├── .gitignore
├── .idea
├── Defect_Detection.iml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── workspace.xml
├── Detect
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── config.cpython-36.pyc
│ ├── inference.cpython-36.pyc
│ ├── mynet.cpython-36.pyc
│ ├── submit.cpython-36.pyc
│ ├── train.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── config.py
├── evaluate.py
├── inference.py
├── model_merge.py
├── mynet.py
├── phone_seg.py
├── predict.py
├── submit.py
├── train.py
└── utils.py
├── Image_preprocessing
├── __init__.py
├── create_data.py
├── image2npy.py
├── image_augmentation.py
├── image_crop.py
├── image_segmentation.py
└── json2npy.py
├── Siamese
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── inference.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── bin_classify.py
├── config.py
├── data_cache.py
├── evaluate.py
├── inference.py
├── inference_2.py
├── predict.py
├── train.py
├── train_2.py
└── utils.py
├── Siamese_multi
└── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── inference.cpython-36.pyc
├── camera
├── DllTest.dll
├── __init__.py
└── test.py
├── mv
├── __init__.py
├── base.py
├── cig_detection.py
└── evaluate.py
├── nets
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── alexnet.cpython-36.pyc
│ ├── inception_resnet_v2.cpython-36.pyc
│ ├── inception_utils.cpython-36.pyc
│ ├── inception_v4.cpython-36.pyc
│ ├── mobilenet_v1.cpython-36.pyc
│ ├── resnet_utils.cpython-36.pyc
│ ├── resnet_v2.cpython-36.pyc
│ └── vgg.cpython-36.pyc
├── alexnet.py
├── alexnet_test.py
├── cifarnet.py
├── cyclegan.py
├── cyclegan_test.py
├── dcgan.py
├── dcgan_test.py
├── inception.py
├── inception_resnet_v2.py
├── inception_resnet_v2_test.py
├── inception_utils.py
├── inception_v1.py
├── inception_v1_test.py
├── inception_v2.py
├── inception_v2_test.py
├── inception_v3.py
├── inception_v3_test.py
├── inception_v4.py
├── inception_v4_test.py
├── lenet.py
├── mobilenet
│ ├── README.md
│ ├── __init__.py
│ ├── conv_blocks.py
│ ├── madds_top1_accuracy.png
│ ├── mnet_v1_vs_v2_pixel1_latency.png
│ ├── mobilenet.py
│ ├── mobilenet_example.ipynb
│ ├── mobilenet_v2.py
│ └── mobilenet_v2_test.py
├── mobilenet_v1.md
├── mobilenet_v1.png
├── mobilenet_v1.py
├── mobilenet_v1_eval.py
├── mobilenet_v1_test.py
├── mobilenet_v1_train.py
├── nasnet
│ ├── README.md
│ ├── __init__.py
│ ├── nasnet.py
│ ├── nasnet_test.py
│ ├── nasnet_utils.py
│ ├── nasnet_utils_test.py
│ ├── pnasnet.py
│ └── pnasnet_test.py
├── nets_factory.py
├── nets_factory_test.py
├── overfeat.py
├── overfeat_test.py
├── pix2pix.py
├── pix2pix_test.py
├── resnet_utils.py
├── resnet_v1.py
├── resnet_v1_test.py
├── resnet_v2.py
├── resnet_v2_test.py
├── vgg.py
└── vgg_test.py
├── ocr
├── __init__.py
└── ocr.py
├── qt
├── 1.jpg
├── LogDialog.py
├── MainWindow.py
├── __pycache__
│ ├── LogDialog.cpython-36.pyc
│ ├── MainWindow.cpython-36.pyc
│ ├── predict.cpython-36.pyc
│ ├── predict_dl.cpython-36.pyc
│ └── predict_ip.cpython-36.pyc
├── cv.py
├── hello-world-master
│ ├── 1.txt
│ └── README.md
├── log
│ ├── 1
│ └── detect_20181016_134111.csv
├── main.py
├── nets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── alexnet.cpython-36.pyc
│ │ ├── mobilenet_v1.cpython-36.pyc
│ │ └── vgg.cpython-36.pyc
│ ├── alexnet.py
│ └── mobilenet_v1.py
├── predict_dl.py
├── predict_ip.py
├── src
│ ├── 校标.png
│ └── 欢迎.jpg
├── test.py
├── test1.jpg
├── ui_database_set.ui
├── ui_defect_log.ui
├── ui_detect_log.ui
└── ui_detection.ui
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
2 | /log/
3 | /qt/weight/
4 |
--------------------------------------------------------------------------------
/.idea/Defect_Detection.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Detect/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__init__.py
--------------------------------------------------------------------------------
/Detect/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/__pycache__/inference.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/inference.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/__pycache__/mynet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/mynet.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/__pycache__/submit.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/submit.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/Detect/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Detect 配置文件,全局参数
3 | author: 王建坤
4 | date: 2018-12-3
5 | """
6 | import numpy as np
7 | import tensorflow as tf
8 | from nets import alexnet, mobilenet_v1, vgg, inception_v4, resnet_v2, inception_resnet_v2
9 | from Detect import mynet
10 | import os
11 |
12 | # 分类类别
13 | CLASSES = 5
14 | # 图像尺寸
15 | IMG_SIZE = 224
16 | # 图像通道数
17 | CHANNEL = 2
18 | # 是否为非标准图像尺寸
19 | GLOBAL_POOL = True
20 | # 是否使用GPU
21 | USE_GPU = True
22 | if not USE_GPU:
23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
24 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
25 | # 模型名称 'Mobile' 'My'
26 | MODEL_NAME = 'My'
27 | # 数据集路径
28 | date_set_name = 'sia_cig_' # b_cig_, sia_cig_, s_sia_cig_
29 | images_path = '../data/' + date_set_name + 'data_' + str(IMG_SIZE) + '.npy'
30 | labels_path = '../data/' + date_set_name + 'label_' + str(IMG_SIZE) + '.npy'
31 |
32 |
33 |
--------------------------------------------------------------------------------
/Detect/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Detect 性能评估
3 | author: 王建坤
4 | date: 2018-9-10
5 | """
6 | from Detect.config import *
7 | import math
8 | from sklearn.model_selection import train_test_split
9 |
10 | BATCH_SIZE = 1
11 | IS_TRAINING = False
12 |
13 |
14 | def evaluate(model=MODEL_NAME):
15 | """
16 | 评估模型
17 | :param model: model name
18 | :return: none
19 | """
20 | # 预测为某类的样本个数,某类预测正确的样本个数
21 | sample_labels = np.zeros(CLASSES, dtype=np.uint16)
22 | pre_pos = np.zeros(CLASSES, dtype=np.uint16)
23 | true_pos = np.zeros(CLASSES, dtype=np.uint16)
24 |
25 | # 加载测试集
26 | images = np.load(images_path)
27 | labels = np.load(labels_path)
28 |
29 | _, val_data, _, val_label = train_test_split(images, labels, test_size=0.2, random_state=222)
30 | # 如果输入是灰色图,要增加一维
31 | if CHANNEL == 1:
32 | val_data = np.expand_dims(val_data, axis=3)
33 |
34 | test_num = val_data.shape[0]
35 | # 存放预测结果
36 | res = np.zeros(test_num, dtype=np.uint16)
37 |
38 | # 占位符
39 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL], name="x_input")
40 | y_ = tf.placeholder(tf.uint8, [None], name="y_input")
41 |
42 | # 模型保存路径,模型名,预训练文件路径,前向传播
43 | if model == 'Alex':
44 | log_dir = "../log/Alex"
45 | y, _ = alexnet.alexnet_v2(x,
46 | num_classes=CLASSES, # 分类的类别
47 | is_training=IS_TRAINING, # 是否在训练
48 | dropout_keep_prob=1.0, # 保留比率
49 | spatial_squeeze=True, # 压缩掉1维的维度
50 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
51 | elif model == 'My':
52 | log_dir = "../log/My"
53 | # y = mynet.mynet_v1(x, is_training=IS_TRAINING, num_classes=CLASSES)
54 | y, _ = mobilenet_v1.mobilenet_v1(x,
55 | num_classes=CLASSES,
56 | dropout_keep_prob=1.0,
57 | is_training=IS_TRAINING,
58 | min_depth=8,
59 | depth_multiplier=1.0,
60 | conv_defs=None,
61 | prediction_fn=None,
62 | spatial_squeeze=True,
63 | reuse=None,
64 | scope='MobilenetV1',
65 | global_pool=GLOBAL_POOL)
66 | elif model == 'Mobile':
67 | log_dir = "../log/Mobile"
68 | y, _ = mobilenet_v1.mobilenet_v1(x,
69 | num_classes=CLASSES,
70 | dropout_keep_prob=1.0,
71 | is_training=IS_TRAINING,
72 | min_depth=8,
73 | depth_multiplier=1.0,
74 | conv_defs=None,
75 | prediction_fn=None,
76 | spatial_squeeze=True,
77 | reuse=None,
78 | scope='MobilenetV1',
79 | global_pool=GLOBAL_POOL)
80 | elif model == 'VGG':
81 | log_dir = "../log/VGG"
82 | y, _ = vgg.vgg_16(x,
83 | num_classes=CLASSES,
84 | is_training=IS_TRAINING,
85 | dropout_keep_prob=1.0,
86 | spatial_squeeze=True,
87 | global_pool=GLOBAL_POOL)
88 | elif model == 'Incep4':
89 | log_dir = "E:/alum/log/Incep4"
90 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES,
91 | is_training=IS_TRAINING,
92 | dropout_keep_prob=1.0,
93 | reuse=None,
94 | scope='InceptionV4',
95 | create_aux_logits=True)
96 | elif model == 'Res':
97 | log_dir = "E:/alum/log/Res"
98 | y, _ = resnet_v2.resnet_v2_50(x,
99 | num_classes=CLASSES,
100 | is_training=IS_TRAINING,
101 | global_pool=GLOBAL_POOL,
102 | output_stride=None,
103 | spatial_squeeze=True,
104 | reuse=None,
105 | scope='resnet_v2_50')
106 | else:
107 | print('Error: model name not exist')
108 | return
109 |
110 | # 预测结果
111 | y_pre = tf.argmax(y, 1)
112 | saver = tf.train.Saver(tf.global_variables())
113 |
114 | with tf.Session() as sess:
115 | # 恢复模型权重
116 | print('Reading checkpoints of', model)
117 | ckpt = tf.train.get_checkpoint_state(log_dir)
118 | if ckpt and ckpt.model_checkpoint_path:
119 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
120 | saver.restore(sess, ckpt.model_checkpoint_path)
121 | print('Loading success, step is %s' % global_step)
122 | else:
123 | print('Error: no checkpoint file found')
124 | return
125 |
126 | # 遍历一次测试集需要次数
127 | num_iter = int(math.ceil(test_num / BATCH_SIZE))
128 |
129 | step = 0
130 | start = 0
131 | while step < num_iter:
132 | # 获取一个 batch
133 | if step == num_iter-1:
134 | end = test_num
135 | else:
136 | end = start+BATCH_SIZE
137 | image_batch = val_data[start:end]
138 | label_batch = val_label[start:end]
139 | # 准确率和预测结果统计信息
140 | pres, pre = sess.run([y, y_pre], feed_dict={x: image_batch, y_: label_batch})
141 | res[start:end] = pre
142 | start += BATCH_SIZE
143 | step += 1
144 | # 计算准确率
145 | normal_num = 0
146 | for i in range(test_num):
147 | pre_pos[res[i]] += 1
148 | sample_labels[val_label[i]] += 1
149 | if res[i] == val_label[i]:
150 | true_pos[res[i]] += 1
151 | if val_label[i] == 0:
152 | normal_num += 1
153 |
154 | precision = true_pos/pre_pos
155 | recall = true_pos/sample_labels
156 | f1 = 2*precision*recall/(precision+recall)
157 | error = (pre_pos - true_pos) / (test_num - sample_labels)
158 | print('测试样本数:', test_num)
159 | print('测试数:', sample_labels)
160 | print('预测数:', pre_pos)
161 | print('正确数:', true_pos)
162 | print('Precision:', precision)
163 | print('Recall:', recall)
164 | print('F1:', f1)
165 | print('误检率:', error)
166 | print('准确率:', np.sum(true_pos)/test_num)
167 | print('总漏检率:', (pre_pos[0]-true_pos[0])/(test_num-normal_num))
168 | print('总过杀率:', (normal_num-true_pos[0])/normal_num)
169 |
170 |
171 | if __name__ == '__main__':
172 | evaluate()
173 |
--------------------------------------------------------------------------------
/Detect/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Detect 模型网络结构(自己搭建的)
3 | author: 王建坤
4 | date: 2018-8-10
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 |
10 | # 卷积层
11 | def convolution(input_tensor, conv_height, conv_width, conv_deep, x_stride, y_stride, name, padding='SAME'):
12 | with tf.variable_scope(name):
13 | channel = int(input_tensor.get_shape()[-1])
14 | weights = tf.get_variable("weights", shape=[conv_height, conv_width, channel, conv_deep],
15 | initializer=tf.truncated_normal_initializer(stddev=0.025))
16 | bias = tf.get_variable("bias", shape=[conv_deep], initializer=tf.constant_initializer(0.025))
17 | conv = tf.nn.conv2d(input_tensor, weights, strides=[1, x_stride, y_stride, 1], padding=padding)
18 | relu = tf.nn.relu(tf.nn.bias_add(conv, bias))
19 | return relu
20 |
21 |
22 | # 最大池化层
23 | def max_pool(input_tensor, height, width, x_stride, y_stride, name, padding="SAME"):
24 | return tf.nn.max_pool(input_tensor, ksize=[1, height, width, 1], strides=[1, x_stride, y_stride, 1],
25 | padding=padding, name=name)
26 |
27 |
28 | # 局部响应归一化层
29 | def LRN(input_tensor, R, alpha, beta, name=None, bias=1.0):
30 | return tf.nn.local_response_normalization(input_tensor, depth_radius=R, bias=bias, alpha=alpha, beta=beta,
31 | name=name)
32 |
33 |
34 | # dropout层
35 | def dropout(input_tensor, prob, name):
36 | return tf.nn.dropout(input_tensor, prob, name=name)
37 |
38 |
39 | # 全连接层
40 | def full_connect(input_tensor, in_dimension, out_dimension, relu_flag, name):
41 | with tf.variable_scope(name):
42 | weights = tf.get_variable("weights", [in_dimension, out_dimension],
43 | initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2 / in_dimension)))
44 | bias = tf.get_variable("bias", [out_dimension], initializer=tf.constant_initializer(0.05))
45 | fc = tf.matmul(input_tensor, weights) + bias
46 | if relu_flag:
47 | fc = tf.nn.relu(fc)
48 | return fc
49 |
50 |
51 | def inference(input_tensor, train):
52 | conv1 = convolution(input_tensor, 11, 11, 96, 4, 4, "conv1", padding="VALID")
53 | pool1 = max_pool(conv1, 3, 3, 2, 2, "pool1", "VALID")
54 | lrn1 = LRN(pool1, 2, 2e-05, 0.75, "lrn1")
55 | conv2 = convolution(lrn1, 5, 5, 256, 1, 1, "conv2")
56 | pool2 = max_pool(conv2, 3, 3, 2, 2, "pool2", "VALID")
57 | lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2")
58 | conv3 = convolution(lrn2, 3, 3, 384, 1, 1, "conv3")
59 | conv4 = convolution(conv3, 3, 3, 384, 1, 1, "conv4")
60 | conv5 = convolution(conv4, 3, 3, 256, 1, 1, "conv5")
61 | pool5 = max_pool(conv5, 3, 3, 2, 2, "pool5", "VALID")
62 | fcin = tf.reshape(pool5, [-1, 256 * 6 * 6])
63 | fc1 = full_connect(fcin, 256 * 6 * 6, 4096, True, "fc6")
64 | if train:
65 | fc1 = dropout(fc1, 0.8, "drop6")
66 | fc2 = full_connect(fc1, 4096, 2, True, "fc7") # 2048
67 | # if train:
68 | # fc2 = dropout(fc2, 1, "drop7")
69 | fc3 = full_connect(fc2, 2, 2, True, "fc8")
70 | return fc3
71 |
--------------------------------------------------------------------------------
/Detect/model_merge.py:
--------------------------------------------------------------------------------
1 | """
2 | 模型融合--根据预测结果进行融合
3 | author: 王建坤
4 | date: 2018-9-29
5 | """
6 | import math
7 | import os
8 | import pandas as pd
9 | import numpy as np
10 | import tensorflow as tf
11 | from PIL import Image
12 | from nets import alexnet, vgg, inception_v4, resnet_v2
13 |
14 |
15 | BATCH_SIZE = 32
16 | IMG_SIZE1 = 224
17 | IMG_SIZE2 = 299
18 | CLASSES = 12
19 | GLOBAL_POOL = True
20 | test_dir = 'E:/dataset/alum/guangdong_round1_test_a_20180916'
21 | word_class = {'norm': 0, 'defect1': 1, 'defect2': 2, 'defect3': 3, 'defect4': 4, 'defect5': 5, 'defect6': 6,
22 | 'defect7': 7, 'defect8': 8, 'defect9': 9, 'defect10': 10, 'defect11': 11}
23 | class_label = ['norm', 'defect1', 'defect2', 'defect3', 'defect4', 'defect5',
24 | 'defect6', 'defect7', 'defect8', 'defect9', 'defect10', 'defect11']
25 |
26 |
27 | def choose_model(x, model):
28 | """
29 | 选择模型
30 | :param x:
31 | :param model:
32 | :return:
33 | """
34 | # 模型保存路径,模型名,预训练文件路径,前向传播
35 | if model == 'Alex':
36 | log_dir = "E:/alum/log/Alex"
37 | y, _ = alexnet.alexnet_v2(x,
38 | num_classes=CLASSES, # 分类的类别
39 | is_training=True, # 是否在训练
40 | dropout_keep_prob=1.0, # 保留比率
41 | spatial_squeeze=True, # 压缩掉1维的维度
42 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
43 | elif model == 'VGG':
44 | log_dir = "E:/alum/log/VGG"
45 | y, _ = vgg.vgg_16(x,
46 | num_classes=CLASSES,
47 | is_training=True,
48 | dropout_keep_prob=1.0,
49 | spatial_squeeze=True,
50 | global_pool=GLOBAL_POOL)
51 | elif model == 'VGG2':
52 | log_dir = "E:/alum/log/VGG2"
53 | y, _ = vgg.vgg_16(x,
54 | num_classes=CLASSES,
55 | is_training=True,
56 | dropout_keep_prob=1.0,
57 | spatial_squeeze=True,
58 | global_pool=GLOBAL_POOL)
59 | elif model == 'Incep4':
60 | log_dir = "E:/alum/log/Incep4"
61 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES,
62 | is_training=True,
63 | dropout_keep_prob=1.0,
64 | reuse=None,
65 | scope='InceptionV4',
66 | create_aux_logits=True)
67 | elif model == 'Res':
68 | log_dir = "E:/alum/log/Res"
69 | y, _ = resnet_v2.resnet_v2_50(x,
70 | num_classes=CLASSES,
71 | is_training=True,
72 | global_pool=GLOBAL_POOL,
73 | output_stride=None,
74 | spatial_squeeze=True,
75 | reuse=None,
76 | scope='resnet_v2_50')
77 | else:
78 | print('Error: model name not exist')
79 | return
80 |
81 | return y, log_dir
82 |
83 |
84 | def evaluate_model_merge(model1='VGG2', model2='VGG', submit=True):
85 | """
86 | 模型融合,并评估结果
87 | :param model1:
88 | :param model2:
89 | :param submit:
90 | :return:
91 | """
92 | # 存放模型1、2的预测结果
93 | image_name_list = os.listdir(test_dir)
94 | image_labels = pd.read_csv('E:/WJK_File/Python_File/Defect_Detection/Detect/test_label.csv')
95 | res, res1, res2, res_name, labels = [], [], [], [], []
96 |
97 | # model 1
98 | g1 = tf.Graph()
99 | with g1.as_default():
100 | x1 = tf.placeholder(tf.float32, [None, IMG_SIZE1, IMG_SIZE1, 3], name="x_input_1")
101 | y1, log_dir1 = choose_model(x1, model1)
102 | y1 = tf.nn.softmax(y1)
103 | saver1 = tf.train.Saver()
104 | sess1 = tf.Session(graph=g1)
105 | print("Reading checkpoints of model1")
106 | ckpt = tf.train.get_checkpoint_state(log_dir1)
107 | if ckpt and ckpt.model_checkpoint_path:
108 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
109 | saver1.restore(sess1, ckpt.model_checkpoint_path)
110 | print('Loading success, global_step is %s' % global_step)
111 | else:
112 | print('Error: no checkpoint file found')
113 | return
114 |
115 | # model 1
116 | g2 = tf.Graph()
117 | with g2.as_default():
118 | x2 = tf.placeholder(tf.float32, [None, IMG_SIZE2, IMG_SIZE2, 3], name="x_input_2")
119 | y2, log_dir2 = choose_model(x2, model2)
120 | y2 = tf.nn.softmax(y2)
121 | saver2 = tf.train.Saver()
122 | sess2 = tf.Session(graph=g2)
123 | print("Reading checkpoints of model2")
124 | ckpt = tf.train.get_checkpoint_state(log_dir2)
125 | if ckpt and ckpt.model_checkpoint_path:
126 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
127 | saver2.restore(sess2, ckpt.model_checkpoint_path)
128 | print('Loading success, global_step is %s' % global_step)
129 | else:
130 | print('model2 checkpoint file is not found ')
131 | return
132 |
133 | for image_name in image_name_list:
134 | label = image_labels[image_labels.iloc[:, 0].isin([image_name])].iloc[0, 1]
135 | labels.append(word_class[label])
136 | image_path = os.path.join(test_dir, image_name)
137 | img = Image.open(image_path)
138 | img1 = img.resize((IMG_SIZE1, IMG_SIZE1))
139 | img1 = np.array(img1, np.float32)
140 | img1 = np.expand_dims(img1, axis=0)
141 | pre1 = sess1.run(y1, feed_dict={x1: img1})
142 | max_index1 = np.argmax(pre1)
143 | max_value1 = np.max(pre1)
144 | res1.append([max_index1, max_value1])
145 |
146 | img2 = img.resize((IMG_SIZE2, IMG_SIZE2))
147 | img2 = np.array(img2, np.float32)
148 | img2 = np.expand_dims(img2, axis=0)
149 | pre2 = sess2.run(y2, feed_dict={x2: img2})
150 | max_index2 = np.argmax(pre2)
151 | max_value2 = np.max(pre2)
152 | res2.append([max_index2, max_value2])
153 |
154 | sess1.close()
155 | sess2.close()
156 | # print(res1, '\n', res2)
157 | img_num = len(res1)
158 | for i in range(img_num):
159 | if res1[i][0] == res2[i][0]:
160 | res.append(res1[i][0])
161 | else:
162 | if res1[i][1] > res2[i][1]:
163 | res.append(res1[i][0])
164 | else:
165 | res.append(res2[i][0])
166 | print(res1[i], res2[i], res[i], labels[i])
167 | print(res, '\n', labels)
168 |
169 | # 评估结果
170 | pre_pos = np.zeros(12)
171 | true_pos = np.zeros(12)
172 | for i in range(img_num):
173 | pre_pos[res[i]] += 1
174 | if res[i] == labels[i]:
175 | true_pos[res[i]] += 1
176 | precision = true_pos/pre_pos
177 | print(pre_pos, '\n', true_pos, '\n', precision)
178 | print('准确率:', np.sum(true_pos)/img_num, '\n', '平均准确率:', np.mean(precision))
179 |
180 | # 是否要保存结果为csv
181 | if submit:
182 | for i in res:
183 | res_name.append(class_label[i])
184 | print('image: ', image_name_list, '\n', 'class:', res_name)
185 | # 存为csv文件。两列:图片名,预测结果
186 | data = pd.DataFrame({'0': image_name_list, '1': res_name})
187 | data.to_csv("merge.csv", index=False, header=False)
188 |
189 |
190 | if __name__ == '__main__':
191 | evaluate_model_merge()
192 |
--------------------------------------------------------------------------------
/Detect/mynet.py:
--------------------------------------------------------------------------------
1 | """
2 | 缺陷检测,自定义网络结构
3 | author: 王建坤
4 | date: 2018-12-5
5 | """
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 | import numpy as np
9 | import time
10 |
11 | IMG_SIZE = 128
12 |
13 |
14 | def mynet_v1(inputs, num_classes=10, is_training=True, scope=None):
15 | with tf.variable_scope(scope, 'Mynet_v1', [inputs]):
16 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], normalizer_fn=slim.batch_norm):
17 | with slim.arg_scope([slim.batch_norm], is_training=is_training):
18 | net = slim.conv2d(inputs, 16, [11, 11], stride=4, padding='VALID', scope='Conv_1')
19 | # net = slim.batch_norm(net, scope='Conv_1_bn')
20 | net = slim.conv2d(net, 32, [1, 1], stride=1, scope='Conv_2')
21 | net = slim.separable_conv2d(net, None, [5, 5], depth_multiplier=1.0, stride=3, scope='Conv_3_dw')
22 | net = slim.conv2d(net, 64, [1, 1], stride=1, scope='Conv_3_pw')
23 | net = slim.separable_conv2d(net, None, [3, 3], depth_multiplier=1.0, stride=2, scope='Conv_4_dw')
24 | net = slim.conv2d(net, 64, [1, 1], stride=1, scope='Conv_4_pw')
25 | net = slim.conv2d(net, 64, [5, 5], padding='VALID', scope='Conv_5')
26 | pre = slim.conv2d(net, num_classes, [1, 1], scope='Conv_6')
27 | # pre = tf.squeeze(pre, [1, 2], name='squeezed')
28 | # pre = slim.softmax(pre, scope='Predictions')
29 |
30 | return pre
31 |
32 |
33 | def train():
34 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1], name="x_input")
35 | y_ = tf.placeholder(tf.uint8, [None, 3], name="y_input")
36 |
37 | images = np.random.randint(255, size=(100, IMG_SIZE, IMG_SIZE, 1))
38 | labels = np.random.randint(3, size=(100, 3))
39 |
40 | y = mynet_v1(x, num_classes=3, scope=None)
41 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_, name='entropy')
42 | loss = tf.reduce_mean(cross_entropy, name='loss')
43 | optimizer = tf.train.AdamOptimizer(0.01)
44 | train_op = optimizer.minimize(loss)
45 |
46 | init = tf.global_variables_initializer()
47 |
48 | with tf.Session() as sess:
49 | sess.run(init)
50 | step = 0
51 |
52 | while step < 10:
53 | sess.run(train_op, feed_dict={x: images, y_: labels})
54 | step += 1
55 | print(step)
56 |
57 |
58 | def predict():
59 | x = tf.placeholder(tf.float32, [None, 2000, 1000, 1], name="x_input")
60 | y = mynet_v1(x, num_classes=3, scope=None)
61 |
62 | image = np.random.randint(255, size=(1, 2000, 1000, 1))
63 |
64 | init = tf.global_variables_initializer()
65 |
66 | with tf.Session() as sess:
67 | sess.run(init)
68 | sess.run(y, feed_dict={x: image})
69 |
70 | start_time = time.clock()
71 | res = sess.run(y, feed_dict={x: image})
72 | end_time = time.clock()
73 | print('run time: ', end_time - start_time)
74 | print(res.shape)
75 |
76 |
77 | if __name__ == '__main__':
78 | images = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1], name="x_input")
79 | res = mynet_v1(images, 2)
80 | print(res)
81 | # train()
82 | # predict()
83 |
84 |
85 |
--------------------------------------------------------------------------------
/Detect/phone_seg.py:
--------------------------------------------------------------------------------
1 | """
2 | 手机图像分割,学习目标矩形框
3 | author: 王建坤
4 | date: 2018-10-25
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 | import os
9 | import time
10 | import tensorflow.contrib.slim as slim
11 | from sklearn.model_selection import train_test_split
12 | from nets import alexnet
13 | from Detect import utils
14 | from PIL import Image
15 | import cv2
16 |
17 |
18 | def lenet5(inputs):
19 | inputs = tf.reshape(inputs, [-1, 300, 300, 3])
20 | net = slim.conv2d(inputs, 32, [5, 5], padding='SAME', scope='conv1')
21 | net = slim.max_pool2d(net, 2, stride=2, scope='pool1')
22 | net = slim.conv2d(net, 64, [5, 5], padding='SAME', scope='conv2')
23 | net = slim.max_pool2d(net, 2, stride=2, scope='pool2')
24 | net = slim.flatten(net, scope='flatten')
25 | net = slim.fully_connected(net, 500, scope='fully4')
26 | net = slim.fully_connected(net, 10, scope='fully5')
27 | return net
28 |
29 |
30 | MAX_STEP = 100
31 | LEARNING_RATE_BASE = 0.01
32 | LEARNING_RATE_DECAY = 0.96
33 | IMG_SIZE = 224
34 | SHOW_SIZE = 800
35 | CLASSES = 8
36 | BATCH_SIZE = 32
37 |
38 | INFO_STEP = 20
39 | SAVE_STEP = 200
40 | GLOBAL_POOL = False
41 |
42 |
43 | def train(inherit=False, model='Alex'):
44 | # 加载数据集
45 | images = np.load('../data/card_data.npy')
46 | labels = np.load('../data/card_label.npy')
47 | train_data, val_data, train_label, val_label = train_test_split(images, labels, test_size=0.2, random_state=222)
48 |
49 | # 占位符
50 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1], name='x-input')
51 | y_ = tf.placeholder(tf.float32, [None, CLASSES], name='y-input')
52 | my_global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64)
53 |
54 | # 前向传播
55 | if model == 'Alex':
56 | log_path = "../log/Alex"
57 | model_name = 'alex.ckpt'
58 | y, _ = alexnet.alexnet_v2(x,
59 | num_classes=CLASSES, # 分类的类别
60 | is_training=True, # 是否在训练
61 | dropout_keep_prob=1.0, # 保留比率
62 | spatial_squeeze=True, # 压缩掉1维的维度
63 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
64 | else:
65 | log_path = '../log/My'
66 | model_name = 'my.ckpt'
67 | y, _ = lenet5(x)
68 |
69 | # 交叉熵、损失值、优化器、准确率
70 | loss = tf.reduce_mean(tf.sqrt(tf.square(y_ - y)+0.0000001))
71 | learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, my_global_step, 100, LEARNING_RATE_DECAY)
72 | train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=my_global_step)
73 |
74 | # 模型保存器、初始化
75 | saver = tf.train.Saver()
76 | init = tf.global_variables_initializer()
77 |
78 | tf.summary.scalar("loss", loss)
79 | # tf.summary.histogram("W1", W)
80 | merged_summary_op = tf.summary.merge_all()
81 |
82 | # 训练迭代
83 | with tf.Session() as sess:
84 | summary_writer1 = tf.summary.FileWriter('../log/curve/train', sess.graph)
85 | summary_writer2 = tf.summary.FileWriter('../log/curve/test')
86 | step = 0
87 | if inherit:
88 | ckpt = tf.train.get_checkpoint_state(log_path)
89 | if ckpt and ckpt.model_checkpoint_path:
90 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
91 | saver.restore(sess, ckpt.model_checkpoint_path)
92 | print(model, 'continue train from %s:' % global_step)
93 | step = int(global_step)
94 | else:
95 | print('Error: no checkpoint file found')
96 | return
97 | else:
98 | print(model, 'restart train:')
99 | sess.run(init)
100 |
101 | # 迭代
102 | while step < MAX_STEP:
103 | start_time = time.clock()
104 | image_batch, label_batch = utils.get_batch(train_data, train_label, BATCH_SIZE)
105 |
106 | # 训练,损失值和准确率
107 | _, train_loss = sess.run([train_op, loss], feed_dict={x: image_batch, y_: label_batch})
108 | end_time = time.clock()
109 | runtime = end_time - start_time
110 |
111 | step += 1
112 | # 训练信息、曲线和保存模型
113 | if step % INFO_STEP == 0 or step == MAX_STEP:
114 | summary_str = sess.run(merged_summary_op, feed_dict={x: image_batch, y_: label_batch})
115 | summary_writer1.add_summary(summary_str, step)
116 | test_loss, summary_str = sess.run([loss, merged_summary_op], feed_dict={x: val_data, y_: val_label})
117 | summary_writer2.add_summary(summary_str, step)
118 | print('step: %d, runtime: %.2f, train loss: %.4f, test loss: %.4f' %
119 | (step, runtime, train_loss, test_loss))
120 |
121 | if step % SAVE_STEP == 0:
122 | checkpoint_path = os.path.join(log_path, model_name)
123 | saver.save(sess, checkpoint_path, global_step=step)
124 |
125 |
126 | def predict(root_path, model='Alex'):
127 | # 占位符
128 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1])
129 |
130 | # 模型保存路径,前向传播
131 | if model == 'Alex':
132 | log_path = "../log/Alex"
133 | y, _ = alexnet.alexnet_v2(x,
134 | num_classes=CLASSES, # 分类的类别
135 | is_training=True, # 是否在训练
136 | dropout_keep_prob=1.0, # 保留比率
137 | spatial_squeeze=True, # 压缩掉1维的维度
138 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
139 | else:
140 | log_path = '../log/My'
141 | y, _ = lenet5(x)
142 |
143 | saver = tf.train.Saver()
144 |
145 | with tf.Session() as sess:
146 | # 恢复模型权重
147 | print('Reading checkpoints: ', model)
148 | ckpt = tf.train.get_checkpoint_state(log_path)
149 | if ckpt and ckpt.model_checkpoint_path:
150 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
151 | saver.restore(sess, ckpt.model_checkpoint_path)
152 | print('Loading success, global_step is %s' % global_step)
153 | else:
154 | print('Error: no checkpoint file found')
155 | return
156 |
157 | file_list = os.listdir(root_path)
158 | for file_name in file_list:
159 | if file_name.split('.')[-1] == 'jpg':
160 | img = Image.open(os.path.join(root_path, file_name))
161 | img_show = img.resize((SHOW_SIZE, SHOW_SIZE))
162 | img_show = np.array(img_show, np.uint8)
163 | img = img.resize((IMG_SIZE, IMG_SIZE))
164 | img = np.array(img, np.uint8)
165 | img = np.expand_dims(img, axis=0)
166 | img = np.expand_dims(img, axis=3)
167 | pre = np.squeeze(sess.run(y, feed_dict={x: img}))
168 | print('predict:', pre)
169 |
170 | points = [[pre[0] * SHOW_SIZE, pre[1] * SHOW_SIZE], [pre[2] * SHOW_SIZE, pre[3] * SHOW_SIZE],
171 | [pre[4] * SHOW_SIZE, pre[5] * SHOW_SIZE], [pre[6] * SHOW_SIZE, pre[7] * SHOW_SIZE]]
172 | points = np.int0(points)
173 | cv2.polylines(img_show, [points], True, (0, 0, 255), 1)
174 | cv2.imshow('img_show', img_show)
175 | cv2.waitKey()
176 | cv2.destroyAllWindows()
177 |
178 |
179 | if __name__ == '__main__':
180 | train()
181 | # predict('../data/card')
182 |
--------------------------------------------------------------------------------
/Detect/predict.py:
--------------------------------------------------------------------------------
1 | """
2 | Detect 预测
3 | author: 王建坤
4 | date: 2018-8-10
5 | """
6 | import sys
7 | sys.path.append("..")
8 | from Detect.config import *
9 | from PIL import Image
10 | import time
11 |
12 | # 图像目录路径
13 | # IMG_DIR = '../data/crop/pos/'
14 | # IMG_DIR = '../data/phone/'
15 | IMG_DIR = '../data/backup/cigarette_ys/wire_fail/' # normal wire_fail
16 | IS_TRAINING = False
17 | TS = []
18 |
19 |
20 | def predict(img_path, model=MODEL_NAME):
21 | """
22 | 预测图片
23 | :param img_path: 图片路径
24 | :param model: 模型名
25 | :return: none
26 | """
27 | # 占位符
28 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL])
29 |
30 | # 模型保存路径,前向传播
31 | if model == 'Alex':
32 | log_dir = "../log/Alex"
33 | y, _ = alexnet.alexnet_v2(x,
34 | num_classes=CLASSES, # 分类的类别
35 | is_training=IS_TRAINING, # 是否在训练
36 | dropout_keep_prob=1.0, # 保留比率
37 | spatial_squeeze=True, # 压缩掉1维的维度
38 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
39 | elif model == 'My':
40 | log_dir = "../log/My"
41 | y, _ = mobilenet_v1.mobilenet_v1(x,
42 | num_classes=CLASSES,
43 | dropout_keep_prob=1.0,
44 | is_training=IS_TRAINING,
45 | min_depth=8,
46 | depth_multiplier=0.1,
47 | conv_defs=None,
48 | prediction_fn=None,
49 | spatial_squeeze=True,
50 | reuse=None,
51 | scope='MobilenetV1',
52 | global_pool=GLOBAL_POOL)
53 | elif model == 'Mobile':
54 | log_dir = "../log/Mobile"
55 | y, _ = mobilenet_v1.mobilenet_v1(x,
56 | num_classes=CLASSES,
57 | dropout_keep_prob=1.0,
58 | is_training=IS_TRAINING,
59 | min_depth=8,
60 | depth_multiplier=1.0,
61 | conv_defs=None,
62 | prediction_fn=None,
63 | spatial_squeeze=True,
64 | reuse=None,
65 | scope='MobilenetV1',
66 | global_pool=GLOBAL_POOL)
67 | elif model == 'VGG':
68 | log_dir = "../log/VGG"
69 | y, _ = vgg.vgg_16(x,
70 | num_classes=CLASSES,
71 | is_training=IS_TRAINING,
72 | dropout_keep_prob=1.0,
73 | spatial_squeeze=True,
74 | global_pool=GLOBAL_POOL)
75 | elif model == 'Incep4':
76 | log_dir = "../log/Incep4"
77 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES,
78 | is_training=IS_TRAINING,
79 | dropout_keep_prob=1.0,
80 | reuse=None,
81 | scope='InceptionV4',
82 | create_aux_logits=True)
83 | elif model == 'Res':
84 | log_dir = "../log/Res"
85 | y, _ = resnet_v2.resnet_v2_50(x,
86 | num_classes=CLASSES,
87 | is_training=IS_TRAINING,
88 | global_pool=GLOBAL_POOL,
89 | output_stride=None,
90 | spatial_squeeze=True,
91 | reuse=None,
92 | scope='resnet_v2_50')
93 | else:
94 | print('Error: model name not exist')
95 | return
96 |
97 | y = tf.nn.softmax(y)
98 | saver = tf.train.Saver()
99 |
100 | with tf.Session() as sess:
101 | # 恢复模型权重
102 | print('Reading checkpoints: ', model)
103 | ckpt = tf.train.get_checkpoint_state(log_dir)
104 | if ckpt and ckpt.model_checkpoint_path:
105 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
106 | saver.restore(sess, ckpt.model_checkpoint_path)
107 | # check_tensor_name(sess)
108 | print('Loading success, global_step is %s' % global_step)
109 | else:
110 | print('Error: no checkpoint file found')
111 | return
112 |
113 | # img = read_img(IMG_DIR + img_path[0])
114 | # # 第一次运行时间会较长
115 | # sess.run(y, feed_dict={x: np.expand_dims(img, 3)})
116 | if img_path:
117 | img_list = img_path
118 | else:
119 | img_list = os.listdir(IMG_DIR)
120 | for img_name in img_list:
121 | # img_name = '1.jpg'
122 | print(img_name)
123 |
124 | start_time = time.clock()
125 | img_path = IMG_DIR + img_name
126 | img = read_img(img_path)
127 | # 如果输入是灰色图,要增加一维
128 | if CHANNEL == 1:
129 | img = np.expand_dims(img, axis=3)
130 | if MODEL_NAME == 'My':
131 | std_img_path = '../data/' + 'std.jpg'
132 | std_img = read_img(std_img_path)
133 | img = np.stack((std_img, img), axis=-1)
134 | predictions = sess.run(y, feed_dict={x: img})
135 | pre = np.argmax(predictions, 1)
136 | end_time = time.clock()
137 | runtime = end_time - start_time
138 | TS.append(round(runtime*1000, 3))
139 | print('prediction is:', predictions)
140 | print('predict class is:', pre)
141 | print('run time:', runtime)
142 | print(min(TS), max(TS[1:]), sum(TS[1:]) / (len(TS)-1))
143 |
144 |
145 | def check_tensor_name(sess):
146 | """
147 | 查看模型的 tensor,在调用模型之后使用
148 | :return: none
149 | """
150 | import tensorflow.contrib.slim as slim
151 | variables_to_restore = slim.get_variables_to_restore()
152 | for var in variables_to_restore:
153 | # tensor 的名
154 | print(var.name)
155 | # tensor 的值
156 | print(sess.run(var))
157 |
158 |
159 | def read_img(img_path):
160 | """
161 | 读取指定路径的图片
162 | :param img_path: 图片的路径
163 | :return: numpy array of image
164 | """
165 | img = Image.open(img_path).convert('L')
166 | img = img.resize((IMG_SIZE, IMG_SIZE))
167 | # img = np.array(img)
168 | img = np.expand_dims(img, axis=0)
169 | return img
170 |
171 |
172 | if __name__ == '__main__':
173 | # normal, nothing, lack_cotton, lack_piece, wire_fail
174 | img_list = os.listdir('../data/backup/cigarette_ys/wire_fail')
175 | # predict(['1.jpg', '1.jpg'])
176 | predict(img_list)
177 |
--------------------------------------------------------------------------------
/Detect/submit.py:
--------------------------------------------------------------------------------
1 | """
2 | Detect 铝型材表面缺陷检测--提交结果
3 | author: 王建坤
4 | date: 2018-9-17
5 | """
6 | import numpy as np
7 | import tensorflow as tf
8 | import pandas as pd
9 | import os
10 | from PIL import Image, ImageFile
11 | from nets import alexnet, vgg, inception_v4, resnet_v2
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 | CLASSES = 12
15 | IMG_SIZE = 299
16 | GLOBAL_POOL = True
17 |
18 |
19 | def submit(test_dir, model='VGG'):
20 | """
21 | 测试集预测结果保存为 csv 文件
22 | :param test_dir: 测试集根路径
23 | :param model: 模型名称
24 | :return: none
25 | """
26 | # 类别名称
27 | print('running submit')
28 | tf.reset_default_graph()
29 | class_label = ['norm', 'defect1', 'defect2', 'defect3', 'defect4', 'defect5',
30 | 'defect6', 'defect7', 'defect8', 'defect9', 'defect10', 'defect11']
31 | img_index, res = [], []
32 |
33 | # 占位符
34 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 3])
35 |
36 | # 模型保存路径,前向传播
37 | if model == 'Alex':
38 | log_dir = "../log/Alex"
39 | y, _ = alexnet.alexnet_v2(x,
40 | num_classes=CLASSES, # 分类的类别
41 | is_training=True, # 是否在训练
42 | dropout_keep_prob=1.0, # 保留比率
43 | spatial_squeeze=True, # 压缩掉1维的维度
44 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
45 | elif model == 'VGG':
46 | log_dir = "../log/VGG"
47 | y, _ = vgg.vgg_16(x,
48 | num_classes=CLASSES,
49 | is_training=True,
50 | dropout_keep_prob=1.0,
51 | spatial_squeeze=True,
52 | global_pool=GLOBAL_POOL)
53 | elif model == 'Incep4':
54 | log_dir = "E:/alum/log/Incep4"
55 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES,
56 | is_training=True,
57 | dropout_keep_prob=1.0,
58 | reuse=None,
59 | scope='InceptionV4',
60 | create_aux_logits=True)
61 | elif model == 'Res':
62 | log_dir = "E:/alum/log/Res"
63 | y, _ = resnet_v2.resnet_v2_50(x,
64 | num_classes=CLASSES,
65 | is_training=True,
66 | global_pool=GLOBAL_POOL,
67 | output_stride=None,
68 | spatial_squeeze=True,
69 | reuse=None,
70 | scope='resnet_v2_50')
71 | else:
72 | print('Error: model name not exist')
73 | return
74 |
75 | saver = tf.train.Saver()
76 |
77 | with tf.Session() as sess:
78 | # 恢复模型权重
79 | print("Reading checkpoints...")
80 | # ckpt 有 model_checkpoint_path 和 all_model_checkpoint_paths 两个属性
81 | ckpt = tf.train.get_checkpoint_state(log_dir)
82 | if ckpt and ckpt.model_checkpoint_path:
83 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
84 | saver.restore(sess, ckpt.model_checkpoint_path)
85 | print('Loading success, global_step is %s' % global_step)
86 | else:
87 | print('Error: no checkpoint file found')
88 | return
89 |
90 | img_list = os.listdir(test_dir)
91 | for img_name in img_list:
92 | # print(img_name)
93 | # 读取图片并缩放
94 | img = Image.open(os.path.join(test_dir, img_name))
95 | img = img.resize((IMG_SIZE, IMG_SIZE))
96 |
97 | # 图片转为numpy的数组
98 | img = np.array(img, np.float32)
99 | img = np.expand_dims(img, axis=0)
100 |
101 | predictions = sess.run(y, feed_dict={x: img})
102 | pre = np.argmax(predictions)
103 | img_index.append(img_name)
104 | res.append(class_label[int(pre)])
105 |
106 | print('image: ', img_index, '\n', 'class:', res)
107 | # 存为csv文件。两列:图片名,预测结果
108 | data = pd.DataFrame({'0': img_index, '1': res})
109 | # data.sort_index(axis=0)
110 | data.to_csv("../submit/vgg_2.csv", index=False, header=False)
111 |
112 |
113 | if __name__ == '__main__':
114 | submit('../data/guangdong_round1_test_b_20181009')
115 |
--------------------------------------------------------------------------------
/Detect/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具函数
3 | author: 王建坤
4 | date: 2018-10-15
5 | """
6 | import numpy as np
7 | import os
8 | import pandas as pd
9 |
10 |
11 | def load_npy(data_path, label_path):
12 | """
13 | Load npy file
14 | :param data_path: the path of data directory
15 | :param label_path: the path of label directory
16 | :return: numpy format of data and label
17 | """
18 | data = np.load(data_path)
19 | label = np.load(label_path)
20 | print('data_set shape: ', np.shape(data))
21 | return data, label
22 |
23 |
24 | def get_batch(data, label, batch_size):
25 | """
26 | Get a batch from numpy array.
27 | :param data: numpy array of data
28 | :param label: numpy array of label
29 | :param batch_size: the number of a batch
30 | :param total: the number of data set sample
31 | :return: a batch of data set
32 | """
33 | # 遍历数据集
34 | # start = start % total
35 | # end = start + batch_size
36 | # if end > total:
37 | # res = end % total
38 | # end = total
39 | # data1 = data[start:end]
40 | # data2 = data[0: res]
41 | # return np.vstack((data1, data2))
42 | # return data[start:end]
43 |
44 | # 随机选取一个batch
45 | total = data.shape[0]
46 | index_list = np.random.randint(total, size=batch_size)
47 | return data[index_list, :], label[index_list]
48 |
49 |
50 | def shuffle_data(data, label):
51 | """
52 | Shuffle the data to assure that the training model is valid.
53 | """
54 | permutation = np.random.permutation(label.shape[0])
55 | batch_data = data[permutation, :]
56 | batch_label = label[permutation]
57 | return batch_data, batch_label
58 |
59 |
60 | def normalize_data(data):
61 | """
62 | Normalized the data by subtracting the mean value of R,G,B to accelerate training .
63 | """
64 | r = data[:, :, :, 2]
65 | r_mean = np.average(r)
66 | b = data[:, :, :, 0]
67 | b_mean = np.average(b)
68 | g = data[:, :, :, 1]
69 | g_mean = np.average(g)
70 | r = r - r_mean
71 | b = b - b_mean
72 | g = g - g_mean
73 | data = np.zeros([r.shape[0], r.shape[1], r.shape[2], 3])
74 | data[:, :, :, 0] = b
75 | data[:, :, :, 1] = g
76 | data[:, :, :, 2] = r
77 | return data
78 |
79 |
80 | def check_tensor():
81 | """
82 | 查看保存的 ckpt文件中的 tensor
83 | """
84 | from tensorflow.python import pywrap_tensorflow
85 |
86 | # ckpt 文件路径
87 | checkpoint_path = 'E:/alum/weight/inception_v4_2016_09_09/inception_v4.ckpt'
88 |
89 | # Read data from checkpoint file
90 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
91 | var_to_shape_map = reader.get_variable_to_shape_map()
92 |
93 | # Print tensor name and values
94 | for key in var_to_shape_map:
95 | print("tensor_name: ", key)
96 | print(reader.get_tensor(key).dtype)
97 |
98 |
99 | def image_to_csv(data_path, save_path):
100 | """
101 | 图片路径和标签保存为 csv
102 | :param data_path: image root path
103 | :param save_path: save path of csv file
104 | :return: none
105 | """
106 | img_path, label = [], []
107 | per_class_num = {}
108 | # 图片类型
109 | class_dic = {'正常': 0, '不导电': 1, '擦花': 2, '角位漏底': 3, '桔皮': 4, '漏底': 5, '起坑': 6, '脏点': 7}
110 | # 遍历文件夹
111 | folder_list = os.listdir(data_path)
112 | for folder in folder_list:
113 | if folder not in class_dic.keys():
114 | continue
115 |
116 | num = 0
117 | class_path = os.path.join(data_path, folder)
118 | img_list = os.listdir(class_path)
119 | for img_name in img_list:
120 | img_path.append(os.path.join(class_path, img_name))
121 | label.append(class_dic[folder])
122 | num += 1
123 | per_class_num[folder] = num
124 |
125 | img_path_label = pd.DataFrame({'img_path': img_path, 'label': label})
126 | img_path_label.to_csv(save_path+'/data_label.csv', index=False)
127 | print('per class number: ', per_class_num)
128 |
129 |
130 | def csv_show_image():
131 | """
132 | 从 csv 文件读取图片路径,并显示
133 | :return: none
134 | """
135 | from PIL import Image
136 | table = pd.read_csv('../data/data_label.csv')
137 | img_path = table.iloc[0, 0]
138 | img = Image.open(img_path)
139 | img.show()
140 |
141 |
142 | if __name__ == "__main__":
143 | print('running utils:')
144 | # train_data, train_label = load_npy('../data/data_299.npy', '../data/label_299.npy')
145 | # image_to_csv('../data/alum', '../data')
146 |
147 |
148 |
--------------------------------------------------------------------------------
/Image_preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Image_preprocessing/__init__.py
--------------------------------------------------------------------------------
/Image_preprocessing/create_data.py:
--------------------------------------------------------------------------------
1 | """
2 | 生成手机后壳的虚假数据集--只是用来测试算法的效果
3 | author: 王建坤
4 | date: 2018-8-15
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | from PIL import ImageEnhance, Image
10 |
11 |
12 | def img_to_gray(img_path):
13 | """
14 | 彩色图片转为灰度图存储
15 | """
16 | # src_img = cv2.imread(img_path, 0)
17 | # print(src_img.shape)
18 | # cv2.imwrite('F:/DefectDetection/img_data/zc.jpg', src_img)
19 | # # cv2.cvtColor(src_img, gray_img, cv2.COLOR_BGR2GRAY);
20 |
21 | src_img = Image.open(img_path)
22 | # 显示图片的尺寸,图片的模式,图片的格式
23 | print(src_img.size, src_img.mode, src_img.format)
24 | dst_img = src_img.convert("L")
25 | # dst_img.save(os.path.join(os.path.dirname(img_path), 'zc.jpg'))
26 | dst_img.save(os.path.join(os.path.dirname(img_path), 'wz.jpg'))
27 |
28 |
29 | def img_enhance(img_dir, save_path, num):
30 | """
31 | 随机扰动图像的亮度和对比度
32 | """
33 | # img_path_list = [img_dir]
34 | img_path_list = os.listdir(img_dir)
35 | for img_path in img_path_list:
36 | # src_img = Image.open(img_path)
37 | # filename_pre = img_path.split('/')[-1].split('.')[0]
38 |
39 | src_img = Image.open(os.path.join(img_dir, img_path))
40 | # print(src_img.mode, src_img.size)
41 | filename_pre = img_path.split('.')[0]
42 |
43 | for i in range(num):
44 | brightness_factor = np.random.randint(97, 103) / 100.0
45 | brightness_image = ImageEnhance.Brightness(src_img).enhance(brightness_factor)
46 | contrast_factor = np.random.randint(98, 102) / 100.0
47 | contrast_image = ImageEnhance.Contrast(brightness_image).enhance(contrast_factor)
48 | save_name = save_path + '/' + filename_pre + '_' + str(i) + '.jpg'
49 | contrast_image.save(save_name)
50 |
51 |
52 | def draw_line(img_path, save_path, num):
53 | """
54 | 随机画线
55 | """
56 | src_img = cv2.imread(img_path, -1)
57 | filename_pre = save_path.split('/')[-1]
58 |
59 | for i in range(num):
60 | dst_img = src_img.copy()
61 | x1, y1 = np.random.randint(80, dst_img.shape[1]-120), np.random.randint(70, dst_img.shape[0]-70)
62 | x2, y2 = x1, y1
63 | while x1-10 < x2 < x1+10:
64 | x2 = np.random.randint(x1-50, x1+50)
65 | while y1-10 < y2 < y1+10:
66 | y2 = np.random.randint(y1-50, y1+50)
67 | # print(x1, y1, x2, y2)
68 | thickness = np.random.randint(1, 2)
69 | pixel = np.random.randint(160, 200)
70 | cv2.line(dst_img, (x1, y1), (x2, y2), pixel, thickness)
71 |
72 | save_name = save_path + '/' + filename_pre + '_' + str(i) + '.jpg'
73 | cv2.imwrite(save_name, dst_img)
74 |
75 |
76 | def draw_wz(img_path, save_path, num):
77 | """
78 | 随机贴上污渍
79 | """
80 | src_img = cv2.imread(img_path, -1)
81 | filename_pre = save_path.split('/')[-1]
82 |
83 | for i in range(num):
84 | dst_img = src_img.copy()
85 | x1, y1 = np.random.randint(50, dst_img.shape[1] - 60), np.random.randint(40, dst_img.shape[0] - 40)
86 | # print(x1, y1)
87 | radius = np.random.randint(5, 20)
88 | pixel = np.random.randint(180, 200)
89 | cv2.circle(dst_img, (x1, y1), radius, pixel, -1)
90 |
91 | save_name = save_path + '/' + filename_pre + '_' + str(i) + '.jpg'
92 | cv2.imwrite(save_name, dst_img)
93 |
94 |
95 | if __name__ == '__main__':
96 | print('create_data.py is running')
97 | # img_to_gray('F:/DefectDetection/img_data/src_zc.jpg')
98 | # img_enhance('F:/DefectDetection/img_data/zc.jpg', 'F:/DefectDetection/img_data/zc', 2000)
99 | # draw_line('F:/DefectDetection/img_data/zc.jpg', 'F:/DefectDetection/img_data/hh', 500)
100 | # img_enhance('F:/DefectDetection/img_data/hh', 'F:/DefectDetection/img_data/hh', 3)
101 | # img_to_gray('F:/DefectDetection/img_data/src_wz.jpg')
102 | # draw_wz('F:/DefectDetection/img_data/zc.jpg', 'F:/DefectDetection/img_data/wz', 500)
103 | # img_enhance('F:/DefectDetection/img_data/wz', 'F:/DefectDetection/img_data/wz', 3)
104 |
--------------------------------------------------------------------------------
/Image_preprocessing/image2npy.py:
--------------------------------------------------------------------------------
1 | """
2 | 图片数据集转 npy 文件
3 | author: 王建坤
4 | date: 2018-10-15
5 | """
6 | import numpy as np
7 | import os
8 | from PIL import Image, ImageFile
9 | from matplotlib import pyplot as plt
10 | import random as rd
11 |
12 | # 读图出错的解决方法
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | IMG_SIZE = 224 # 299
15 |
16 |
17 | def multi_class_to_npy(data_path, save_path):
18 | """
19 | 多种类别的图片数据集保存为 npy 文件
20 | :param data_path: the root path of data set
21 | :param save_path: the save path of npy
22 | :return: none
23 | """
24 | data, label = [], []
25 | per_class_num = {}
26 | save_name = 'b_cig_'
27 |
28 | # 图片类型
29 | # 铝型材类别
30 | # class_dic = {'正常': 0, '不导电': 1, '擦花': 2, '角位漏底': 3, '桔皮': 4, '漏底': 5, '起坑': 6, '脏点': 7}
31 | # 电子烟装配类别
32 | class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4}
33 | # 遍历文件夹
34 | folder_list = os.listdir(data_path)
35 | for folder in folder_list:
36 | # 过滤不需要的文件夹
37 | if folder not in class_dic.keys():
38 | continue
39 | # 遍历文件夹中的文件
40 | num = 0
41 | class_path = os.path.join(data_path, folder)
42 | img_list = os.listdir(class_path)
43 | for img_name in img_list:
44 | print(img_name)
45 | # 读取图片并缩放
46 | img = Image.open(os.path.join(class_path, img_name))
47 | # img.convert('L')
48 | # print(img.mode)
49 | img = img.resize((IMG_SIZE, IMG_SIZE))
50 | # 添加数据和标签,根据文件夹名确定样本的标签
51 | img = np.array(img)
52 | data.append(img)
53 | label.append(class_dic[folder])
54 | num += 1
55 | per_class_num[folder] = num
56 |
57 | # 数据集转化为numpy数组
58 | data = np.array(data, np.uint8)
59 | label = np.array(label, np.uint8)
60 | print('Per class number is: ', per_class_num)
61 | print('Data set shape is: ', np.shape(data), np.shape(label))
62 |
63 | # 数组保存为npy
64 | data_save_path = save_path+'/'+save_name+'data_'+str(IMG_SIZE)+'.npy'
65 | label_save_path = save_path+'/'+save_name+'label_'+str(IMG_SIZE)+'.npy'
66 | np.save(data_save_path, data)
67 | np.save(label_save_path, label)
68 |
69 |
70 | def siamese_sample_to_npy(data_path, save_path):
71 | """
72 | 构造 Siamese 数据集
73 | :param data_path: the root path of data set
74 | :param save_path: the save path of npy
75 | :return: none
76 | """
77 | data, label = [], []
78 | per_class_num = {'normal': 0, 'nothing': 0, 'lack_cotton': 0, 'lack_piece': 0, 'wire_fail': 0}
79 | save_name = 's_sia_cig_'
80 |
81 | # 电子烟装配类别
82 | class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4}
83 |
84 | # 标准图片
85 | std_class_path = os.path.join(data_path, 'std')
86 | std_img_list = os.listdir(std_class_path)
87 | # std_img_list = rd.sample(std_img_list, 300)
88 | for std_img_name in std_img_list:
89 | std_img = Image.open(os.path.join(std_class_path, std_img_name))
90 | # img.convert('L')
91 | # print(img.mode)
92 | std_img = std_img.resize((IMG_SIZE, IMG_SIZE))
93 | # 添加数据和标签,根据文件夹名确定样本的标签
94 | std_img = np.array(std_img)
95 |
96 | # 遍历文件夹
97 | folder_list = os.listdir(data_path)
98 | for folder in folder_list:
99 | # 过滤不需要的文件夹
100 | if folder not in class_dic.keys():
101 | continue
102 | # 遍历文件夹中的文件
103 | class_path = os.path.join(data_path, folder)
104 | img_list = os.listdir(class_path)
105 | # 对正常类别进行降采样
106 | # if folder == 'normal':
107 | # img_list = rd.sample(img_list, 200)
108 | for img_name in img_list:
109 | print(img_name)
110 | # 读取图片并缩放
111 | img = Image.open(os.path.join(class_path, img_name))
112 | # img.convert('L')
113 | # print(img.mode)
114 | img = img.resize((IMG_SIZE, IMG_SIZE))
115 | # 添加数据和标签,根据文件夹名确定样本的标签
116 | sample = np.stack((std_img, img), axis=-1)
117 | data.append(sample)
118 | label.append(class_dic[folder])
119 | per_class_num[folder] += 1
120 |
121 | # 数据集转化为numpy数组
122 | data = np.array(data, np.uint8)
123 | label = np.array(label, np.uint8)
124 | print('Per class number is: ', per_class_num)
125 | print('Data set shape is: ', np.shape(data), np.shape(label))
126 |
127 | # 数组保存为npy
128 | data_save_path = save_path+'/'+save_name+'data_'+str(IMG_SIZE)+'.npy'
129 | label_save_path = save_path+'/'+save_name+'label_'+str(IMG_SIZE)+'.npy'
130 | np.save(data_save_path, data)
131 | np.save(label_save_path, label)
132 |
133 |
134 | def load_npy(data_path, label_path):
135 | """
136 | 读取 npy 文件
137 | :param data_path:
138 | :param label_path:
139 | :return:
140 | """
141 | data = np.load(data_path)
142 | label = np.load(label_path)
143 | print('data set shape is: ', np.shape(data))
144 | return data, label
145 |
146 |
147 | def array_to_image(sample, sample_label):
148 | """
149 | 显示一个样本的信息,图片和标签
150 | :param sample:
151 | :param sample_label:
152 | :return:
153 | """
154 | print('sample shape: ', np.shape(sample))
155 | sample = sample/255
156 | # plt.imshow 的像素值为 [0, 1]
157 | plt.imshow(sample)
158 | plt.show()
159 | print('sample label: ', sample_label)
160 |
161 |
162 | if __name__ == '__main__':
163 | print('running image2npy:')
164 | # alum_to_npy('../data/alum', '../data')
165 | multi_class_to_npy('../data/cigarette', '../data')
166 | # siamese_sample_to_npy('../data/cigarette', '../data')
167 | # siamese_sample_to_npy('E:/backup/cigarette', '../data')
168 |
169 |
170 |
171 |
--------------------------------------------------------------------------------
/Image_preprocessing/image_crop.py:
--------------------------------------------------------------------------------
1 | """
2 | 图像裁剪,从大图片中裁剪出小图片
3 | author: 王建坤
4 | date: 2018-12-6
5 | """
6 | import cv2
7 | import os
8 |
9 | IMG_H_SIZE = 2000
10 | IMG_V_SIZE = 1000
11 |
12 | IMG_DIR = '../data/phone'
13 | SAVE_DIR = '../data/crop'
14 |
15 |
16 | def open_img(img_path):
17 | """
18 | 打开指定路径的单张图像(灰度图),并缩放
19 | :param img_path: 图像路径
20 | :return:
21 | """
22 | img = cv2.imread(img_path, 0)
23 | img = cv2.resize(img, (IMG_H_SIZE, IMG_V_SIZE))
24 | return img
25 |
26 |
27 | def crop_img(img):
28 | for x in range(20, 101, 8):
29 | for y in range(60, 701, 8):
30 | print(x, '-', y)
31 | small_image = img[x:x+128, y:y+128]
32 | cv2.imwrite(os.path.join(SAVE_DIR, str(x)+'_'+str(y)+'.jpg'), small_image)
33 |
34 |
35 | def main():
36 | img_list = os.listdir(IMG_DIR)
37 | img_name = img_list[0]
38 | print('image name:', img_name)
39 | img = open_img(os.path.join(IMG_DIR, img_name))
40 | crop_img(img)
41 |
42 |
43 | if __name__ == '__main__':
44 | main()
45 |
--------------------------------------------------------------------------------
/Image_preprocessing/json2npy.py:
--------------------------------------------------------------------------------
1 | """
2 | 把 json文件转为数据集,并保存为npy
3 | author: 王建坤
4 | date: 2018-10-26
5 | """
6 | # labelme_json_to_dataset E:/1.json
7 | # labelme_draw_json E:/1.json
8 | import json
9 | import os
10 | import numpy as np
11 | from PIL import Image
12 |
13 | IMG_SIZE = 224
14 | col = 2592
15 | row = 1944
16 |
17 |
18 | def json_to_npy(root_path, save_path='../data/'):
19 | """
20 | 把 json文件转为数据集,并保存为npy
21 | :param root_path:
22 | :param save_path:
23 | :return:
24 | """
25 | data, labels = [], []
26 | f = open('../data/label.txt', 'a')
27 | file_list = os.listdir(root_path)
28 | for file_name in file_list:
29 | if file_name.split('.')[-1] == 'jpg':
30 | img = Image.open(os.path.join(root_path, file_name))
31 | img = img.resize((IMG_SIZE, IMG_SIZE))
32 | img = np.array(img)
33 | img = np.expand_dims(img, 3)
34 | data.append(img)
35 | json_path = os.path.join(root_path, file_name.split('.')[0]) + '.json'
36 | label = json_to_label(json_path)
37 | f.write(str(label) + '\n')
38 | temp = [label[0][0]/col, label[0][1]/row, label[1][0]/col, label[1][1]/row,
39 | label[2][0]/col, label[2][1]/row, label[3][0]/col, label[3][1]/row]
40 | labels.append(temp)
41 |
42 | f.close()
43 | # 数据集转化为numpy数组
44 | data = np.array(data, np.uint8)
45 | labels = np.array(labels, np.float16)
46 | print('data set shape is: ', np.shape(data), np.shape(labels))
47 |
48 | # 数组保存为npy
49 | data_save_path = save_path + 'card_data.npy'
50 | label_save_path = save_path + 'card_label.npy'
51 | np.save(data_save_path, data)
52 | np.save(label_save_path, labels)
53 |
54 |
55 | def json_to_label(json_path):
56 | """
57 | 抽取 json 文件中的 label
58 | :param json_path:
59 | :return:
60 | """
61 | with open(json_path, 'r', encoding='utf-8') as load_f:
62 | data = json.load(load_f)
63 | label = data['shapes'][0]['points']
64 | # print(label)
65 | return label
66 |
67 |
68 | if __name__ == '__main__':
69 | print('running json2npy: ')
70 | json_to_npy('../data/card')
71 | # json_to_label('E:/1.json')
72 |
73 |
74 |
--------------------------------------------------------------------------------
/Siamese/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__init__.py
--------------------------------------------------------------------------------
/Siamese/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/Siamese/__pycache__/inference.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__pycache__/inference.cpython-36.pyc
--------------------------------------------------------------------------------
/Siamese/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/Siamese/bin_classify.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 训练 -- 二分类
3 | author: 王建坤
4 | date: 2018-10-15
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 | import os
9 | from Siamese import inference, utils
10 |
11 | MAX_STEP = 600
12 | LEARNING_RATE = 0.01
13 | # 训练信息和保存权重的gap
14 | INFO_STEP = 20
15 | SAVE_STEP = 200
16 |
17 | BATCH_SIZE = 128
18 | IMG_SIZE = 299
19 |
20 |
21 | def train(inherit=False):
22 | """
23 | 用二分类来训练 Siamese
24 | """
25 | # 加载数据集
26 | images = np.load('../data/data_'+str(IMG_SIZE)+'.npy')
27 | labels = np.load('../data/label_'+str(IMG_SIZE)+'.npy')
28 | images, labels = utils.shuffle_data(images, labels)
29 |
30 | # 占位符
31 | with tf.variable_scope('bin_input_x1') as scope:
32 | x1 = tf.placeholder(tf.float32, shape=[None, IMG_SIZE, IMG_SIZE, 3])
33 | with tf.variable_scope('bin_input_x2') as scope:
34 | x2 = tf.placeholder(tf.float32, shape=[None, IMG_SIZE, IMG_SIZE, 3])
35 | with tf.variable_scope('bin_y') as scope:
36 | y = tf.placeholder(tf.float32, shape=[None, 1])
37 | with tf.name_scope('bin_keep_prob') as scope:
38 | keep_prob = tf.placeholder(tf.float32)
39 |
40 | # 前向传播
41 | with tf.variable_scope('bin_siamese') as scope:
42 | out1 = inference.inference(x1, keep_prob)
43 | # 参数共享,不会生成两套参数
44 | scope.reuse_variables()
45 | out2 = inference.inference(x2, keep_prob)
46 |
47 | # 增加二分类层
48 | with tf.name_scope('bin_c') as scope:
49 | w_bc = tf.Variable(tf.truncated_normal(shape=[64, 1], stddev=0.05, mean=0), name='w_bc')
50 | b_bc = tf.Variable(tf.zeros(1), name='b_bc')
51 | out12 = tf.concat((out1, out2), 1, name='out12')
52 | pre = tf.add(tf.matmul(out12, w_bc), b_bc)
53 |
54 | # 损失函数和优化器
55 | with tf.variable_scope('metrics') as scope:
56 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=pre, name='entropy')
57 | loss = tf.reduce_mean(cross_entropy, name='loss')
58 | train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
59 |
60 | init = tf.global_variables_initializer()
61 | saver = tf.train.Saver(tf.global_variables())
62 |
63 | # 模型保存路径
64 | log_path = '../log/Siamese'
65 |
66 | with tf.Session() as sess:
67 | step = 0
68 | if inherit:
69 | ckpt = tf.train.get_checkpoint_state(log_path)
70 | if ckpt and ckpt.model_checkpoint_path:
71 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
72 | saver.restore(sess, ckpt.model_checkpoint_path)
73 | print('Siamese continue train from %s:' % global_step)
74 | step = int(global_step)
75 | else:
76 | print('Error: no checkpoint file found')
77 | return
78 | else:
79 | print('Siamese restart train:')
80 | sess.run(init)
81 |
82 | while step < MAX_STEP:
83 | # 获取一对batch的数据集
84 | x_1, y_1 = utils.get_batch(images, labels, BATCH_SIZE)
85 | x_2, y_2 = utils.get_batch(images, labels, BATCH_SIZE)
86 | # 判断对应的两个标签是否相等
87 | y_s = np.array(y_1 == y_2, dtype=np.uint8)
88 | y_s = np.expand_dims(y_s, axis=1)
89 |
90 | _, train_loss = sess.run([train_op, loss], feed_dict={x1: x_1, x2: x_2, y: y_s, keep_prob: 0.8})
91 |
92 | step += 1
93 | # 训练信息和保存模型
94 | if step % INFO_STEP == 0 or step == MAX_STEP:
95 | print('step: %d, loss: %.4f' % (step, train_loss))
96 |
97 | if step % SAVE_STEP == 0:
98 | checkpoint_path = os.path.join(log_path, 'sia_bin.ckpt')
99 | saver.save(sess, checkpoint_path, global_step=step)
100 |
101 |
102 | if __name__ == '__main__':
103 | train()
104 |
--------------------------------------------------------------------------------
/Siamese/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 配置文件,全局参数
3 | author: 王建坤
4 | date: 2018-12-20
5 | """
6 | import numpy as np
7 | import tensorflow as tf
8 | from nets import alexnet, mobilenet_v1, vgg, inception_v4, resnet_v2, inception_resnet_v2
9 | from Detect import mynet
10 | import os
11 |
12 | # 分类类别
13 | CLASSES = 5
14 | # 图像尺寸
15 | IMG_SIZE = 224
16 | # 图像通道数
17 | CHANNEL = 1
18 | # 是否为非标准图像尺寸
19 | GLOBAL_POOL = True
20 | # 是否使用GPU
21 | USE_GPU = True
22 | if not USE_GPU:
23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
24 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
25 | # 模型名称
26 | MODEL_NAME = 'Mobile'
27 | # 数据集路径
28 | images_path = '../data/cig_data_' + str(IMG_SIZE) + '.npy'
29 | labels_path = '../data/cig_label_' + str(IMG_SIZE) + '.npy'
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/Siamese/data_cache.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 缓存数据 -- 缓存用Siamese提取每个样本的特征向量
3 | author: 王建坤
4 | date: 2018-8-14
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 | from Siamese import inference
9 |
10 |
11 | def cache_data():
12 | """
13 | 缓存用Siamese提取每个样本的特征向量
14 | """
15 | # 加载数据集
16 | images = np.load('F:/DefectDetection/npy_data/train_data.npy')
17 | # 根据自己的图像尺寸修改shape
18 | images_input = tf.reshape(images, [-1, 28, 28, 1])
19 |
20 | # 前向传播
21 | result = inference.inference(images_input, 1.0)
22 |
23 | saver = tf.train.Saver()
24 |
25 | with tf.Session() as sess:
26 | saver.restore(sess, 'F:/DefectDetection/log/Siamese')
27 | feature_vectors = sess.run(result)
28 |
29 | # 保存为npy文件
30 | np.save('F:/DefectDetection/npy_data/feature_vectors.npy', feature_vectors)
31 |
32 |
33 | if __name__ == '__main__':
34 | cache_data()
35 |
--------------------------------------------------------------------------------
/Siamese/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 评估
3 | author: 王建坤
4 | date: 2018-8-16
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 | from Siamese import inference, utils
9 |
10 | BATCH_SIZE = 132
11 | CLASSES = 3
12 | CODE_LEN = 32
13 |
14 |
15 | def evaluate():
16 | # 加载数据集
17 | images = np.load('E:/dataset/npy/train_data.npy')
18 | labels = np.load('E:/dataset/npy/train_label.npy')
19 | total_test = images.shape[0]
20 |
21 | # 占位符
22 | with tf.variable_scope('input_x1') as scope:
23 | x1 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1])
24 | with tf.variable_scope('input_x2') as scope:
25 | x2 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1])
26 | with tf.variable_scope('y') as scope:
27 | y = tf.placeholder(tf.float32, shape=[None])
28 |
29 | with tf.name_scope('keep_prob') as scope:
30 | keep_prob = tf.placeholder(tf.float32)
31 |
32 | # 前向传播
33 | with tf.variable_scope('siamese') as scope:
34 | out1 = inference.inference(x1, keep_prob)
35 | # 参数共享,不会生成两套参数
36 | scope.reuse_variables()
37 | out2 = inference.inference(x2, keep_prob)
38 |
39 | # 损失函数和优化器
40 | with tf.variable_scope('metrics') as scope:
41 | loss = inference.loss_spring(out1, out2, y)
42 |
43 | saver = tf.train.Saver(tf.global_variables())
44 |
45 | # 模型保存路径
46 | log_dir = "E:/alum/log/Siamese"
47 |
48 | with tf.Session() as sess:
49 | print("Reading checkpoints...")
50 | ckpt = tf.train.get_checkpoint_state(log_dir)
51 | if ckpt and ckpt.model_checkpoint_path:
52 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
53 | saver.restore(sess, ckpt.model_checkpoint_path)
54 | print('Loading success, global_step is %s' % global_step)
55 | else:
56 | print('No checkpoint file found')
57 | return
58 |
59 | # 各个类别的样本数和平均编码
60 | sample_num = np.zeros(CLASSES)
61 | code_sum = np.zeros((CLASSES, CODE_LEN))
62 | for i in range(total_test):
63 | xs_1 = images[i:i+1]
64 | xs_1 = np.expand_dims(xs_1, axis=3)
65 | l_1 = np.argmax(labels[i])
66 | y1 = sess.run(out1, feed_dict={x1: xs_1, keep_prob: 1})
67 | sample_num[l_1] += 1
68 | code_sum[l_1] += np.squeeze(y1)
69 | code_mean = code_sum/np.expand_dims(sample_num, axis=1)
70 | # print('sample_num: ', '\n', sample_num, '\n', 'code_mean: ', code_mean)
71 |
72 | # 类别间的距离
73 | class_diff = np.zeros((CLASSES, CLASSES))
74 | for i in range(CLASSES):
75 | for j in range(i, CLASSES):
76 | class_diff[i][j] = class_diff[j][i] = np.mean(np.square(code_mean[i]-code_mean[j]))
77 | print('code diff between class: ', '\n', class_diff)
78 |
79 |
80 | if __name__ == '__main__':
81 | evaluate()
82 |
--------------------------------------------------------------------------------
/Siamese/inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 模型网络结构--基于LeNet5
3 | author: 王建坤
4 | date: 2018-8-14
5 | """
6 | import tensorflow as tf
7 |
8 |
9 | def inference(inputs, keep_prob):
10 | initer = tf.truncated_normal_initializer(stddev=0.1)
11 |
12 | with tf.name_scope('conv1') as scope:
13 | w1 = tf.get_variable('w1', dtype=tf.float32, shape=[11, 11, 3, 4], initializer=initer)
14 | b1 = tf.get_variable('b1', dtype=tf.float32, initializer=tf.constant(0.01, shape=[4], dtype=tf.float32))
15 | conv1 = tf.nn.conv2d(inputs, w1, strides=[1, 1, 1, 1], padding='VALID', name='conv1')
16 | with tf.name_scope('relu1') as scope:
17 | relu1 = tf.nn.relu(tf.add(conv1, b1), name='relu1')
18 | with tf.name_scope('max_pool1') as scope:
19 | pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='max_pool1')
20 |
21 | with tf.name_scope('conv2') as scope:
22 | w2 = tf.get_variable('w2', dtype=tf.float32, shape=[5, 5, 4, 8], initializer=initer)
23 | b2 = tf.get_variable('b2', dtype=tf.float32, initializer=tf.constant(0.01, shape=[8], dtype=tf.float32))
24 | conv2 = tf.nn.conv2d(pool1, w2, strides=[1, 1, 1, 1], padding='VALID', name='conv2')
25 | with tf.name_scope('relu2') as scope:
26 | relu2 = tf.nn.relu(conv2 + b2, name='relu2')
27 | with tf.name_scope('max_pool2') as scope:
28 | pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='max_pool2')
29 |
30 | with tf.name_scope('conv3') as scope:
31 | w3 = tf.get_variable('w3', dtype=tf.float32, shape=[3, 3, 8, 8], initializer=initer)
32 | b3 = tf.get_variable('b3', dtype=tf.float32, initializer=tf.constant(0.01, shape=[8], dtype=tf.float32))
33 | conv3 = tf.nn.conv2d(pool2, w3, strides=[1, 1, 1, 1], padding='VALID', name='conv3')
34 | with tf.name_scope('relu3') as scope:
35 | relu3 = tf.nn.relu(conv3 + b3, name='relu3')
36 | with tf.name_scope('max_pool3') as scope:
37 | pool3 = tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='max_pool3')
38 |
39 | dim = pool3.get_shape()[1]*pool3.get_shape()[2]*pool3.get_shape()[3]
40 | # print(pool3.get_shape()[1], pool3.get_shape()[2], pool3.get_shape()[3])
41 |
42 | with tf.name_scope('fc1') as scope:
43 | x_flat = tf.reshape(pool3, shape=[-1, dim])
44 | w_fc1 = tf.get_variable('w_fc1', dtype=tf.float32, shape=[int(dim), 128], initializer=initer)
45 | b_fc1 = tf.get_variable('b_fc1', dtype=tf.float32, initializer=tf.constant(0.01, shape=[128], dtype=tf.float32))
46 | fc1 = tf.add(tf.matmul(x_flat, w_fc1), b_fc1)
47 | with tf.name_scope('relu_fc1') as scope:
48 | relu_fc1 = tf.nn.relu(fc1, name='relu_fc1')
49 | # with tf.name_scope('bn_fc1') as scope:
50 | # bn_fc1 = tf.layers.batch_normalization(relu_fc1, name='bn_fc1')
51 | with tf.name_scope('drop_1') as scope:
52 | drop_1 = tf.nn.dropout(relu_fc1, keep_prob=keep_prob, name='drop_1')
53 |
54 | with tf.name_scope('fc2') as scope:
55 | w_fc2 = tf.get_variable('w_fc2', dtype=tf.float32, shape=[128, 32], initializer=initer)
56 | b_fc2 = tf.get_variable('b_fc2', dtype=tf.float32, initializer=tf.constant(0.01, shape=[32], dtype=tf.float32))
57 | fc2 = tf.add(tf.matmul(drop_1, w_fc2), b_fc2)
58 |
59 | return fc2
60 |
61 |
62 | # 损失函数1
63 | # y=1 表示同一类
64 | def siamese_loss(out1, out2, y, margin=5.0):
65 | Q = tf.constant(margin, name="Q", dtype=tf.float32)
66 | E_w = tf.sqrt(tf.reduce_sum(tf.square(out1-out2), 1))
67 | # 同类
68 | pos = tf.multiply(tf.multiply(y, 2/Q), tf.square(E_w))
69 | # 不同类
70 | neg = tf.multiply(tf.multiply(1-y, 2*Q), tf.exp(-2.77/Q*E_w))
71 | loss = pos + neg
72 | loss = tf.reduce_mean(loss)
73 | return loss
74 |
75 |
76 | # 损失函数2
77 | def loss_spring(out1, out2, y, margin=5.0):
78 | eucd2 = tf.reduce_sum(tf.square(out1-out2), 1)
79 | eucd = tf.sqrt(eucd2+1e-6, name="eucd")
80 | C = tf.constant(margin, name="C")
81 | # (1-yi)*||CNN(p1i)-CNN(p2i)||^2 + yi*max(0, C-||CNN(p1i)-CNN(p2i)||^2)
82 | # 同类
83 | pos = tf.multiply(y, eucd, name="pos_loss")
84 | # 不同类
85 | # neg = tf.multiply(1 - y, tf.pow(tf.maximum(C - eucd, 0.0), 2), name="neg_loss")
86 | neg = tf.multiply(1-y, tf.maximum(C-eucd, 0.0), name="neg_loss")
87 | losses = tf.add(pos, neg, name="losses")
88 | loss = tf.reduce_mean(losses, name="loss")
89 | return loss
90 |
--------------------------------------------------------------------------------
/Siamese/inference_2.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 训练 -- 多分类网络结构
3 | author: 王建坤
4 | date: 2018-8-14
5 | """
6 | import tensorflow as tf
7 |
8 |
9 | def inference(inputs, keep_prob):
10 | """
11 | 两层的神经网络进行多分类
12 | """
13 | with tf.name_scope('mul_fc1') as scope:
14 | w_fc1 = tf.Variable(tf.truncated_normal(shape=[inputs.shape[1], 64], stddev=0.05, mean=0), name='w_fc1')
15 | b_fc1 = tf.Variable(tf.zeros(64), name='b_fc1')
16 | fc1 = tf.add(tf.matmul(inputs, w_fc1), b_fc1)
17 | with tf.name_scope('mul_relu_fc1') as scope:
18 | relu_fc1 = tf.nn.relu(fc1, name='relu_fc1')
19 | with tf.name_scope('mul_drop_1') as scope:
20 | drop_1 = tf.nn.dropout(relu_fc1, keep_prob=keep_prob, name='drop_1')
21 | with tf.name_scope('mul_bn_fc1') as scope:
22 | bn_fc1 = tf.layers.batch_normalization(drop_1, name='bn_fc1')
23 |
24 | with tf.name_scope('mul_fc2') as scope:
25 | w_fc2 = tf.Variable(tf.truncated_normal(shape=[64, 10], stddev=0.05, mean=0), name='w_fc1')
26 | b_fc2 = tf.Variable(tf.zeros(64), name='b_fc2')
27 | fc2 = tf.add(tf.matmul(bn_fc1, w_fc2), b_fc2)
28 |
29 | return fc2
30 |
--------------------------------------------------------------------------------
/Siamese/predict.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 预测
3 | author: 王建坤
4 | date: 2018-8-16
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 | from Siamese import inference
9 |
10 | # 搭建图和恢复模型
11 | # 占位符
12 | with tf.variable_scope('input_x1') as scope:
13 | x1 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1])
14 | with tf.variable_scope('input_x2') as scope:
15 | x2 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1])
16 |
17 | # 前向传播
18 | with tf.variable_scope('siamese') as scope:
19 | out1 = inference.inference(x1, 1)
20 | # 参数共享,不会生成两套参数
21 | scope.reuse_variables()
22 | out2 = inference.inference(x2, 1)
23 |
24 | diff = tf.sqrt(tf.reduce_sum(tf.square(out1 - out2)))
25 |
26 | saver = tf.train.Saver(tf.global_variables())
27 | # 模型保存路径
28 | log_dir = "E:/alum/log/Siamese"
29 |
30 | sess = tf.Session()
31 |
32 | # 重复使用变量
33 | # tf.get_variable_scope().reuse_variables()
34 |
35 | print("Reading checkpoints...")
36 | ckpt = tf.train.get_checkpoint_state(log_dir)
37 | if ckpt and ckpt.model_checkpoint_path:
38 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
39 | saver.restore(sess, ckpt.model_checkpoint_path)
40 | print('Loading success, global_step is %s' % global_step)
41 | else:
42 | print('No checkpoint file found')
43 |
44 |
45 | # 单独预测,不用重复加载模型,sess是全局变量
46 | def predict(img1, img2):
47 | x_1 = np.expand_dims(img1, axis=3)
48 | x_2 = np.expand_dims(img2, axis=3)
49 | y1, y2, diff_loss = sess.run([out1, out2, diff], feed_dict={x1: x_1, x2: x_2})
50 | print(y1, '\n', y2)
51 | print(diff_loss)
52 |
53 | # 加载模型和预测在一个函数中
54 | # def predict(img1, img2):
55 | # x_1 = np.expand_dims(img1, axis=3)
56 | # x_2 = np.expand_dims(img2, axis=3)
57 | #
58 | # # 占位符
59 | # with tf.variable_scope('input_x1') as scope:
60 | # x1 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1])
61 | # with tf.variable_scope('input_x2') as scope:
62 | # x2 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1])
63 | #
64 | # # 前向传播
65 | # with tf.variable_scope('siamese') as scope:
66 | # out1 = inference.inference(x1, 1)
67 | # # 参数共享,不会生成两套参数
68 | # scope.reuse_variables()
69 | # out2 = inference.inference(x2, 1)
70 | #
71 | # diff = tf.sqrt(tf.reduce_sum(tf.square(out1 - out2)))
72 | #
73 | # saver = tf.train.Saver(tf.global_variables())
74 | # # 模型保存路径
75 | # log_dir = "E:/alum/log/Siamese"
76 | #
77 | # with tf.Session() as sess:
78 | # tf.get_variable_scope().reuse_variables()
79 | # print("Reading checkpoints...")
80 | # ckpt = tf.train.get_checkpoint_state(log_dir)
81 | # if ckpt and ckpt.model_checkpoint_path:
82 | # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
83 | # saver.restore(sess, ckpt.model_checkpoint_path)
84 | # print('Loading success, global_step is %s' % global_step)
85 | # else:
86 | # print('No checkpoint file found')
87 | # return
88 | #
89 | # # 查看图的 tensor
90 | # # for var in tf.global_variables():
91 | # # print(var.name)
92 | #
93 | # # 查看 tensor 的值
94 | # # print(sess.run(tf.get_default_graph().get_tensor_by_name("siamese/b2:0")))
95 | #
96 | # y1, y2, diff_loss = sess.run([out1, out2, diff], feed_dict={x1: x_1, x2: x_2})
97 | # # print(y1, '\n', y2)
98 | # print(diff_loss)
99 |
100 |
101 | if __name__ == '__main__':
102 | images = np.load('E:/dataset/npy/train_data.npy')
103 | labels = np.load('E:/dataset/npy/train_label.npy')
104 | index1 = 0
105 | img1 = images[index1:index1+1]
106 | index2 = 1
107 | img2 = images[index2:index2+1]
108 | print(img1.shape)
109 | predict(img1, img2)
110 |
111 | # for index2 in range(labels.shape[0]):
112 | # img2 = images[index2:index2 + 1]
113 | # predict(img1, img2)
114 |
115 | # 两个样本间的关系
116 | # correct = np.equal(np.argmax(labels[index1]), np.argmax(labels[index2]))
117 | # print('true relationship is: ', correct)
118 |
119 |
120 | # first = 0
121 | # second = 0
122 | # third = 0
123 | # for i in range(120):
124 | # for j in range(240, 400):
125 | # ou1, ou2 = predict(images[i:i+1], images[j:j+1])
126 | # # ou2 = predict(images[301:302], images[301:302])
127 | # diffe = np.sqrt(np.sum(np.square(ou1 - ou2), 1))
128 | # # print('difference is: ', diffe)
129 | # if diffe < 1:
130 | # first+=1
131 | # elif 1<= diffe <5:
132 | # second +=1
133 | # else:
134 | # third += 1
135 | # print(first, second, third)
136 |
--------------------------------------------------------------------------------
/Siamese/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 训练 -- 相似度损失函数
3 | author: 王建坤
4 | date: 2018-9-30
5 | """
6 | from sklearn.model_selection import train_test_split
7 | from Siamese import inference, utils
8 | from Siamese.config import *
9 |
10 | MAX_STEP = 2000
11 | LEARNING_RATE_BASE = 0.001
12 | LEARNING_RATE_DECAY = 0.98
13 | # 训练信息和保存权重的gap
14 | INFO_STEP = 20
15 | SAVE_STEP = 200
16 | # 图像尺寸
17 | BATCH_SIZE = 16
18 |
19 |
20 | def train(inherit=False):
21 | # 加载数据集
22 | images = np.load(images_path)
23 | labels = np.load(labels_path)
24 | train_data, val_data, train_label, val_label = train_test_split(images, labels, test_size=0.2, random_state=222)
25 | # 如果输入是灰色图,要增加一维
26 | if CHANNEL == 1:
27 | train_data = np.expand_dims(train_data, axis=3)
28 | val_data = np.expand_dims(val_data, axis=3)
29 |
30 | # 占位符
31 | x1 = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL], name="x_input1")
32 | x2 = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL], name="x_input2")
33 | y = tf.placeholder(tf.float32, [None], name="y_input")
34 | # keep_prob = tf.placeholder(tf.float32, name='keep_prob')
35 | my_global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64)
36 |
37 | # 前向传播
38 | with tf.variable_scope('siamese') as scope:
39 | out1, _ = mobilenet_v1.mobilenet_v1(x1,
40 | num_classes=CLASSES,
41 | dropout_keep_prob=1.0,
42 | is_training=True,
43 | min_depth=8,
44 | depth_multiplier=1.0,
45 | conv_defs=None,
46 | prediction_fn=None,
47 | spatial_squeeze=True,
48 | reuse=tf.AUTO_REUSE,
49 | scope='MobilenetV1',
50 | global_pool=GLOBAL_POOL)
51 | # 参数共享,不会生成两套参数。注意定义variable时要使用get_variable()
52 | # scope.reuse_variables()
53 | out2, _ = mobilenet_v1.mobilenet_v1(x2,
54 | num_classes=CLASSES,
55 | dropout_keep_prob=1.0,
56 | is_training=True,
57 | min_depth=8,
58 | depth_multiplier=1.0,
59 | conv_defs=None,
60 | prediction_fn=None,
61 | spatial_squeeze=True,
62 | reuse=tf.AUTO_REUSE,
63 | scope='MobilenetV1',
64 | global_pool=GLOBAL_POOL)
65 |
66 | # 损失函数和优化器
67 | with tf.variable_scope('metrics') as scope:
68 | loss = inference.loss_spring(out1, out2, y)
69 | learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, my_global_step, 100, LEARNING_RATE_DECAY)
70 | train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=my_global_step)
71 |
72 | saver = tf.train.Saver(tf.global_variables())
73 |
74 | # 模型保存路径和名称
75 | log_dir = "../log/Siamese"
76 | model_name = 'siamese.ckpt'
77 |
78 | with tf.Session() as sess:
79 | step = 0
80 | if inherit:
81 | ckpt = tf.train.get_checkpoint_state(log_dir)
82 | if ckpt and ckpt.model_checkpoint_path:
83 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
84 | saver.restore(sess, ckpt.model_checkpoint_path)
85 | print('Siamese continue train from %s' % global_step)
86 | step = int(global_step)
87 | else:
88 | print('No checkpoint file found')
89 | return
90 | else:
91 | print('restart train')
92 | sess.run(tf.global_variables_initializer())
93 |
94 | while step < MAX_STEP:
95 | # 获取一对batch的数据集
96 | xs_1, ys_1 = utils.get_batch(train_data, train_label, BATCH_SIZE)
97 | xs_2, ys_2 = utils.get_batch(train_data, train_label, BATCH_SIZE)
98 | # 判断对应的两个标签是否相等
99 | y_s = np.array(ys_1 == ys_2, dtype=np.float32)
100 |
101 | _, y1, y2, train_loss = sess.run([train_op, out1, out2, loss],
102 | feed_dict={x1: xs_1, x2: xs_2, y: y_s})
103 |
104 | # 训练信息和保存模型
105 | step += 1
106 | if step % INFO_STEP == 0 or step == MAX_STEP:
107 | print('step: %d, loss: %.4f' % (step, train_loss))
108 |
109 | if step % SAVE_STEP == 0 or step == MAX_STEP:
110 | checkpoint_path = os.path.join(log_dir, model_name)
111 | saver.save(sess, checkpoint_path, global_step=step)
112 |
113 |
114 | if __name__ == '__main__':
115 | train()
116 |
--------------------------------------------------------------------------------
/Siamese/train_2.py:
--------------------------------------------------------------------------------
1 | """
2 | Siamese 训练 -- 多分类训练
3 | author: 王建坤
4 | date: 2018-8-14
5 | """
6 | import tensorflow as tf
7 | import numpy as np
8 | import os
9 | from Siamese_multi import utils
10 | from Siamese import inference_2
11 |
12 | learning_rate = 0.01
13 | iterations = 100
14 | batch_size = 64
15 |
16 |
17 | def multi_classify_train():
18 | """
19 | 用Siamese提取的特征进行多分类
20 | """
21 | # 加载数据集
22 | feature_vectors = np.load('F:/DefectDetection/npy_data/feature_vectors.npy')
23 | labels = np.load('F:/DefectDetection/npy_data/train_label.npy')
24 |
25 | total_train = feature_vectors.shape[0]
26 |
27 | # 占位符
28 | input_vectors = tf.placeholder(tf.float32, shape=[None, 64], name='input_vectors')
29 | y = tf.placeholder(tf.float32, shape=[None, 3])
30 | keep_prob = tf.placeholder(tf.float32, name='multi_keep_prob')
31 |
32 | # 前向传播
33 | result = inference_2.inference(input_vectors, keep_prob)
34 |
35 | # 损失函数和优化器
36 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=result, name='multi_entropy')
37 | loss = tf.reduce_mean(cross_entropy, name='multi_loss')
38 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
39 |
40 | # 模型保存路径
41 | log_dir = "F:/DefectDetection/log/Siamese_multi"
42 | saver = tf.train.Saver(tf.global_variables())
43 |
44 | with tf.Session() as sess:
45 | sess.run(tf.global_variables_initializer())
46 |
47 | start = 0
48 | for iter in range(iterations):
49 | # 获取一个batch的数据集
50 | vectors_batch = utils.get_batch(feature_vectors, start, batch_size, total_train)
51 | label_batch = utils.get_batch(labels, start, batch_size, total_train)
52 | _, train_loss = sess.run([train_op, loss], feed_dict={input_vectors: vectors_batch,
53 | y: label_batch, keep_prob: 0.6})
54 | start += batch_size
55 | # if iter % 100 == 1:
56 | # print('iter {},train loss {}'.format(iter, train_loss))
57 |
58 | # 训练信息和保存模型
59 | if (iter + 11) % 10 == 0 or (iter + 1) == iterations:
60 | print('iter: %d, loss: %.4f' % (iter, train_loss))
61 | checkpoint_path = os.path.join(log_dir, 'Siamese_multi_model.ckpt')
62 | saver.save(sess, checkpoint_path, global_step=iter)
63 |
64 |
65 | if __name__ == '__main__':
66 | multi_classify_train()
67 |
--------------------------------------------------------------------------------
/Siamese/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | 工具函数
3 | author: 王建坤
4 | date: 2018-10-15
5 | """
6 | import numpy as np
7 | import os
8 | import pandas as pd
9 |
10 |
11 | def load_npy(data_path, label_path):
12 | """
13 | Load npy file
14 | :param data_path: the path of data directory
15 | :param label_path: the path of label directory
16 | :return: numpy format of data and label
17 | """
18 | data = np.load(data_path)
19 | label = np.load(label_path)
20 | print('data_set shape: ', np.shape(data))
21 | return data, label
22 |
23 |
24 | def get_batch(data, label, batch_size):
25 | """
26 | Get a batch from numpy array.
27 | :param data: numpy array of data
28 | :param label: numpy array of label
29 | :param batch_size: the number of a batch
30 | :param total: the number of data set sample
31 | :return: a batch of data set
32 | """
33 | # 遍历数据集
34 | # start = start % total
35 | # end = start + batch_size
36 | # if end > total:
37 | # res = end % total
38 | # end = total
39 | # data1 = data[start:end]
40 | # data2 = data[0: res]
41 | # return np.vstack((data1, data2))
42 | # return data[start:end]
43 |
44 | # 随机选取一个batch
45 | total = data.shape[0]
46 | index_list = np.random.randint(total, size=batch_size)
47 | return data[index_list, :], label[index_list]
48 |
49 |
50 | def shuffle_data(data, label):
51 | """
52 | Shuffle the data to assure that the training model is valid.
53 | """
54 | permutation = np.random.permutation(label.shape[0])
55 | batch_data = data[permutation, :]
56 | batch_label = label[permutation]
57 | return batch_data, batch_label
58 |
59 |
60 | def normalize_data(data):
61 | """
62 | Normalized the data by subtracting the mean value of R,G,B to accelerate training .
63 | """
64 | r = data[:, :, :, 2]
65 | r_mean = np.average(r)
66 | b = data[:, :, :, 0]
67 | b_mean = np.average(b)
68 | g = data[:, :, :, 1]
69 | g_mean = np.average(g)
70 | r = r - r_mean
71 | b = b - b_mean
72 | g = g - g_mean
73 | data = np.zeros([r.shape[0], r.shape[1], r.shape[2], 3])
74 | data[:, :, :, 0] = b
75 | data[:, :, :, 1] = g
76 | data[:, :, :, 2] = r
77 | return data
78 |
79 |
80 | def check_tensor():
81 | """
82 | 查看保存的 ckpt文件中的 tensor
83 | """
84 | from tensorflow.python import pywrap_tensorflow
85 |
86 | # ckpt 文件路径
87 | checkpoint_path = 'E:/alum/weight/inception_v4_2016_09_09/inception_v4.ckpt'
88 |
89 | # Read data from checkpoint file
90 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
91 | var_to_shape_map = reader.get_variable_to_shape_map()
92 |
93 | # Print tensor name and values
94 | for key in var_to_shape_map:
95 | print("tensor_name: ", key)
96 | print(reader.get_tensor(key).dtype)
97 |
98 |
99 | def image_to_csv(data_path, save_path):
100 | """
101 | 图片路径和标签保存为 csv
102 | :param data_path: image root path
103 | :param save_path: save path of csv file
104 | :return: none
105 | """
106 | img_path, label = [], []
107 | per_class_num = {}
108 | # 图片类型
109 | class_dic = {'正常': 0, '不导电': 1, '擦花': 2, '角位漏底': 3, '桔皮': 4, '漏底': 5, '起坑': 6, '脏点': 7}
110 | # 遍历文件夹
111 | folder_list = os.listdir(data_path)
112 | for folder in folder_list:
113 | if folder not in class_dic.keys():
114 | continue
115 |
116 | num = 0
117 | class_path = os.path.join(data_path, folder)
118 | img_list = os.listdir(class_path)
119 | for img_name in img_list:
120 | img_path.append(os.path.join(class_path, img_name))
121 | label.append(class_dic[folder])
122 | num += 1
123 | per_class_num[folder] = num
124 |
125 | img_path_label = pd.DataFrame({'img_path': img_path, 'label': label})
126 | img_path_label.to_csv(save_path+'/data_label.csv', index=False)
127 | print('per class number: ', per_class_num)
128 |
129 |
130 | def csv_show_image():
131 | """
132 | 从 csv 文件读取图片路径,并显示
133 | :return: none
134 | """
135 | from PIL import Image
136 | table = pd.read_csv('../data/data_label.csv')
137 | img_path = table.iloc[0, 0]
138 | img = Image.open(img_path)
139 | img.show()
140 |
141 |
142 | if __name__ == "__main__":
143 | print('running utils:')
144 | # train_data, train_label = load_npy('../data/data_299.npy', '../data/label_299.npy')
145 | # image_to_csv('../data/alum', '../data')
146 |
147 |
148 |
--------------------------------------------------------------------------------
/Siamese_multi/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese_multi/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/Siamese_multi/__pycache__/inference.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese_multi/__pycache__/inference.cpython-36.pyc
--------------------------------------------------------------------------------
/camera/DllTest.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/camera/DllTest.dll
--------------------------------------------------------------------------------
/camera/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/camera/__init__.py
--------------------------------------------------------------------------------
/camera/test.py:
--------------------------------------------------------------------------------
1 | from ctypes import *
2 |
3 | dll = CDLL("DllTest.dll")
4 | print(dll.add(10, 102))
5 |
--------------------------------------------------------------------------------
/mv/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/mv/__init__.py
--------------------------------------------------------------------------------
/mv/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | 电子烟雾化器装配检测,效果评估
3 | author: 王建坤
4 | date: 2019-3-6
5 | """
6 | import os
7 | from mv import cig_detection
8 | import numpy as np
9 |
10 | ROOT_DIR = 'E:/backup/cigarette/'
11 | CLASSES = 5
12 |
13 |
14 | def evaluate():
15 | """
16 |
17 | :return:
18 | """
19 | img_path_list, labels = [], []
20 |
21 | class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4}
22 | # 遍历文件夹
23 | folder_list = os.listdir(ROOT_DIR)
24 | for folder in folder_list:
25 | # 过滤不需要的文件夹
26 | if folder not in class_dic.keys():
27 | continue
28 | folder_path = os.path.join(ROOT_DIR, folder)
29 | img_list = os.listdir(folder_path)
30 | for img_name in img_list:
31 | img_path = os.path.join(folder_path, img_name)
32 | img_path_list.append(img_path)
33 | labels.append(class_dic[folder])
34 |
35 | res = []
36 | for img, label in zip(img_path_list, labels):
37 | print(img)
38 | detector = cig_detection.AssembleDetection(img)
39 | detector.detect()
40 | if sum(detector.res) == 0:
41 | res.append(0)
42 | else:
43 | for i in range(len(detector.res)):
44 | if detector.res[i] == 1:
45 | if i < 2:
46 | res.append(i+1)
47 | break
48 | elif i == 2:
49 | res.append(4)
50 | break
51 | else:
52 | res.append(i)
53 | break
54 |
55 | sample_labels = np.zeros(CLASSES, dtype=np.uint16)
56 | pre_pos = np.zeros(CLASSES, dtype=np.uint16)
57 | true_pos = np.zeros(CLASSES, dtype=np.uint16)
58 |
59 | test_num = len(res)
60 | normal_num = 0
61 | for i in range(test_num):
62 | pre_pos[res[i]] += 1
63 | sample_labels[labels[i]] += 1
64 | if res[i] == labels[i]:
65 | true_pos[res[i]] += 1
66 | if labels[i] == 0:
67 | normal_num += 1
68 |
69 | precision = true_pos/pre_pos
70 | recall = true_pos/sample_labels
71 | f1 = 2*precision*recall/(precision+recall)
72 | error = (pre_pos - true_pos) / (test_num - sample_labels)
73 | print('测试样本数:', test_num)
74 | print('测试数:', sample_labels)
75 | print('预测数:', pre_pos)
76 | print('正确数:', true_pos)
77 | print('Precision:', precision)
78 | print('Recall:', recall)
79 | print('F1:', f1)
80 | print('误检率:', error)
81 | print('准确率:', np.sum(true_pos)/test_num)
82 | print('总漏检率:', (pre_pos[0]-true_pos[0])/(test_num-normal_num))
83 | print('总过杀率:', (normal_num-true_pos[0])/normal_num)
84 |
85 |
86 | if __name__ == '__main__':
87 | evaluate()
88 |
89 |
90 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/nets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/alexnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/alexnet.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/inception_resnet_v2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/inception_resnet_v2.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/inception_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/inception_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/inception_v4.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/inception_v4.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/mobilenet_v1.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/mobilenet_v1.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/resnet_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/resnet_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/resnet_v2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/resnet_v2.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/vgg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/vgg.cpython-36.pyc
--------------------------------------------------------------------------------
/nets/alexnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a model definition for AlexNet.
16 |
17 | This work was first described in:
18 | ImageNet Classification with Deep Convolutional Neural Networks
19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
20 |
21 | and later refined in:
22 | One weird trick for parallelizing convolutional neural networks
23 | Alex Krizhevsky, 2014
24 |
25 | Here we provide the implementation proposed in "One weird trick" and not
26 | "ImageNet Classification", as per the paper, the LRN layers have been removed.
27 |
28 | Usage:
29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
30 | outputs, end_points = alexnet.alexnet_v2(inputs)
31 |
32 | @@alexnet_v2
33 | """
34 |
35 | from __future__ import absolute_import
36 | from __future__ import division
37 | from __future__ import print_function
38 |
39 | import tensorflow as tf
40 |
41 | import tensorflow.contrib.slim as slim
42 |
43 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
44 |
45 |
46 | def alexnet_v2_arg_scope(weight_decay=0.0005):
47 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
48 | activation_fn=tf.nn.relu,
49 | biases_initializer=tf.constant_initializer(0.1),
50 | weights_regularizer=slim.l2_regularizer(weight_decay)):
51 | with slim.arg_scope([slim.conv2d], padding='SAME'):
52 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
53 | return arg_sc
54 |
55 |
56 | def alexnet_v2(inputs,
57 | num_classes=1000,
58 | is_training=True,
59 | dropout_keep_prob=0.5,
60 | spatial_squeeze=True,
61 | scope='alexnet_v2',
62 | global_pool=False):
63 | """AlexNet version 2.
64 |
65 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf
66 | Parameters from:
67 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
68 | layers-imagenet-1gpu.cfg
69 |
70 | Note: All the fully_connected layers have been transformed to conv2d layers.
71 | To use in classification mode, resize input to 224x224 or set
72 | global_pool=True. To use in fully convolutional mode, set
73 | spatial_squeeze to false.
74 | The LRN layers have been removed and change the initializers from
75 | random_normal_initializer to xavier_initializer.
76 |
77 | Args:
78 | inputs: a tensor of size [batch_size, height, width, channels].
79 | num_classes: the number of predicted classes. If 0 or None, the logits layer
80 | is omitted and the input features to the logits layer are returned instead.
81 | is_training: whether or not the model is being trained.
82 | dropout_keep_prob: the probability that activations are kept in the dropout
83 | layers during training.
84 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
85 | logits. Useful to remove unnecessary dimensions for classification.
86 | scope: Optional scope for the variables.
87 | global_pool: Optional boolean flag. If True, the input to the classification
88 | layer is avgpooled to size 1x1, for any input size. (This is not part
89 | of the original AlexNet.)
90 |
91 | Returns:
92 | net: the output of the logits layer (if num_classes is a non-zero integer),
93 | or the non-dropped-out input to the logits layer (if num_classes is 0
94 | or None).
95 | end_points: a dict of tensors with intermediate activations.
96 | """
97 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
98 | end_points_collection = sc.original_name_scope + '_end_points'
99 | # Collect outputs for conv2d, fully_connected and max_pool2d.
100 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
101 | outputs_collections=[end_points_collection]):
102 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
103 | scope='conv1')
104 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
105 | net = slim.conv2d(net, 192, [5, 5], scope='conv2')
106 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
107 | net = slim.conv2d(net, 384, [3, 3], scope='conv3')
108 | net = slim.conv2d(net, 384, [3, 3], scope='conv4')
109 | net = slim.conv2d(net, 256, [3, 3], scope='conv5')
110 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
111 |
112 | # Use conv2d instead of fully_connected layers.
113 | with slim.arg_scope([slim.conv2d],
114 | weights_initializer=trunc_normal(0.005),
115 | biases_initializer=tf.constant_initializer(0.1)):
116 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
117 | scope='fc6')
118 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
119 | scope='dropout6')
120 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
121 | # Convert end_points_collection into a end_point dict.
122 | end_points = slim.utils.convert_collection_to_dict(
123 | end_points_collection)
124 | if global_pool:
125 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
126 | end_points['global_pool'] = net
127 | if num_classes:
128 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
129 | scope='dropout7')
130 | net = slim.conv2d(net, num_classes, [1, 1],
131 | activation_fn=None,
132 | normalizer_fn=None,
133 | biases_initializer=tf.zeros_initializer(),
134 | scope='fc8')
135 | if spatial_squeeze:
136 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
137 | end_points[sc.name + '/fc8'] = net
138 | return net, end_points
139 |
140 |
141 | alexnet_v2.default_image_size = 224
142 |
--------------------------------------------------------------------------------
/nets/cifarnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a variant of the CIFAR-10 model definition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
26 |
27 |
28 | def cifarnet(images, num_classes=10, is_training=False,
29 | dropout_keep_prob=0.5,
30 | prediction_fn=slim.softmax,
31 | scope='CifarNet'):
32 | """Creates a variant of the CifarNet model.
33 |
34 | Note that since the output is a set of 'logits', the values fall in the
35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a
36 | probability distribution over the characters, one will need to convert them
37 | using the softmax function:
38 |
39 | logits = cifarnet.cifarnet(images, is_training=False)
40 | probabilities = tf.nn.softmax(logits)
41 | predictions = tf.argmax(logits, 1)
42 |
43 | Args:
44 | images: A batch of `Tensors` of size [batch_size, height, width, channels].
45 | num_classes: the number of classes in the dataset. If 0 or None, the logits
46 | layer is omitted and the input features to the logits layer are returned
47 | instead.
48 | is_training: specifies whether or not we're currently training the model.
49 | This variable will determine the behaviour of the dropout layer.
50 | dropout_keep_prob: the percentage of activation values that are retained.
51 | prediction_fn: a function to get predictions out of logits.
52 | scope: Optional variable_scope.
53 |
54 | Returns:
55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
56 | is a non-zero integer, or the input to the logits layer if num_classes
57 | is 0 or None.
58 | end_points: a dictionary from components of the network to the corresponding
59 | activation.
60 | """
61 | end_points = {}
62 |
63 | with tf.variable_scope(scope, 'CifarNet', [images]):
64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1')
65 | end_points['conv1'] = net
66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
67 | end_points['pool1'] = net
68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2')
70 | end_points['conv2'] = net
71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
73 | end_points['pool2'] = net
74 | net = slim.flatten(net)
75 | end_points['Flatten'] = net
76 | net = slim.fully_connected(net, 384, scope='fc3')
77 | end_points['fc3'] = net
78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
79 | scope='dropout3')
80 | net = slim.fully_connected(net, 192, scope='fc4')
81 | end_points['fc4'] = net
82 | if not num_classes:
83 | return net, end_points
84 | logits = slim.fully_connected(net, num_classes,
85 | biases_initializer=tf.zeros_initializer(),
86 | weights_initializer=trunc_normal(1/192.0),
87 | weights_regularizer=None,
88 | activation_fn=None,
89 | scope='logits')
90 |
91 | end_points['Logits'] = logits
92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
93 |
94 | return logits, end_points
95 | cifarnet.default_image_size = 32
96 |
97 |
98 | def cifarnet_arg_scope(weight_decay=0.004):
99 | """Defines the default cifarnet argument scope.
100 |
101 | Args:
102 | weight_decay: The weight decay to use for regularizing the model.
103 |
104 | Returns:
105 | An `arg_scope` to use for the inception v3 model.
106 | """
107 | with slim.arg_scope(
108 | [slim.conv2d],
109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
110 | activation_fn=tf.nn.relu):
111 | with slim.arg_scope(
112 | [slim.fully_connected],
113 | biases_initializer=tf.constant_initializer(0.1),
114 | weights_initializer=trunc_normal(0.04),
115 | weights_regularizer=slim.l2_regularizer(weight_decay),
116 | activation_fn=tf.nn.relu) as sc:
117 | return sc
118 |
--------------------------------------------------------------------------------
/nets/cyclegan_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for tensorflow.contrib.slim.nets.cyclegan."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from nets import cyclegan
24 |
25 |
26 | # TODO(joelshor): Add a test to check generator endpoints.
27 | class CycleganTest(tf.test.TestCase):
28 |
29 | def test_generator_inference(self):
30 | """Check one inference step."""
31 | img_batch = tf.zeros([2, 32, 32, 3])
32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch)
33 | with self.test_session() as sess:
34 | sess.run(tf.global_variables_initializer())
35 | sess.run(model_output)
36 |
37 | def _test_generator_graph_helper(self, shape):
38 | """Check that generator can take small and non-square inputs."""
39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape))
40 | self.assertAllEqual(shape, output_imgs.shape.as_list())
41 |
42 | def test_generator_graph_small(self):
43 | self._test_generator_graph_helper([4, 32, 32, 3])
44 |
45 | def test_generator_graph_medium(self):
46 | self._test_generator_graph_helper([3, 128, 128, 3])
47 |
48 | def test_generator_graph_nonsquare(self):
49 | self._test_generator_graph_helper([2, 80, 400, 3])
50 |
51 | def test_generator_unknown_batch_dim(self):
52 | """Check that generator can take unknown batch dimension inputs."""
53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3])
54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img)
55 |
56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list())
57 |
58 | def _input_and_output_same_shape_helper(self, kernel_size):
59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet(
61 | img_batch, kernel_size=kernel_size)
62 |
63 | self.assertAllEqual(img_batch.shape.as_list(),
64 | output_img_batch.shape.as_list())
65 |
66 | def input_and_output_same_shape_kernel3(self):
67 | self._input_and_output_same_shape_helper(3)
68 |
69 | def input_and_output_same_shape_kernel4(self):
70 | self._input_and_output_same_shape_helper(4)
71 |
72 | def input_and_output_same_shape_kernel5(self):
73 | self._input_and_output_same_shape_helper(5)
74 |
75 | def input_and_output_same_shape_kernel6(self):
76 | self._input_and_output_same_shape_helper(6)
77 |
78 | def _error_if_height_not_multiple_of_four_helper(self, height):
79 | self.assertRaisesRegexp(
80 | ValueError,
81 | 'The input height must be a multiple of 4.',
82 | cyclegan.cyclegan_generator_resnet,
83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3]))
84 |
85 | def test_error_if_height_not_multiple_of_four_height29(self):
86 | self._error_if_height_not_multiple_of_four_helper(29)
87 |
88 | def test_error_if_height_not_multiple_of_four_height30(self):
89 | self._error_if_height_not_multiple_of_four_helper(30)
90 |
91 | def test_error_if_height_not_multiple_of_four_height31(self):
92 | self._error_if_height_not_multiple_of_four_helper(31)
93 |
94 | def _error_if_width_not_multiple_of_four_helper(self, width):
95 | self.assertRaisesRegexp(
96 | ValueError,
97 | 'The input width must be a multiple of 4.',
98 | cyclegan.cyclegan_generator_resnet,
99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3]))
100 |
101 | def test_error_if_width_not_multiple_of_four_width29(self):
102 | self._error_if_width_not_multiple_of_four_helper(29)
103 |
104 | def test_error_if_width_not_multiple_of_four_width30(self):
105 | self._error_if_width_not_multiple_of_four_helper(30)
106 |
107 | def test_error_if_width_not_multiple_of_four_width31(self):
108 | self._error_if_width_not_multiple_of_four_helper(31)
109 |
110 |
111 | if __name__ == '__main__':
112 | tf.test.main()
113 |
--------------------------------------------------------------------------------
/nets/dcgan_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for dcgan."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from six.moves import xrange # pylint: disable=redefined-builtin
22 | import tensorflow as tf
23 |
24 | from nets import dcgan
25 |
26 |
27 | class DCGANTest(tf.test.TestCase):
28 |
29 | def test_generator_run(self):
30 | tf.set_random_seed(1234)
31 | noise = tf.random_normal([100, 64])
32 | image, _ = dcgan.generator(noise)
33 | with self.test_session() as sess:
34 | sess.run(tf.global_variables_initializer())
35 | image.eval()
36 |
37 | def test_generator_graph(self):
38 | tf.set_random_seed(1234)
39 | # Check graph construction for a number of image size/depths and batch
40 | # sizes.
41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)):
42 | tf.reset_default_graph()
43 | final_size = 2 ** i
44 | noise = tf.random_normal([batch_size, 64])
45 | image, end_points = dcgan.generator(
46 | noise,
47 | depth=32,
48 | final_size=final_size)
49 |
50 | self.assertAllEqual([batch_size, final_size, final_size, 3],
51 | image.shape.as_list())
52 |
53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits']
54 | self.assertSetEqual(set(expected_names), set(end_points.keys()))
55 |
56 | # Check layer depths.
57 | for j in range(1, i):
58 | layer = end_points['deconv%i' % j]
59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1])
60 |
61 | def test_generator_invalid_input(self):
62 | wrong_dim_input = tf.zeros([5, 32, 32])
63 | with self.assertRaises(ValueError):
64 | dcgan.generator(wrong_dim_input)
65 |
66 | correct_input = tf.zeros([3, 2])
67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'):
68 | dcgan.generator(correct_input, final_size=30)
69 |
70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'):
71 | dcgan.generator(correct_input, final_size=4)
72 |
73 | def test_discriminator_run(self):
74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1)
75 | output, _ = dcgan.discriminator(image)
76 | with self.test_session() as sess:
77 | sess.run(tf.global_variables_initializer())
78 | output.eval()
79 |
80 | def test_discriminator_graph(self):
81 | # Check graph construction for a number of image size/depths and batch
82 | # sizes.
83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)):
84 | tf.reset_default_graph()
85 | img_w = 2 ** i
86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1)
87 | output, end_points = dcgan.discriminator(
88 | image,
89 | depth=32)
90 |
91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list())
92 |
93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits']
94 | self.assertSetEqual(set(expected_names), set(end_points.keys()))
95 |
96 | # Check layer depths.
97 | for j in range(1, i+1):
98 | layer = end_points['conv%i' % j]
99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1])
100 |
101 | def test_discriminator_invalid_input(self):
102 | wrong_dim_img = tf.zeros([5, 32, 32])
103 | with self.assertRaises(ValueError):
104 | dcgan.discriminator(wrong_dim_img)
105 |
106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3])
107 | with self.assertRaises(ValueError):
108 | dcgan.discriminator(spatially_undefined_shape)
109 |
110 | not_square = tf.zeros([5, 32, 16, 3])
111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'):
112 | dcgan.discriminator(not_square)
113 |
114 | not_power_2 = tf.zeros([5, 30, 30, 3])
115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'):
116 | dcgan.discriminator(not_power_2)
117 |
118 |
119 | if __name__ == '__main__':
120 | tf.test.main()
121 |
--------------------------------------------------------------------------------
/nets/inception.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Brings all inception models under one namespace."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # pylint: disable=unused-import
22 | from nets.inception_resnet_v2 import inception_resnet_v2
23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope
24 | from nets.inception_resnet_v2 import inception_resnet_v2_base
25 | from nets.inception_v1 import inception_v1
26 | from nets.inception_v1 import inception_v1_arg_scope
27 | from nets.inception_v1 import inception_v1_base
28 | from nets.inception_v2 import inception_v2
29 | from nets.inception_v2 import inception_v2_arg_scope
30 | from nets.inception_v2 import inception_v2_base
31 | from nets.inception_v3 import inception_v3
32 | from nets.inception_v3 import inception_v3_arg_scope
33 | from nets.inception_v3 import inception_v3_base
34 | from nets.inception_v4 import inception_v4
35 | from nets.inception_v4 import inception_v4_arg_scope
36 | from nets.inception_v4 import inception_v4_base
37 | # pylint: enable=unused-import
38 |
--------------------------------------------------------------------------------
/nets/inception_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains common code shared by all inception models.
16 |
17 | Usage of arg scope:
18 | with slim.arg_scope(inception_arg_scope()):
19 | logits, end_points = inception.inception_v3(images, num_classes,
20 | is_training=is_training)
21 |
22 | """
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import tensorflow as tf
28 |
29 | slim = tf.contrib.slim
30 |
31 |
32 | def inception_arg_scope(weight_decay=0.00004,
33 | use_batch_norm=True,
34 | batch_norm_decay=0.9997,
35 | batch_norm_epsilon=0.001,
36 | activation_fn=tf.nn.relu,
37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS):
38 | """Defines the default arg scope for inception models.
39 |
40 | Args:
41 | weight_decay: The weight decay to use for regularizing the model.
42 | use_batch_norm: "If `True`, batch_norm is applied after each convolution.
43 | batch_norm_decay: Decay for batch norm moving average.
44 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero
45 | in batch norm.
46 | activation_fn: Activation function for conv2d.
47 | batch_norm_updates_collections: Collection for the update ops for
48 | batch norm.
49 |
50 | Returns:
51 | An `arg_scope` to use for the inception models.
52 | """
53 | batch_norm_params = {
54 | # Decay for the moving averages.
55 | 'decay': batch_norm_decay,
56 | # epsilon to prevent 0s in variance.
57 | 'epsilon': batch_norm_epsilon,
58 | # collection containing update_ops.
59 | 'updates_collections': batch_norm_updates_collections,
60 | # use fused batch norm if possible.
61 | 'fused': None,
62 | }
63 | if use_batch_norm:
64 | normalizer_fn = slim.batch_norm
65 | normalizer_params = batch_norm_params
66 | else:
67 | normalizer_fn = None
68 | normalizer_params = {}
69 | # Set weight_decay for weights in Conv and FC layers.
70 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
71 | weights_regularizer=slim.l2_regularizer(weight_decay)):
72 | with slim.arg_scope(
73 | [slim.conv2d],
74 | weights_initializer=slim.variance_scaling_initializer(),
75 | activation_fn=activation_fn,
76 | normalizer_fn=normalizer_fn,
77 | normalizer_params=normalizer_params) as sc:
78 | return sc
79 |
--------------------------------------------------------------------------------
/nets/lenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a variant of the LeNet model definition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | def lenet(images, num_classes=10, is_training=False,
27 | dropout_keep_prob=0.5,
28 | prediction_fn=slim.softmax,
29 | scope='LeNet'):
30 | """Creates a variant of the LeNet model.
31 |
32 | Note that since the output is a set of 'logits', the values fall in the
33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a
34 | probability distribution over the characters, one will need to convert them
35 | using the softmax function:
36 |
37 | logits = lenet.lenet(images, is_training=False)
38 | probabilities = tf.nn.softmax(logits)
39 | predictions = tf.argmax(logits, 1)
40 |
41 | Args:
42 | images: A batch of `Tensors` of size [batch_size, height, width, channels].
43 | num_classes: the number of classes in the dataset. If 0 or None, the logits
44 | layer is omitted and the input features to the logits layer are returned
45 | instead.
46 | is_training: specifies whether or not we're currently training the model.
47 | This variable will determine the behaviour of the dropout layer.
48 | dropout_keep_prob: the percentage of activation values that are retained.
49 | prediction_fn: a function to get predictions out of logits.
50 | scope: Optional variable_scope.
51 |
52 | Returns:
53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer
55 | if num_classes is 0 or None.
56 | end_points: a dictionary from components of the network to the corresponding
57 | activation.
58 | """
59 | end_points = {}
60 |
61 | with tf.variable_scope(scope, 'LeNet', [images]):
62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2')
65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
66 | net = slim.flatten(net)
67 | end_points['Flatten'] = net
68 |
69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3')
70 | if not num_classes:
71 | return net, end_points
72 | net = end_points['dropout3'] = slim.dropout(
73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3')
74 | logits = end_points['Logits'] = slim.fully_connected(
75 | net, num_classes, activation_fn=None, scope='fc4')
76 |
77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
78 |
79 | return logits, end_points
80 | lenet.default_image_size = 28
81 |
82 |
83 | def lenet_arg_scope(weight_decay=0.0):
84 | """Defines the default lenet argument scope.
85 |
86 | Args:
87 | weight_decay: The weight decay to use for regularizing the model.
88 |
89 | Returns:
90 | An `arg_scope` to use for the inception v3 model.
91 | """
92 | with slim.arg_scope(
93 | [slim.conv2d, slim.fully_connected],
94 | weights_regularizer=slim.l2_regularizer(weight_decay),
95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
96 | activation_fn=tf.nn.relu) as sc:
97 | return sc
98 |
--------------------------------------------------------------------------------
/nets/mobilenet/README.md:
--------------------------------------------------------------------------------
1 | # MobileNetV2
2 | This folder contains building code for MobileNetV2, based on
3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381)
4 |
5 | # Performance
6 | ## Latency
7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using
8 | TF-Lite on the large core of Pixel 1 phone.
9 |
10 | 
11 |
12 | ## MACs
13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed
14 | to compute an inference on a single image is a common metric to measure the efficiency of the model.
15 |
16 | Below is the graph comparing V2 vs a few selected networks. The size
17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there
18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers.
19 |
20 | 
21 |
22 | # Pretrained models
23 | ## Imagenet Checkpoints
24 |
25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1
26 | ---------------------------|---------|---------------|---------|----|-------------
27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0
28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0
29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8
30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1
31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2
32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6
33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6
34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8
35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6
36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4
37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9
38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2
39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7
40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1
41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9
42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9
43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4
44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7
45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6
46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5
47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9
48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5
49 |
50 | # Training
51 | The numbers above can be reproduced using slim's `train_image_classifier`.
52 | Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU.
53 | If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and
54 | num_epochs_per_decay both need to be adjusted depending on how many GPUs are being
55 | used due to slim's internal averaging.
56 |
57 | ```bash
58 | --model_name="mobilenet_v2"
59 | --learning_rate=0.045 * NUM_GPUS #slim internally averages clones so we compensate
60 | --preprocessing_name="inception_v2"
61 | --label_smoothing=0.1
62 | --moving_average_decay=0.9999
63 | --batch_size= 96
64 | --num_clones = NUM_GPUS # you can use any number here between 1 and 8 depending on your hardware setup.
65 | --learning_rate_decay_factor=0.98
66 | --num_epochs_per_decay = 2.5 / NUM_GPUS # train_image_classifier does per clone epochs
67 | ```
68 |
69 | # Example
70 |
71 |
72 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb).
73 |
74 |
--------------------------------------------------------------------------------
/nets/mobilenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet/__init__.py
--------------------------------------------------------------------------------
/nets/mobilenet/madds_top1_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet/madds_top1_accuracy.png
--------------------------------------------------------------------------------
/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png
--------------------------------------------------------------------------------
/nets/mobilenet_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet_v1.png
--------------------------------------------------------------------------------
/nets/mobilenet_v1_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Validate mobilenet_v1 with options for quantization."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import math
22 | import tensorflow as tf
23 |
24 | from datasets import dataset_factory
25 | from nets import mobilenet_v1
26 | from preprocessing import preprocessing_factory
27 |
28 | slim = tf.contrib.slim
29 |
30 | flags = tf.app.flags
31 |
32 | flags.DEFINE_string('master', '', 'Session master')
33 | flags.DEFINE_integer('batch_size', 250, 'Batch size')
34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish')
35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate')
36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution')
37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet')
38 | flags.DEFINE_bool('quantize', False, 'Quantize training')
39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints')
40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs')
41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset')
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | def imagenet_input(is_training):
47 | """Data reader for imagenet.
48 |
49 | Reads in imagenet data and performs pre-processing on the images.
50 |
51 | Args:
52 | is_training: bool specifying if train or validation dataset is needed.
53 | Returns:
54 | A batch of images and labels.
55 | """
56 | if is_training:
57 | dataset = dataset_factory.get_dataset('imagenet', 'train',
58 | FLAGS.dataset_dir)
59 | else:
60 | dataset = dataset_factory.get_dataset('imagenet', 'validation',
61 | FLAGS.dataset_dir)
62 |
63 | provider = slim.dataset_data_provider.DatasetDataProvider(
64 | dataset,
65 | shuffle=is_training,
66 | common_queue_capacity=2 * FLAGS.batch_size,
67 | common_queue_min=FLAGS.batch_size)
68 | [image, label] = provider.get(['image', 'label'])
69 |
70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing(
71 | 'mobilenet_v1', is_training=is_training)
72 |
73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
74 |
75 | images, labels = tf.train.batch(
76 | tensors=[image, label],
77 | batch_size=FLAGS.batch_size,
78 | num_threads=4,
79 | capacity=5 * FLAGS.batch_size)
80 | return images, labels
81 |
82 |
83 | def metrics(logits, labels):
84 | """Specify the metrics for eval.
85 |
86 | Args:
87 | logits: Logits output from the graph.
88 | labels: Ground truth labels for inputs.
89 |
90 | Returns:
91 | Eval Op for the graph.
92 | """
93 | labels = tf.squeeze(labels)
94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels),
96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5),
97 | })
98 | for name, value in names_to_values.iteritems():
99 | slim.summaries.add_scalar_summary(
100 | value, name, prefix='eval', print_summary=True)
101 | return names_to_updates.values()
102 |
103 |
104 | def build_model():
105 | """Build the mobilenet_v1 model for evaluation.
106 |
107 | Returns:
108 | g: graph with rewrites after insertion of quantization ops and batch norm
109 | folding.
110 | eval_ops: eval ops for inference.
111 | variables_to_restore: List of variables to restore from checkpoint.
112 | """
113 | g = tf.Graph()
114 | with g.as_default():
115 | inputs, labels = imagenet_input(is_training=False)
116 |
117 | scope = mobilenet_v1.mobilenet_v1_arg_scope(
118 | is_training=False, weight_decay=0.0)
119 | with slim.arg_scope(scope):
120 | logits, _ = mobilenet_v1.mobilenet_v1(
121 | inputs,
122 | is_training=False,
123 | depth_multiplier=FLAGS.depth_multiplier,
124 | num_classes=FLAGS.num_classes)
125 |
126 | if FLAGS.quantize:
127 | tf.contrib.quantize.create_eval_graph()
128 |
129 | eval_ops = metrics(logits, labels)
130 |
131 | return g, eval_ops
132 |
133 |
134 | def eval_model():
135 | """Evaluates mobilenet_v1."""
136 | g, eval_ops = build_model()
137 | with g.as_default():
138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
139 | slim.evaluation.evaluate_once(
140 | FLAGS.master,
141 | FLAGS.checkpoint_dir,
142 | logdir=FLAGS.eval_dir,
143 | num_evals=num_batches,
144 | eval_op=eval_ops)
145 |
146 |
147 | def main(unused_arg):
148 | eval_model()
149 |
150 |
151 | if __name__ == '__main__':
152 | tf.app.run(main)
153 |
--------------------------------------------------------------------------------
/nets/nasnet/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints
2 | This directory contains the code for the NASNet-A model from the paper
3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al.
4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the
5 | other two are variants of NASNet-A trained on ImageNet, which are listed below.
6 |
7 | # Pre-Trained Models
8 | Two NASNet-A checkpoints are available that have been trained on the
9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/)
10 | image classification dataset. Accuracies were computed by evaluating using a single image crop.
11 |
12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy |
13 | :----:|:------------:|:----------:|:-------:|:-------:|
14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6|
15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2|
16 |
17 |
18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same.
19 |
20 | ```shell
21 | CHECKPOINT_DIR=/tmp/checkpoints
22 | mkdir ${CHECKPOINT_DIR}
23 | cd ${CHECKPOINT_DIR}
24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz
25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz
26 | rm nasnet-a_mobile_04_10_2017.tar.gz
27 | ```
28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md).
29 |
30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
31 |
32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference
33 | -------
34 | Run eval with the NASNet-A mobile ImageNet model
35 |
36 | ```shell
37 | DATASET_DIR=/tmp/imagenet
38 | EVAL_DIR=/tmp/tfmodel/eval
39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt
40 | python tensorflow_models/research/slim/eval_image_classifier \
41 | --checkpoint_path=${CHECKPOINT_DIR} \
42 | --eval_dir=${EVAL_DIR} \
43 | --dataset_dir=${DATASET_DIR} \
44 | --dataset_name=imagenet \
45 | --dataset_split_name=validation \
46 | --model_name=nasnet_mobile \
47 | --eval_image_size=224
48 | ```
49 |
50 | Run eval with the NASNet-A large ImageNet model
51 |
52 | ```shell
53 | DATASET_DIR=/tmp/imagenet
54 | EVAL_DIR=/tmp/tfmodel/eval
55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt
56 | python tensorflow_models/research/slim/eval_image_classifier \
57 | --checkpoint_path=${CHECKPOINT_DIR} \
58 | --eval_dir=${EVAL_DIR} \
59 | --dataset_dir=${DATASET_DIR} \
60 | --dataset_name=imagenet \
61 | --dataset_split_name=validation \
62 | --model_name=nasnet_large \
63 | --eval_image_size=331
64 | ```
65 |
--------------------------------------------------------------------------------
/nets/nasnet/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/nets/nasnet/nasnet_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for slim.nets.nasnet.nasnet_utils."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from nets.nasnet import nasnet_utils
24 |
25 |
26 | class NasnetUtilsTest(tf.test.TestCase):
27 |
28 | def testCalcReductionLayers(self):
29 | num_cells = 18
30 | num_reduction_layers = 2
31 | reduction_layers = nasnet_utils.calc_reduction_layers(
32 | num_cells, num_reduction_layers)
33 | self.assertEqual(len(reduction_layers), 2)
34 | self.assertEqual(reduction_layers[0], 6)
35 | self.assertEqual(reduction_layers[1], 12)
36 |
37 | def testGetChannelIndex(self):
38 | data_formats = ['NHWC', 'NCHW']
39 | for data_format in data_formats:
40 | index = nasnet_utils.get_channel_index(data_format)
41 | correct_index = 3 if data_format == 'NHWC' else 1
42 | self.assertEqual(index, correct_index)
43 |
44 | def testGetChannelDim(self):
45 | data_formats = ['NHWC', 'NCHW']
46 | shape = [10, 20, 30, 40]
47 | for data_format in data_formats:
48 | dim = nasnet_utils.get_channel_dim(shape, data_format)
49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1]
50 | self.assertEqual(dim, correct_dim)
51 |
52 | def testGlobalAvgPool(self):
53 | data_formats = ['NHWC', 'NCHW']
54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10))
55 | for data_format in data_formats:
56 | output = nasnet_utils.global_avg_pool(
57 | inputs, data_format)
58 | self.assertEqual(output.shape, [5, 10])
59 |
60 |
61 | if __name__ == '__main__':
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/nets/nets_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import functools
21 |
22 | import tensorflow as tf
23 |
24 | from nets import alexnet
25 | from nets import cifarnet
26 | from nets import inception
27 | from nets import lenet
28 | from nets import mobilenet_v1
29 | from nets import overfeat
30 | from nets import resnet_v1
31 | from nets import resnet_v2
32 | from nets import vgg
33 | from nets.mobilenet import mobilenet_v2
34 | from nets.nasnet import nasnet
35 | from nets.nasnet import pnasnet
36 |
37 | slim = tf.contrib.slim
38 |
39 | networks_map = {'alexnet_v2': alexnet.alexnet_v2,
40 | 'cifarnet': cifarnet.cifarnet,
41 | 'overfeat': overfeat.overfeat,
42 | 'vgg_a': vgg.vgg_a,
43 | 'vgg_16': vgg.vgg_16,
44 | 'vgg_19': vgg.vgg_19,
45 | 'inception_v1': inception.inception_v1,
46 | 'inception_v2': inception.inception_v2,
47 | 'inception_v3': inception.inception_v3,
48 | 'inception_v4': inception.inception_v4,
49 | 'inception_resnet_v2': inception.inception_resnet_v2,
50 | 'lenet': lenet.lenet,
51 | 'resnet_v1_50': resnet_v1.resnet_v1_50,
52 | 'resnet_v1_101': resnet_v1.resnet_v1_101,
53 | 'resnet_v1_152': resnet_v1.resnet_v1_152,
54 | 'resnet_v1_200': resnet_v1.resnet_v1_200,
55 | 'resnet_v2_50': resnet_v2.resnet_v2_50,
56 | 'resnet_v2_101': resnet_v2.resnet_v2_101,
57 | 'resnet_v2_152': resnet_v2.resnet_v2_152,
58 | 'resnet_v2_200': resnet_v2.resnet_v2_200,
59 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1,
60 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075,
61 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050,
62 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025,
63 | 'mobilenet_v2': mobilenet_v2.mobilenet,
64 | 'mobilenet_v2_140': mobilenet_v2.mobilenet_v2_140,
65 | 'mobilenet_v2_035': mobilenet_v2.mobilenet_v2_035,
66 | 'nasnet_cifar': nasnet.build_nasnet_cifar,
67 | 'nasnet_mobile': nasnet.build_nasnet_mobile,
68 | 'nasnet_large': nasnet.build_nasnet_large,
69 | 'pnasnet_large': pnasnet.build_pnasnet_large,
70 | 'pnasnet_mobile': pnasnet.build_pnasnet_mobile,
71 | }
72 |
73 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
74 | 'cifarnet': cifarnet.cifarnet_arg_scope,
75 | 'overfeat': overfeat.overfeat_arg_scope,
76 | 'vgg_a': vgg.vgg_arg_scope,
77 | 'vgg_16': vgg.vgg_arg_scope,
78 | 'vgg_19': vgg.vgg_arg_scope,
79 | 'inception_v1': inception.inception_v3_arg_scope,
80 | 'inception_v2': inception.inception_v3_arg_scope,
81 | 'inception_v3': inception.inception_v3_arg_scope,
82 | 'inception_v4': inception.inception_v4_arg_scope,
83 | 'inception_resnet_v2':
84 | inception.inception_resnet_v2_arg_scope,
85 | 'lenet': lenet.lenet_arg_scope,
86 | 'resnet_v1_50': resnet_v1.resnet_arg_scope,
87 | 'resnet_v1_101': resnet_v1.resnet_arg_scope,
88 | 'resnet_v1_152': resnet_v1.resnet_arg_scope,
89 | 'resnet_v1_200': resnet_v1.resnet_arg_scope,
90 | 'resnet_v2_50': resnet_v2.resnet_arg_scope,
91 | 'resnet_v2_101': resnet_v2.resnet_arg_scope,
92 | 'resnet_v2_152': resnet_v2.resnet_arg_scope,
93 | 'resnet_v2_200': resnet_v2.resnet_arg_scope,
94 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope,
95 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope,
96 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope,
97 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope,
98 | 'mobilenet_v2': mobilenet_v2.training_scope,
99 | 'mobilenet_v2_035': mobilenet_v2.training_scope,
100 | 'mobilenet_v2_140': mobilenet_v2.training_scope,
101 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope,
102 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope,
103 | 'nasnet_large': nasnet.nasnet_large_arg_scope,
104 | 'pnasnet_large': pnasnet.pnasnet_large_arg_scope,
105 | 'pnasnet_mobile': pnasnet.pnasnet_mobile_arg_scope,
106 | }
107 |
108 |
109 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
110 | """Returns a network_fn such as `logits, end_points = network_fn(images)`.
111 |
112 | Args:
113 | name: The name of the network.
114 | num_classes: The number of classes to use for classification. If 0 or None,
115 | the logits layer is omitted and its input features are returned instead.
116 | weight_decay: The l2 coefficient for the model weights.
117 | is_training: `True` if the model is being used for training and `False`
118 | otherwise.
119 |
120 | Returns:
121 | network_fn: A function that applies the model to a batch of images. It has
122 | the following signature:
123 | net, end_points = network_fn(images)
124 | The `images` input is a tensor of shape [batch_size, height, width, 3]
125 | with height = width = network_fn.default_image_size. (The permissibility
126 | and treatment of other sizes depends on the network_fn.)
127 | The returned `end_points` are a dictionary of intermediate activations.
128 | The returned `net` is the topmost layer, depending on `num_classes`:
129 | If `num_classes` was a non-zero integer, `net` is a logits tensor
130 | of shape [batch_size, num_classes].
131 | If `num_classes` was 0 or `None`, `net` is a tensor with the input
132 | to the logits layer of shape [batch_size, 1, 1, num_features] or
133 | [batch_size, num_features]. Dropout has not been applied to this
134 | (even if the network's original classification does); it remains for
135 | the caller to do this or not.
136 |
137 | Raises:
138 | ValueError: If network `name` is not recognized.
139 | """
140 | if name not in networks_map:
141 | raise ValueError('Name of network unknown %s' % name)
142 | func = networks_map[name]
143 | @functools.wraps(func)
144 | def network_fn(images, **kwargs):
145 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
146 | with slim.arg_scope(arg_scope):
147 | return func(images, num_classes, is_training=is_training, **kwargs)
148 | if hasattr(func, 'default_image_size'):
149 | network_fn.default_image_size = func.default_image_size
150 |
151 | return network_fn
152 |
--------------------------------------------------------------------------------
/nets/nets_factory_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for slim.inception."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | import tensorflow as tf
24 |
25 | from nets import nets_factory
26 |
27 |
28 | class NetworksTest(tf.test.TestCase):
29 |
30 | def testGetNetworkFnFirstHalf(self):
31 | batch_size = 5
32 | num_classes = 1000
33 | for net in list(nets_factory.networks_map.keys())[:10]:
34 | with tf.Graph().as_default() as g, self.test_session(g):
35 | net_fn = nets_factory.get_network_fn(net, num_classes)
36 | # Most networks use 224 as their default_image_size
37 | image_size = getattr(net_fn, 'default_image_size', 224)
38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
39 | logits, end_points = net_fn(inputs)
40 | self.assertTrue(isinstance(logits, tf.Tensor))
41 | self.assertTrue(isinstance(end_points, dict))
42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size)
43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
44 |
45 | def testGetNetworkFnSecondHalf(self):
46 | batch_size = 5
47 | num_classes = 1000
48 | for net in list(nets_factory.networks_map.keys())[10:]:
49 | with tf.Graph().as_default() as g, self.test_session(g):
50 | net_fn = nets_factory.get_network_fn(net, num_classes)
51 | # Most networks use 224 as their default_image_size
52 | image_size = getattr(net_fn, 'default_image_size', 224)
53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
54 | logits, end_points = net_fn(inputs)
55 | self.assertTrue(isinstance(logits, tf.Tensor))
56 | self.assertTrue(isinstance(end_points, dict))
57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size)
58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
59 |
60 | if __name__ == '__main__':
61 | tf.test.main()
62 |
--------------------------------------------------------------------------------
/nets/overfeat.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains the model definition for the OverFeat network.
16 |
17 | The definition for the network was obtained from:
18 | OverFeat: Integrated Recognition, Localization and Detection using
19 | Convolutional Networks
20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
21 | Yann LeCun, 2014
22 | http://arxiv.org/abs/1312.6229
23 |
24 | Usage:
25 | with slim.arg_scope(overfeat.overfeat_arg_scope()):
26 | outputs, end_points = overfeat.overfeat(inputs)
27 |
28 | @@overfeat
29 | """
30 | from __future__ import absolute_import
31 | from __future__ import division
32 | from __future__ import print_function
33 |
34 | import tensorflow as tf
35 |
36 | slim = tf.contrib.slim
37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
38 |
39 |
40 | def overfeat_arg_scope(weight_decay=0.0005):
41 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
42 | activation_fn=tf.nn.relu,
43 | weights_regularizer=slim.l2_regularizer(weight_decay),
44 | biases_initializer=tf.zeros_initializer()):
45 | with slim.arg_scope([slim.conv2d], padding='SAME'):
46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
47 | return arg_sc
48 |
49 |
50 | def overfeat(inputs,
51 | num_classes=1000,
52 | is_training=True,
53 | dropout_keep_prob=0.5,
54 | spatial_squeeze=True,
55 | scope='overfeat',
56 | global_pool=False):
57 | """Contains the model definition for the OverFeat network.
58 |
59 | The definition for the network was obtained from:
60 | OverFeat: Integrated Recognition, Localization and Detection using
61 | Convolutional Networks
62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
63 | Yann LeCun, 2014
64 | http://arxiv.org/abs/1312.6229
65 |
66 | Note: All the fully_connected layers have been transformed to conv2d layers.
67 | To use in classification mode, resize input to 231x231. To use in fully
68 | convolutional mode, set spatial_squeeze to false.
69 |
70 | Args:
71 | inputs: a tensor of size [batch_size, height, width, channels].
72 | num_classes: number of predicted classes. If 0 or None, the logits layer is
73 | omitted and the input features to the logits layer are returned instead.
74 | is_training: whether or not the model is being trained.
75 | dropout_keep_prob: the probability that activations are kept in the dropout
76 | layers during training.
77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
78 | outputs. Useful to remove unnecessary dimensions for classification.
79 | scope: Optional scope for the variables.
80 | global_pool: Optional boolean flag. If True, the input to the classification
81 | layer is avgpooled to size 1x1, for any input size. (This is not part
82 | of the original OverFeat.)
83 |
84 | Returns:
85 | net: the output of the logits layer (if num_classes is a non-zero integer),
86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or
87 | None).
88 | end_points: a dict of tensors with intermediate activations.
89 | """
90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc:
91 | end_points_collection = sc.original_name_scope + '_end_points'
92 | # Collect outputs for conv2d, fully_connected and max_pool2d
93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
94 | outputs_collections=end_points_collection):
95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
96 | scope='conv1')
97 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2')
99 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3')
101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4')
102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5')
103 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
104 |
105 | # Use conv2d instead of fully_connected layers.
106 | with slim.arg_scope([slim.conv2d],
107 | weights_initializer=trunc_normal(0.005),
108 | biases_initializer=tf.constant_initializer(0.1)):
109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6')
110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
111 | scope='dropout6')
112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
113 | # Convert end_points_collection into a end_point dict.
114 | end_points = slim.utils.convert_collection_to_dict(
115 | end_points_collection)
116 | if global_pool:
117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
118 | end_points['global_pool'] = net
119 | if num_classes:
120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
121 | scope='dropout7')
122 | net = slim.conv2d(net, num_classes, [1, 1],
123 | activation_fn=None,
124 | normalizer_fn=None,
125 | biases_initializer=tf.zeros_initializer(),
126 | scope='fc8')
127 | if spatial_squeeze:
128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
129 | end_points[sc.name + '/fc8'] = net
130 | return net, end_points
131 | overfeat.default_image_size = 231
132 |
--------------------------------------------------------------------------------
/nets/pix2pix_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 | """Tests for pix2pix."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 | from nets import pix2pix
23 |
24 |
25 | class GeneratorTest(tf.test.TestCase):
26 |
27 | def _reduced_default_blocks(self):
28 | """Returns the default blocks, scaled down to make test run faster."""
29 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob)
30 | for b in pix2pix._default_generator_blocks()]
31 |
32 | def test_output_size_nn_upsample_conv(self):
33 | batch_size = 2
34 | height, width = 256, 256
35 | num_outputs = 4
36 |
37 | images = tf.ones((batch_size, height, width, 3))
38 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
39 | logits, _ = pix2pix.pix2pix_generator(
40 | images, num_outputs, blocks=self._reduced_default_blocks(),
41 | upsample_method='nn_upsample_conv')
42 |
43 | with self.test_session() as session:
44 | session.run(tf.global_variables_initializer())
45 | np_outputs = session.run(logits)
46 | self.assertListEqual([batch_size, height, width, num_outputs],
47 | list(np_outputs.shape))
48 |
49 | def test_output_size_conv2d_transpose(self):
50 | batch_size = 2
51 | height, width = 256, 256
52 | num_outputs = 4
53 |
54 | images = tf.ones((batch_size, height, width, 3))
55 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
56 | logits, _ = pix2pix.pix2pix_generator(
57 | images, num_outputs, blocks=self._reduced_default_blocks(),
58 | upsample_method='conv2d_transpose')
59 |
60 | with self.test_session() as session:
61 | session.run(tf.global_variables_initializer())
62 | np_outputs = session.run(logits)
63 | self.assertListEqual([batch_size, height, width, num_outputs],
64 | list(np_outputs.shape))
65 |
66 | def test_block_number_dictates_number_of_layers(self):
67 | batch_size = 2
68 | height, width = 256, 256
69 | num_outputs = 4
70 |
71 | images = tf.ones((batch_size, height, width, 3))
72 | blocks = [
73 | pix2pix.Block(64, 0.5),
74 | pix2pix.Block(128, 0),
75 | ]
76 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
77 | _, end_points = pix2pix.pix2pix_generator(
78 | images, num_outputs, blocks)
79 |
80 | num_encoder_layers = 0
81 | num_decoder_layers = 0
82 | for end_point in end_points:
83 | if end_point.startswith('encoder'):
84 | num_encoder_layers += 1
85 | elif end_point.startswith('decoder'):
86 | num_decoder_layers += 1
87 |
88 | self.assertEqual(num_encoder_layers, len(blocks))
89 | self.assertEqual(num_decoder_layers, len(blocks))
90 |
91 |
92 | class DiscriminatorTest(tf.test.TestCase):
93 |
94 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2):
95 | return (input_size + pad * 2 - kernel_size) // stride + 1
96 |
97 | def test_four_layers(self):
98 | batch_size = 2
99 | input_size = 256
100 |
101 | output_size = self._layer_output_size(input_size)
102 | output_size = self._layer_output_size(output_size)
103 | output_size = self._layer_output_size(output_size)
104 | output_size = self._layer_output_size(output_size, stride=1)
105 | output_size = self._layer_output_size(output_size, stride=1)
106 |
107 | images = tf.ones((batch_size, input_size, input_size, 3))
108 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
109 | logits, end_points = pix2pix.pix2pix_discriminator(
110 | images, num_filters=[64, 128, 256, 512])
111 | self.assertListEqual([batch_size, output_size, output_size, 1],
112 | logits.shape.as_list())
113 | self.assertListEqual([batch_size, output_size, output_size, 1],
114 | end_points['predictions'].shape.as_list())
115 |
116 | def test_four_layers_no_padding(self):
117 | batch_size = 2
118 | input_size = 256
119 |
120 | output_size = self._layer_output_size(input_size, pad=0)
121 | output_size = self._layer_output_size(output_size, pad=0)
122 | output_size = self._layer_output_size(output_size, pad=0)
123 | output_size = self._layer_output_size(output_size, stride=1, pad=0)
124 | output_size = self._layer_output_size(output_size, stride=1, pad=0)
125 |
126 | images = tf.ones((batch_size, input_size, input_size, 3))
127 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
128 | logits, end_points = pix2pix.pix2pix_discriminator(
129 | images, num_filters=[64, 128, 256, 512], padding=0)
130 | self.assertListEqual([batch_size, output_size, output_size, 1],
131 | logits.shape.as_list())
132 | self.assertListEqual([batch_size, output_size, output_size, 1],
133 | end_points['predictions'].shape.as_list())
134 |
135 | def test_four_layers_wrog_paddig(self):
136 | batch_size = 2
137 | input_size = 256
138 |
139 | images = tf.ones((batch_size, input_size, input_size, 3))
140 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
141 | with self.assertRaises(TypeError):
142 | pix2pix.pix2pix_discriminator(
143 | images, num_filters=[64, 128, 256, 512], padding=1.5)
144 |
145 | def test_four_layers_negative_padding(self):
146 | batch_size = 2
147 | input_size = 256
148 |
149 | images = tf.ones((batch_size, input_size, input_size, 3))
150 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()):
151 | with self.assertRaises(ValueError):
152 | pix2pix.pix2pix_discriminator(
153 | images, num_filters=[64, 128, 256, 512], padding=-1)
154 |
155 | if __name__ == '__main__':
156 | tf.test.main()
157 |
--------------------------------------------------------------------------------
/ocr/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | @Author: JK_Wang
4 | @Time: 13-May-19
5 | """
--------------------------------------------------------------------------------
/ocr/ocr.py:
--------------------------------------------------------------------------------
1 | """
2 | 电子烟雾化器产品批号字符识别
3 | @Author: JK_Wang
4 | @Time: 13-May-19
5 | """
6 | import cv2
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | ROI = {'x1': 130, 'y1': 90, 'x2': 440, 'y2': 195}
11 | THRESH = 80
12 |
13 |
14 | def c_roi(img_path):
15 | img = cv2.imread(img_path, 0)
16 | roi_img = img[ROI['y1']:ROI['y2'], ROI['x1']:ROI['x2']]
17 | cv2.imshow('1', roi_img)
18 | return roi_img
19 |
20 |
21 | def img_morphology(img_gray, method=0, kernel_size=3, kernel_type=0):
22 | """
23 | 形态学操作
24 | :param img_gray: 灰度图
25 | :param method: 方法,0:腐蚀,1:膨胀,2:开运算,3:闭运算
26 | :param kernel_size: 窗口的大小
27 | :param kernel_type: 窗口的类型,0:矩形,1:椭圆形,2:交叉形
28 | :return:
29 | """
30 | if method == 0:
31 | # 腐蚀
32 | img_gray = cv2.erode(img_gray, (kernel_size, kernel_size), iterations=2)
33 | elif method == 1:
34 | # 膨胀
35 | img_gray = cv2.dilate(img_gray, (kernel_size, kernel_size), iterations=2)
36 | else:
37 | if kernel_type == 0:
38 | k_type = cv2.MORPH_RECT
39 | elif kernel_type == 1:
40 | k_type = cv2.MORPH_ELLIPSE
41 | else:
42 | k_type = cv2.MORPH_CROSS
43 | kernel = cv2.getStructuringElement(k_type, (kernel_size, kernel_size))
44 | if method == 2:
45 | # 开运算
46 | img_gray = cv2.morphologyEx(img_gray, cv2.MORPH_OPEN, kernel)
47 | else:
48 | # 闭运算
49 | img_gray = cv2.morphologyEx(img_gray, cv2.MORPH_CLOSE, kernel)
50 | return img_gray
51 |
52 |
53 | def c_location(roi_img):
54 | # blur_img = cv2.medianBlur(roi_img, 3)
55 | # blur_img = cv2.equalizeHist(blur_img)
56 | blur_img = cv2.bilateralFilter(roi_img, 9, 5, 5)
57 | cv2.imshow('2', blur_img)
58 |
59 | edge = cv2.Canny(blur_img, 50, 150)
60 | cv2.imshow('3', edge)
61 |
62 | _, th_img = cv2.threshold(roi_img, THRESH, 255, cv2.THRESH_OTSU) # cv2.THRESH_BINARY
63 | cv2.imshow('4', th_img)
64 |
65 | # loc_img = cv2.addWeighted(edge, 1, th_img, 1, 0)
66 | loc_img = img_morphology(th_img, 2, 5)
67 | loc_img = img_morphology(loc_img, 3, 9)
68 | # loc_img = th_img
69 |
70 | v_sum = np.sum(loc_img, 0, dtype=np.float)
71 | seg_pos = []
72 | x, y = 0, 0
73 | for i in range(len(v_sum)-1):
74 | if v_sum[i] == 0 and v_sum[i+1] != 0:
75 | x = i
76 | elif v_sum[i] != 0 and v_sum[i+1] == 0:
77 | y = i
78 | seg_pos.append([x, y])
79 |
80 | k = 10
81 | for i, j in seg_pos:
82 | c_img = roi_img[:, i:j]
83 | c_img = cv2.equalizeHist(c_img)
84 | _, c_img = cv2.threshold(c_img, THRESH, 255, cv2.THRESH_OTSU)
85 | c_img = img_morphology(c_img, 3, 3)
86 | k += 1
87 | cv2.imshow(str(k), c_img)
88 |
89 | plt.plot(v_sum)
90 | # plt.axis([0, 250, 0, 140])
91 | plt.xlabel('列坐标', fontproperties='SimHei', fontsize=14)
92 | plt.ylabel('白点数量', fontproperties='SimHei', fontsize=14)
93 | plt.show()
94 |
95 | return loc_img
96 |
97 |
98 | def ocr(img_path):
99 | roi_img = c_roi(img_path)
100 | loc_img = c_location(roi_img)
101 | cv2.imshow('img', loc_img)
102 | cv2.waitKey()
103 | cv2.destroyAllWindows()
104 |
105 |
106 | if __name__ == '__main__':
107 | ocr('../data/ocr/2.png')
108 |
109 |
--------------------------------------------------------------------------------
/qt/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/1.jpg
--------------------------------------------------------------------------------
/qt/LogDialog.py:
--------------------------------------------------------------------------------
1 | """
2 | 缺陷检测QT软件--检测记录对话框类
3 | author: 王建坤
4 | date: 2018-10-16
5 | """
6 | from PyQt5.QtWidgets import QDialog, QTableWidgetItem, QHeaderView, QFileDialog
7 | from PyQt5.uic import loadUi
8 | from PyQt5.QtCore import Qt
9 | import pandas as pd
10 | import pymysql
11 |
12 |
13 | class DetectLog(QDialog):
14 | """
15 | 检测记录类
16 | """
17 | def __init__(self, *args):
18 | super(DetectLog, self).__init__(*args)
19 | loadUi('ui_detect_log.ui', self)
20 |
21 | # 连接数据库
22 | self.connection = pymysql.connect(host='localhost', user='root', password='123456',
23 | db='detection', charset='utf8')
24 | self.cursor = self.connection.cursor()
25 | # 检索数据库中所有检测记录
26 | sql = "SELECT path,time,detect_class FROM detection_log"
27 | self.cursor.execute(sql)
28 | self.detect_logs = self.cursor.fetchall()
29 |
30 | self.setWindowFlags(Qt.WindowMinimizeButtonHint | Qt.WindowCloseButtonHint)
31 | self.table_log.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
32 | # self.table_log.horizontalHeader().setStyleSheet("QHeaderView::section{background-color:lightgray;};")
33 | # self.table_log.verticalHeader().setStyleSheet("QHeaderView::section{background-color:lightgray;};")
34 | self.pb_clear.clicked.connect(self.slot_clear)
35 | self.pb_save.clicked.connect(self.slot_save)
36 |
37 | # 生成缺陷记录表
38 | rows = len(self.detect_logs)
39 | self.table_log.setRowCount(rows)
40 | for i in range(rows):
41 | self.table_log.setItem(i, 0, QTableWidgetItem(str(self.detect_logs[i][0])))
42 | self.table_log.setItem(i, 1, QTableWidgetItem(str(self.detect_logs[i][1])))
43 | self.table_log.setItem(i, 2, QTableWidgetItem(self.detect_logs[i][2]))
44 |
45 | def insert_row(self):
46 | """
47 | 往表格中插入一行
48 | :return:
49 | """
50 | row_count = self.table_log.rowCount()
51 | self.table_log.insertRow(row_count)
52 |
53 | def slot_clear(self):
54 | """
55 | 清空表格
56 | :return:
57 | """
58 | self.table_log.clearContents()
59 | self.detect_logs = []
60 |
61 | def slot_save(self):
62 | """
63 | 保存检测记录的表格
64 | :return:
65 | """
66 | file_name = QFileDialog.getSaveFileName(self, 'save file', 'log', '.csv')
67 | detect_csv = pd.DataFrame({'图片路径': [x[0] for x in self.detect_logs],
68 | '检测结果': [x[1] for x in self.detect_logs],
69 | '检测时间': [x[2] for x in self.detect_logs]})
70 | if file_name[0] != '':
71 | detect_csv.to_csv(file_name[0])
72 |
73 |
74 | class DefectLog(QDialog):
75 | """
76 | 缺陷记录类
77 | """
78 | def __init__(self):
79 | super(DefectLog, self).__init__()
80 | loadUi('ui_defect_log.ui', self)
81 |
82 | self.name_class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4}
83 | # 连接数据库
84 | self.connection = pymysql.connect(host='localhost', user='root', password='123456',
85 | db='detection', charset='utf8')
86 | self.cursor = self.connection.cursor()
87 | # 检索数据库中所有检测为缺陷的记录
88 | sql = "SELECT path,time,detect_class FROM detection_log WHERE detect_class != 'normal'"
89 | self.cursor.execute(sql)
90 | self.defect_logs = self.cursor.fetchall()
91 |
92 | self.setWindowFlags(Qt.WindowMinimizeButtonHint | Qt.WindowCloseButtonHint)
93 | self.table_defect.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
94 | # self.table_defect.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
95 | self.table_defect.resizeColumnsToContents()
96 | self.table_statistics.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents)
97 | self.pb_save.clicked.connect(self.slot_save)
98 |
99 | # 生成缺陷记录表
100 | defect_rows = len(self.defect_logs)
101 | self.table_defect.setRowCount(defect_rows)
102 | for i in range(defect_rows):
103 | self.table_defect.setItem(i, 0, QTableWidgetItem(str(self.defect_logs[i][0])))
104 | self.table_defect.setItem(i, 1, QTableWidgetItem(str(self.defect_logs[i][1])))
105 | self.table_defect.setItem(i, 2, QTableWidgetItem(self.defect_logs[i][2]))
106 |
107 | # 生成缺陷统计表
108 | self.class_num_list = [0] * 4
109 | # 统计不同缺陷的个数,并填充到相应的单元格
110 | for j in self.defect_logs:
111 | defect_class = self.name_class_dic[j[2]]
112 | if -1 < defect_class < 5:
113 | self.class_num_list[defect_class-1] += 1
114 | # 计算各缺陷的占比,并填充到相应的单元格
115 | for k in range(4):
116 | self.table_statistics.setItem(k, 1, QTableWidgetItem(str(self.class_num_list[k-1])))
117 | if sum(self.class_num_list):
118 | self.table_statistics.setItem(k, 2, QTableWidgetItem('%.2f' % (self.class_num_list[k-1]/sum(self.class_num_list))))
119 |
120 | def slot_save(self):
121 | """
122 | 保存缺陷记录的表格
123 | :return:
124 | """
125 | file_name = QFileDialog.getSaveFileName(self, 'save file', 'log', '.csv')
126 | if file_name[0] == '':
127 | return
128 | if self.tab_table.currentIndex() == 0:
129 | defect_csv = pd.DataFrame({'图片路径': [x[0] for x in self.defect_logs],
130 | '缺陷类别': [x[1] for x in self.defect_logs],
131 | '检测时间': [x[2] for x in self.defect_logs]})
132 | defect_csv.to_csv(file_name[0])
133 | else:
134 | statistics_csv = pd.DataFrame({self.table_statistics.horizontalHeaderItem(0).text(): [i for i in range(5)], self.table_statistics.horizontalHeaderItem(1).text(): self.class_num_list, self.table_statistics.horizontalHeaderItem(2).text(): [self.table_statistics.itemAt(i, 3).text() for i in range(5)]})
135 | statistics_csv.to_csv(file_name[0], index=False)
136 |
137 |
138 | class DatabaseSet(QDialog):
139 | """
140 | 数据库配置类
141 | """
142 | def __init__(self):
143 | super(DatabaseSet, self).__init__()
144 | loadUi('ui_database_set.ui', self)
145 |
146 | self.pb_db_ok.clicked.connect(self.slot_db_set)
147 | self.pb_db_test.clicked.connect(self.slot_db_test)
148 | self.pb_db_cancel.clicked.connect(self.close)
149 |
150 | def slot_db_set(self):
151 | """
152 | 数据库配置
153 | :return:
154 | """
155 | print('数据库连接成功')
156 |
157 | def slot_db_test(self):
158 | """
159 | 测试数据库连接
160 | :return:
161 | """
162 | ip = self.le_ip.text()
163 | port = int(self.le_port.text())
164 | ac = self.le_ac.text()
165 | pw = self.le_pw.text()
166 | db = self.le_db.text()
167 |
168 | try:
169 | pymysql.connect(host=ip, port=port, user=ac, password=pw, db=db, charset='utf8')
170 | # pymysql.connect(host='localhost', user='root', password='123456', db='detection', charset='utf8')
171 | self.le_test_res.setText('连接测试成功')
172 | print('数据库连接测试成功')
173 | except:
174 | self.le_test_res.setText('连接测试失败')
175 | print('数据库连接测试失败')
176 |
--------------------------------------------------------------------------------
/qt/__pycache__/LogDialog.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/LogDialog.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/__pycache__/MainWindow.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/MainWindow.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/__pycache__/predict.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/predict.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/__pycache__/predict_dl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/predict_dl.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/__pycache__/predict_ip.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/predict_ip.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/cv.py:
--------------------------------------------------------------------------------
1 | """
2 | 缺陷检测QT软件--图像处理
3 | author: 王建坤
4 | date: 2018-10-17
5 | """
6 | import cv2
7 |
8 |
9 | def put_text(img_path, pre_class):
10 | """
11 | 打开图片绘制文字
12 | :param img_path:
13 | :param pre_class:
14 | :return:
15 | """
16 | # class_name_dic = {0: '正常', 1: '不导电', 2: '擦花', 3: '角位漏底', 4: '桔皮', 5: '漏底', 6: '起坑', 7: '脏点'}
17 | img = cv2.imread(img_path)
18 | img = cv2.resize(img, (500, 500))
19 | cv2.putText(img, str(pre_class), (img.shape[1]-100, img.shape[0]-20),
20 | cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
21 | cv2.imshow('1', img)
22 | cv2.waitKey()
23 | cv2.destroyAllWindows()
24 |
25 |
26 | if __name__ == '__main__':
27 | put_text('../data/test1.jpg', 1)
28 |
--------------------------------------------------------------------------------
/qt/hello-world-master/1.txt:
--------------------------------------------------------------------------------
1 | version 1.1
2 | link https://www.baidu.com/img/bd_logo1.png
3 |
--------------------------------------------------------------------------------
/qt/hello-world-master/README.md:
--------------------------------------------------------------------------------
1 | hello-world
2 | =======
3 | ok
4 |
--------------------------------------------------------------------------------
/qt/log/1:
--------------------------------------------------------------------------------
1 | ,缺陷类别,数量,占比
2 | 0,0,0,0
3 | 1,1,2,0
4 | 2,2,0,0
5 | 3,3,0,0
6 | 4,4,0,0
7 | 5,5,0,0
8 | 6,6,0,0
9 | 7,7,0,0
10 | 8,8,0,0
11 | 9,9,0,0
12 | 10,10,0,0
13 | 11,11,0,0
14 |
--------------------------------------------------------------------------------
/qt/log/detect_20181016_134111.csv:
--------------------------------------------------------------------------------
1 | ,图片路径,检测结果,检测时间
2 | 0,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:47
3 | 1,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:52
4 | 2,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:52
5 | 3,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:52
6 | 4,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:53
7 | 5,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:53
8 | 6,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:56
9 | 7,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:56
10 | 8,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:56
11 | 9,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57
12 | 10,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57
13 | 11,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57
14 | 12,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57
15 | 13,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57
16 | 14,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57
17 | 15,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58
18 | 16,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58
19 | 17,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58
20 | 18,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58
21 |
--------------------------------------------------------------------------------
/qt/main.py:
--------------------------------------------------------------------------------
1 | """
2 | 缺陷检测QT软件--启动
3 | author: 王建坤
4 | date: 2018-9-25
5 | """
6 | import sys
7 | sys.path.append("..")
8 |
9 | import sys
10 | from PyQt5.QtWidgets import QApplication
11 | from qt import MainWindow
12 |
13 | # 实例化一个 App
14 | app = QApplication(sys.argv)
15 | # 实例化一个 窗口
16 | win = MainWindow.Detect()
17 | # 显示窗口
18 | win.show()
19 | # 进入主循环
20 | sys.exit(app.exec())
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/qt/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/qt/nets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/nets/__pycache__/alexnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/alexnet.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/nets/__pycache__/mobilenet_v1.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/mobilenet_v1.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/nets/__pycache__/vgg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/vgg.cpython-36.pyc
--------------------------------------------------------------------------------
/qt/nets/alexnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a model definition for AlexNet.
16 |
17 | This work was first described in:
18 | ImageNet Classification with Deep Convolutional Neural Networks
19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
20 |
21 | and later refined in:
22 | One weird trick for parallelizing convolutional neural networks
23 | Alex Krizhevsky, 2014
24 |
25 | Here we provide the implementation proposed in "One weird trick" and not
26 | "ImageNet Classification", as per the paper, the LRN layers have been removed.
27 |
28 | Usage:
29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
30 | outputs, end_points = alexnet.alexnet_v2(inputs)
31 |
32 | @@alexnet_v2
33 | """
34 |
35 | from __future__ import absolute_import
36 | from __future__ import division
37 | from __future__ import print_function
38 |
39 | import tensorflow as tf
40 |
41 | import tensorflow.contrib.slim as slim
42 |
43 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
44 |
45 |
46 | def alexnet_v2_arg_scope(weight_decay=0.0005):
47 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
48 | activation_fn=tf.nn.relu,
49 | biases_initializer=tf.constant_initializer(0.1),
50 | weights_regularizer=slim.l2_regularizer(weight_decay)):
51 | with slim.arg_scope([slim.conv2d], padding='SAME'):
52 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
53 | return arg_sc
54 |
55 |
56 | def alexnet_v2(inputs,
57 | num_classes=1000,
58 | is_training=True,
59 | dropout_keep_prob=0.5,
60 | spatial_squeeze=True,
61 | scope='alexnet_v2',
62 | global_pool=False):
63 | """AlexNet version 2.
64 |
65 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf
66 | Parameters from:
67 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
68 | layers-imagenet-1gpu.cfg
69 |
70 | Note: All the fully_connected layers have been transformed to conv2d layers.
71 | To use in classification mode, resize input to 224x224 or set
72 | global_pool=True. To use in fully convolutional mode, set
73 | spatial_squeeze to false.
74 | The LRN layers have been removed and change the initializers from
75 | random_normal_initializer to xavier_initializer.
76 |
77 | Args:
78 | inputs: a tensor of size [batch_size, height, width, channels].
79 | num_classes: the number of predicted classes. If 0 or None, the logits layer
80 | is omitted and the input features to the logits layer are returned instead.
81 | is_training: whether or not the model is being trained.
82 | dropout_keep_prob: the probability that activations are kept in the dropout
83 | layers during training.
84 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
85 | logits. Useful to remove unnecessary dimensions for classification.
86 | scope: Optional scope for the variables.
87 | global_pool: Optional boolean flag. If True, the input to the classification
88 | layer is avgpooled to size 1x1, for any input size. (This is not part
89 | of the original AlexNet.)
90 |
91 | Returns:
92 | net: the output of the logits layer (if num_classes is a non-zero integer),
93 | or the non-dropped-out input to the logits layer (if num_classes is 0
94 | or None).
95 | end_points: a dict of tensors with intermediate activations.
96 | """
97 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
98 | end_points_collection = sc.original_name_scope + '_end_points'
99 | # Collect outputs for conv2d, fully_connected and max_pool2d.
100 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
101 | outputs_collections=[end_points_collection]):
102 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
103 | scope='conv1')
104 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
105 | net = slim.conv2d(net, 192, [5, 5], scope='conv2')
106 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
107 | net = slim.conv2d(net, 384, [3, 3], scope='conv3')
108 | net = slim.conv2d(net, 384, [3, 3], scope='conv4')
109 | net = slim.conv2d(net, 256, [3, 3], scope='conv5')
110 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
111 |
112 | # Use conv2d instead of fully_connected layers.
113 | with slim.arg_scope([slim.conv2d],
114 | weights_initializer=trunc_normal(0.005),
115 | biases_initializer=tf.constant_initializer(0.1)):
116 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
117 | scope='fc6')
118 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
119 | scope='dropout6')
120 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
121 | # Convert end_points_collection into a end_point dict.
122 | end_points = slim.utils.convert_collection_to_dict(
123 | end_points_collection)
124 | if global_pool:
125 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
126 | end_points['global_pool'] = net
127 | if num_classes:
128 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
129 | scope='dropout7')
130 | net = slim.conv2d(net, num_classes, [1, 1],
131 | activation_fn=None,
132 | normalizer_fn=None,
133 | biases_initializer=tf.zeros_initializer(),
134 | scope='fc8')
135 | if spatial_squeeze:
136 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
137 | end_points[sc.name + '/fc8'] = net
138 | return net, end_points
139 |
140 |
141 | alexnet_v2.default_image_size = 224
142 |
--------------------------------------------------------------------------------
/qt/predict_dl.py:
--------------------------------------------------------------------------------
1 | """
2 | 缺陷检测QT软件--预测图片
3 | author: 王建坤
4 | date: 2018-9-25
5 | """
6 | import numpy as np
7 | import tensorflow as tf
8 | from nets import alexnet, mobilenet_v1
9 | from PIL import Image
10 | import time
11 |
12 | CLASSES = 5
13 | IMG_SIZE = 224
14 | GLOBAL_POOL = False
15 |
16 |
17 | def load_model(model='Mobile'):
18 | global x, y, sess
19 | # 占位符
20 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1])
21 |
22 | # 模型保存路径,前向传播
23 | if model == 'Alex':
24 | log_path = 'weight/Alex'
25 | y, _ = alexnet.alexnet_v2(x,
26 | num_classes=CLASSES, # 分类的类别
27 | is_training=False, # 是否在训练
28 | dropout_keep_prob=1.0, # 保留比率
29 | spatial_squeeze=True, # 压缩掉1维的维度
30 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool
31 | elif model == 'Mobile':
32 | log_path = 'weight/Mobile'
33 | y, _ = mobilenet_v1.mobilenet_v1(x,
34 | num_classes=CLASSES,
35 | dropout_keep_prob=1.0,
36 | is_training=False,
37 | min_depth=8,
38 | depth_multiplier=1.0,
39 | conv_defs=None,
40 | prediction_fn=None,
41 | spatial_squeeze=True,
42 | reuse=None,
43 | scope='MobilenetV1',
44 | global_pool=GLOBAL_POOL)
45 | else:
46 | print('Error: model name not exist')
47 | return
48 | y = tf.nn.softmax(y)
49 |
50 | saver = tf.train.Saver()
51 |
52 | sess = tf.Session()
53 | # 恢复模型权重
54 | # print('Reading checkpoints from: ', model)
55 | ckpt = tf.train.get_checkpoint_state(log_path)
56 | if ckpt and ckpt.model_checkpoint_path:
57 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
58 | saver.restore(sess, ckpt.model_checkpoint_path)
59 | # print('Loading success, global_step is %s' % global_step)
60 | else:
61 | print('Error: no checkpoint file found')
62 | return -1
63 |
64 |
65 | def close_sess():
66 | sess.close()
67 | tf.reset_default_graph()
68 | print('Close the session successfully')
69 |
70 |
71 | def predict(m=0, img_path='E:\std.jpg'):
72 | start_time = time.clock()
73 | img = Image.open(img_path)
74 | if img.mode != 'L':
75 | print('Error: the image format is not support')
76 | return -1, -1
77 | img = img.resize((IMG_SIZE, IMG_SIZE))
78 | img = np.array(img, np.float32)
79 | img = np.expand_dims(img, axis=0)
80 | img = np.expand_dims(img, axis=3)
81 | predictions = sess.run(y, feed_dict={x: img})
82 | pre = np.argmax(predictions)
83 | end_time = time.clock()
84 | run_time = round((end_time - start_time)*1000, 1)
85 | if m == 0:
86 | print('class: %d running time: %s ms ' % (int(pre), run_time))
87 | # print('prediction is:', '\n', predictions)
88 | return pre, run_time
89 |
90 |
91 | if __name__ == '__main__':
92 | print('run predict:')
93 | # load_model()
94 |
--------------------------------------------------------------------------------
/qt/src/校标.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/src/校标.png
--------------------------------------------------------------------------------
/qt/src/欢迎.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/src/欢迎.jpg
--------------------------------------------------------------------------------
/qt/test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from PyQt5.QtWidgets import QApplication, QMainWindow, QDialog, QPushButton, QWidget
3 | from PyQt5.QtCore import QObject, Qt, pyqtSignal
4 |
5 |
6 | class MySignal(QObject):
7 | instance = None
8 | signal = pyqtSignal()
9 | status_signal = pyqtSignal(str)
10 |
11 | @classmethod
12 | def my_signal(cls):
13 | if cls.instance:
14 | return cls.instance
15 | else:
16 | obj = cls()
17 | cls.instance = obj
18 | return cls.instance
19 |
20 | def em(self):
21 | print(id(self.signal))
22 | self.signal.emit()
23 |
24 | def status_emit(self, s):
25 | self.status_signal.emit(s)
26 |
27 |
28 | class MyPushButton(QPushButton):
29 | def __init__(self, *args):
30 | super(MyPushButton, self).__init__(*args)
31 |
32 | self.setMouseTracking(True)
33 |
34 | def mouseMoveEvent(self, event):
35 | MySignal.my_signal().status_emit('X:'+str(event.pos().x())+' Y:'+str(event.pos().y()))
36 | self.update()
37 |
38 |
39 | class MainWindow(QMainWindow):
40 | """
41 | 主窗口类
42 | """
43 | Signal = MySignal.my_signal().signal
44 | Status_signal = MySignal.my_signal().status_signal
45 |
46 | print(id(Signal), '1')
47 |
48 | def __init__(self, *args):
49 | super(MainWindow, self).__init__(*args)
50 |
51 | # 设置主窗口的标题及大小
52 | self.setWindowTitle('主窗口')
53 | self.resize(400, 300)
54 |
55 | # 创建按钮
56 | self.btn = MyPushButton(self)
57 | self.btn.setText('自定义按钮')
58 | self.btn.move(50, 50)
59 | self.btn.clicked.connect(self.show_dialog)
60 |
61 | # 自定义信号绑定
62 | self.Signal.connect(self.test)
63 | self.Status_signal.connect(self.show_status)
64 |
65 | self.dialog = Dialog()
66 |
67 | def show_dialog(self):
68 | self.dialog.show()
69 | self.dialog.exec()
70 |
71 | def test(self):
72 | self.btn.setText('我改变了')
73 |
74 | def show_status(self, s):
75 | self.statusBar().showMessage(s)
76 |
77 | def keyPressEvent(self, event):
78 | if event.key() == Qt.Key_Home:
79 | print('Home')
80 | else:
81 | QWidget.keyPressEvent(self, event)
82 |
83 |
84 | class Dialog(QDialog):
85 | """
86 | 对话框类
87 | """
88 | def __init__(self, *args):
89 | super(Dialog, self).__init__(*args)
90 |
91 | # 设置对话框的标题及大小
92 | self.setWindowTitle('对话框')
93 | self.resize(200, 200)
94 | self.setWindowModality(Qt.ApplicationModal)
95 | self.btn = QPushButton(self)
96 | self.btn.setText('改变主窗口按钮的名称')
97 | self.btn.move(50, 50)
98 | self.btn.clicked.connect(MySignal.my_signal().em)
99 | print(id(MySignal.my_signal().signal))
100 |
101 |
102 | if __name__ == '__main__':
103 | app = QApplication(sys.argv)
104 | demo = MainWindow()
105 | demo.show()
106 | sys.exit(app.exec())
107 |
108 |
109 |
--------------------------------------------------------------------------------
/qt/test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/test1.jpg
--------------------------------------------------------------------------------
/qt/ui_defect_log.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Dialog
4 |
5 |
6 |
7 | 0
8 | 0
9 | 520
10 | 500
11 |
12 |
13 |
14 |
15 | 微软雅黑
16 | 10
17 |
18 |
19 |
20 | 缺陷记录
21 |
22 |
23 | -
24 |
25 |
26 | 1
27 |
28 |
29 |
30 | 缺陷记录
31 |
32 |
33 |
34 | 5
35 |
36 |
37 | 5
38 |
39 |
40 | 5
41 |
42 |
43 | 5
44 |
45 |
-
46 |
47 |
48 |
49 | 0
50 | 0
51 |
52 |
53 |
54 |
55 |
56 |
57 | QAbstractScrollArea::AdjustToContents
58 |
59 |
60 | QAbstractItemView::NoEditTriggers
61 |
62 |
63 | 0
64 |
65 |
66 | 3
67 |
68 |
69 | 160
70 |
71 |
72 | 70
73 |
74 |
75 | false
76 |
77 |
78 |
79 | 图片路径
80 |
81 |
82 |
83 |
84 | 检测时间
85 |
86 |
87 |
88 |
89 | 缺陷类别
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | 缺陷统计
99 |
100 |
101 |
102 | 5
103 |
104 |
105 | 5
106 |
107 |
108 | 5
109 |
110 |
111 | 5
112 |
113 |
114 | 5
115 |
116 | -
117 |
118 |
119 | 4
120 |
121 |
122 | 3
123 |
124 |
125 | false
126 |
127 |
128 | 160
129 |
130 |
131 | 25
132 |
133 |
134 | true
135 |
136 |
137 | false
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | 缺陷类别
146 |
147 |
148 |
149 |
150 | 数量
151 |
152 |
153 |
154 |
155 | 占比
156 |
157 |
158 | -
159 |
160 | 0
161 |
162 |
163 | -
164 |
165 | 1
166 |
167 |
168 | -
169 |
170 | 2
171 |
172 |
173 | -
174 |
175 | 3
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 | -
185 |
186 |
-
187 |
188 |
189 | Qt::Horizontal
190 |
191 |
192 |
193 | 40
194 | 20
195 |
196 |
197 |
198 |
199 | -
200 |
201 |
202 |
203 | 微软雅黑
204 | 12
205 |
206 |
207 |
208 | 保存
209 |
210 |
211 | false
212 |
213 |
214 |
215 | -
216 |
217 |
218 | Qt::Horizontal
219 |
220 |
221 |
222 | 40
223 | 20
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/qt/ui_detect_log.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Dialog_log
4 |
5 |
6 | Qt::NonModal
7 |
8 |
9 |
10 | 0
11 | 0
12 | 500
13 | 500
14 |
15 |
16 |
17 |
18 | 微软雅黑
19 | 10
20 |
21 |
22 |
23 | 检查记录
24 |
25 |
26 | -
27 |
28 |
29 |
30 | 0
31 | 0
32 |
33 |
34 |
35 |
36 |
37 |
38 | QAbstractScrollArea::AdjustToContents
39 |
40 |
41 | QAbstractItemView::NoEditTriggers
42 |
43 |
44 | 0
45 |
46 |
47 | 3
48 |
49 |
50 | 160
51 |
52 |
53 | 80
54 |
55 |
56 | false
57 |
58 |
59 |
60 | 图片路径
61 |
62 |
63 |
64 |
65 | 检测时间
66 |
67 |
68 |
69 |
70 | 检测结果
71 |
72 |
73 |
74 |
75 | -
76 |
77 |
-
78 |
79 |
80 | Qt::Horizontal
81 |
82 |
83 |
84 | 40
85 | 20
86 |
87 |
88 |
89 |
90 | -
91 |
92 |
93 |
94 | 微软雅黑
95 | 12
96 |
97 |
98 |
99 | 清空
100 |
101 |
102 | false
103 |
104 |
105 |
106 | -
107 |
108 |
109 | Qt::Horizontal
110 |
111 |
112 |
113 | 40
114 | 20
115 |
116 |
117 |
118 |
119 | -
120 |
121 |
122 |
123 | 微软雅黑
124 | 12
125 |
126 |
127 |
128 | 保存
129 |
130 |
131 | false
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | 局部代码测试
3 | @Author: JK_Wang
4 | @Time: 20-May-19
5 | """
6 | # import pymysql
7 | #
8 | # connection = pymysql.connect(host='localhost', user='root', password='123456', db='detection', charset='utf8')
9 | # print(connection)
10 | # cursor = connection.cursor()
11 | # sql = "INSERT INTO detection_log(detect_class, path) VALUES(%s, %s)"
12 | # cursor.execute(sql, ('1', '2'))
13 | # connection.commit()
14 |
15 | # import cv2
16 | #
17 | # img = cv2.imread('D:/test.jpg')
18 | # print(img.shape)
19 | # _, im = cv2.imencode('.jpg', img)
20 | # res = cv2.imdecode(im, -1)
21 | # print(res)
22 | # # cv2.imshow('1', im)
23 | # print(im.shape)
24 | # cv2.imwrite('D:/test.jpg', im)
25 | #
26 | # cv2.waitKey()
27 |
28 | import pymysql
29 | import time
30 | import random as rd
31 | connection = pymysql.connect(host='localhost', user='root', password='123456',
32 | db='detection', charset='utf8')
33 | cursor = connection.cursor()
34 | class_name_dic = {0: 'normal', 1: 'nothing', 2: 'lack_cotton', 3: 'lack_piece', 4: 'wire_fail'}
35 | defect_num = 0
36 | t_sum = [0.0, 0.0, 0.0, 0.0]
37 | for i in range(1, 1000):
38 | print(i)
39 | time.sleep(rd.randint(15, 25)/10)
40 | a = rd.randint(0, 100)
41 | t1 = rd.randint(19, 23)/10
42 | t2 = rd.randint(25, 29) / 10
43 | t3 = rd.randint(18, 22) / 10
44 | t4 = rd.randint(24, 26) / 10
45 | t_sum[0] += t1
46 | t_sum[1] += t2
47 | t_sum[2] += t3
48 | t_sum[3] += t4
49 | avg1 = round(t_sum[0] / i, 1)
50 | avg2 = round(t_sum[1] / i, 1)
51 | avg3 = round(t_sum[2] / i, 1)
52 | avg4 = round(t_sum[3] / i, 1)
53 |
54 | if a < 2:
55 | defect_num += 1
56 | sql_1 = "UPDATE running_state set uph = %s, detection_num = %s, defect_num = %s where id = 1"
57 | sql_2 = "UPDATE chart_data set step_1 = %s, step_2 = %s, step_3 = %s, step_4 = %s where id = 1"
58 | sql_3 = "UPDATE chart_data set step_1 = %s, step_2 = %s, step_3 = %s, step_4 = %s where id = 2"
59 | cursor.execute(sql_1, (str(i), str(1536 + i), str(defect_num)))
60 | cursor.execute(sql_2, (str(t1), str(t2), str(t3), str(t4)))
61 | cursor.execute(sql_3, (str(avg1), str(avg2), str(avg3), str(avg4)))
62 |
63 | # 提交事务
64 | connection.commit()
--------------------------------------------------------------------------------