├── Discuz
├── README.md
├── get_discuz.py
├── test.py
└── train.py
├── Pytorch-Seg
├── lesson-1
│ ├── unet_model.py
│ └── unet_parts.py
├── lesson-2
│ ├── best_model.pth
│ ├── data
│ │ ├── test
│ │ │ ├── 0.png
│ │ │ ├── 1.png
│ │ │ ├── 10.png
│ │ │ ├── 11.png
│ │ │ ├── 12.png
│ │ │ ├── 13.png
│ │ │ ├── 14.png
│ │ │ ├── 15.png
│ │ │ ├── 16.png
│ │ │ ├── 17.png
│ │ │ ├── 18.png
│ │ │ ├── 19.png
│ │ │ ├── 2.png
│ │ │ ├── 20.png
│ │ │ ├── 21.png
│ │ │ ├── 22.png
│ │ │ ├── 23.png
│ │ │ ├── 24.png
│ │ │ ├── 25.png
│ │ │ ├── 26.png
│ │ │ ├── 27.png
│ │ │ ├── 28.png
│ │ │ ├── 29.png
│ │ │ ├── 3.png
│ │ │ ├── 4.png
│ │ │ ├── 5.png
│ │ │ ├── 6.png
│ │ │ ├── 7.png
│ │ │ ├── 8.png
│ │ │ └── 9.png
│ │ └── train
│ │ │ ├── image
│ │ │ ├── 0.png
│ │ │ ├── 1.png
│ │ │ ├── 10.png
│ │ │ ├── 11.png
│ │ │ ├── 12.png
│ │ │ ├── 13.png
│ │ │ ├── 14.png
│ │ │ ├── 15.png
│ │ │ ├── 16.png
│ │ │ ├── 17.png
│ │ │ ├── 18.png
│ │ │ ├── 19.png
│ │ │ ├── 2.png
│ │ │ ├── 20.png
│ │ │ ├── 21.png
│ │ │ ├── 22.png
│ │ │ ├── 23.png
│ │ │ ├── 24.png
│ │ │ ├── 25.png
│ │ │ ├── 26.png
│ │ │ ├── 27.png
│ │ │ ├── 28.png
│ │ │ ├── 29.png
│ │ │ ├── 3.png
│ │ │ ├── 4.png
│ │ │ ├── 5.png
│ │ │ ├── 6.png
│ │ │ ├── 7.png
│ │ │ ├── 8.png
│ │ │ └── 9.png
│ │ │ └── label
│ │ │ ├── 0.png
│ │ │ ├── 1.png
│ │ │ ├── 10.png
│ │ │ ├── 11.png
│ │ │ ├── 12.png
│ │ │ ├── 13.png
│ │ │ ├── 14.png
│ │ │ ├── 15.png
│ │ │ ├── 16.png
│ │ │ ├── 17.png
│ │ │ ├── 18.png
│ │ │ ├── 19.png
│ │ │ ├── 2.png
│ │ │ ├── 20.png
│ │ │ ├── 21.png
│ │ │ ├── 22.png
│ │ │ ├── 23.png
│ │ │ ├── 24.png
│ │ │ ├── 25.png
│ │ │ ├── 26.png
│ │ │ ├── 27.png
│ │ │ ├── 28.png
│ │ │ ├── 29.png
│ │ │ ├── 3.png
│ │ │ ├── 4.png
│ │ │ ├── 5.png
│ │ │ ├── 6.png
│ │ │ ├── 7.png
│ │ │ ├── 8.png
│ │ │ └── 9.png
│ ├── model
│ │ ├── __init__.py
│ │ ├── unet_model.py
│ │ └── unet_parts.py
│ ├── predict.py
│ ├── train.py
│ └── utils
│ │ ├── __pycache__
│ │ └── dataset.cpython-37.pyc
│ │ └── dataset.py
├── lesson-3
│ ├── log.py
│ ├── logger.py
│ ├── show_loss.py
│ ├── tensorboardX_test.py
│ └── train_loss.txt
└── lesson-4
│ ├── dataset.py
│ ├── dir_label.txt
│ ├── infer.py
│ ├── test.txt
│ ├── train.py
│ ├── train.txt
│ └── val.txt
├── README.md
├── Tutorial
├── README.md
├── lesson-1
│ └── perceptron.py
├── lesson-2
│ ├── linear_unit.py
│ └── perceptron.py
├── lesson-3
│ ├── bp.py
│ ├── fc.py
│ └── mnist.py
├── lesson-4
│ ├── activators.py
│ └── cnn.py
├── lesson-5
│ ├── activators.py
│ ├── cnn.py
│ └── rnn.py
├── lesson-6
│ ├── activators.py
│ ├── cnn.py
│ └── lstm.py
└── lesson-7
│ ├── activators.py
│ └── recursive.py
├── face
└── video_mosaic.py
└── mnist.py
/Discuz/README.md:
--------------------------------------------------------------------------------
1 | ## 深度学习实战教程
2 |
3 | #### 说明
4 |
5 | * get_discuz.py接口已关闭,关闭说明以及验证码环境部署方法,请查看:[Tensorflow实战(二):Discuz验证码识别](https://cuijiahua.com/blog/2018/01/dl_5.html "悬停显示")
--------------------------------------------------------------------------------
/Discuz/get_discuz.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | from urllib.request import urlretrieve
3 | import time, random, os
4 |
5 | class Discuz():
6 | def __init__(self):
7 | # Discuz验证码生成图片地址
8 | self.url = 'http://cuijiahua.com/tutrial/discuz/index.php?label='
9 |
10 | def random_captcha_text(self, captcha_size = 4):
11 | """
12 | 验证码一般都无视大小写;验证码长度4个字符
13 | Parameters:
14 | captcha_size:验证码长度
15 | Returns:
16 | captcha_text:验证码字符串
17 | """
18 | number = ['0','1','2','3','4','5','6','7','8','9']
19 | alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']
20 | char_set = number + alphabet
21 | captcha_text = []
22 | for i in range(captcha_size):
23 | c = random.choice(char_set)
24 | captcha_text.append(c)
25 | captcha_text = ''.join(captcha_text)
26 | return captcha_text
27 |
28 | def download_discuz(self, nums = 5000):
29 | """
30 | 下载验证码图片
31 | Parameters:
32 | nums:下载的验证码图片数量
33 | """
34 | dirname = './Discuz'
35 | if dirname not in os.listdir():
36 | os.mkdir(dirname)
37 | for i in range(nums):
38 | label = self.random_captcha_text()
39 | print('第%d张图片:%s下载' % (i + 1,label))
40 | urlretrieve(url = self.url + label, filename = dirname + '/' + label + '.jpg')
41 | # 请至少加200ms延时,避免给我的服务器造成过多的压力,如发现影响服务器正常工作,我会关闭此功能。
42 | # 你好我也好,大家好才是真的好!
43 | time.sleep(0.2)
44 | print('恭喜图片下载完成!')
45 |
46 | if __name__ == '__main__':
47 | dz = Discuz()
48 | dz.download_discuz()
49 |
--------------------------------------------------------------------------------
/Discuz/test.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | import tensorflow as tf
3 | import numpy as np
4 | import train
5 |
6 | def crack_captcha(captcha_image, captcha_label):
7 | """
8 | 使用模型做预测
9 | Parameters:
10 | captcha_image:数据
11 | captcha_label:标签
12 | """
13 |
14 | output = dz.crack_captcha_cnn()
15 | saver = tf.train.Saver()
16 | with tf.Session(config=dz.config) as sess:
17 |
18 | saver.restore(sess, tf.train.latest_checkpoint('.'))
19 | for i in range(len(captcha_label)):
20 | img = captcha_image[i].flatten()
21 | label = captcha_label[i]
22 | predict = tf.argmax(tf.reshape(output, [-1, dz.max_captcha, dz.char_set_len]), 2)
23 | text_list = sess.run(predict, feed_dict={dz.X: [img], dz.keep_prob: 1})
24 | text = text_list[0].tolist()
25 | vector = np.zeros(dz.max_captcha*dz.char_set_len)
26 | i = 0
27 | for n in text:
28 | vector[i*dz.char_set_len + n] = 1
29 | i += 1
30 | prediction_text = dz.vec2text(vector)
31 | print("正确: {} 预测: {}".format(dz.vec2text(label), prediction_text))
32 |
33 | if __name__ == '__main__':
34 | dz = train.Discuz()
35 | batch_x, batch_y = dz.get_next_batch(False, 5)
36 | crack_captcha(batch_x, batch_y)
37 |
--------------------------------------------------------------------------------
/Discuz/train.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | import tensorflow as tf
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import os, random, cv2
6 |
7 | class Discuz():
8 | def __init__(self):
9 | # 指定GPU
10 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11 | self.config = tf.ConfigProto(allow_soft_placement = True)
12 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 1)
13 | self.config.gpu_options.allow_growth = True
14 | # 数据集路径
15 | self.data_path = './Discuz/'
16 | # 写到指定的磁盘路径中
17 | self.log_dir = '/home/Jack_Cui/Work/Crack_Discuz/Tb'
18 | # 数据集图片大小
19 | self.width = 30
20 | self.heigth = 100
21 | # 最大迭代次数
22 | self.max_steps = 1000000
23 | # 读取数据集
24 | self.test_imgs, self.test_labels, self.train_imgs, self.train_labels = self.get_imgs()
25 | # 训练集大小
26 | self.train_size = len(self.train_imgs)
27 | # 测试集大小
28 | self.test_size = len(self.test_imgs)
29 | # 每次获得batch_size大小的当前训练集指针
30 | self.train_ptr = 0
31 | # 每次获取batch_size大小的当前测试集指针
32 | self.test_ptr = 0
33 | # 字符字典大小:0-9 a-z A-Z _(验证码如果小于4,用_补齐) 一共63个字符
34 | self.char_set_len = 63
35 | # 验证码最长的长度为4
36 | self.max_captcha = 4
37 | # 输入数据X占位符
38 | self.X = tf.placeholder(tf.float32, [None, self.heigth*self.width])
39 | # 输入数据Y占位符
40 | self.Y = tf.placeholder(tf.float32, [None, self.char_set_len*self.max_captcha])
41 | # keepout占位符
42 | self.keep_prob = tf.placeholder(tf.float32)
43 |
44 | def test_show_img(self, fname, show = True):
45 | """
46 | 读取图片,显示图片信息并显示其灰度图
47 | Parameters:
48 | fname:图片文件名
49 | show:是否展示灰度图
50 | """
51 | # 获得标签
52 | label = fname.split('.')
53 | # 读取图片
54 | img = cv2.imread(fname)
55 | # 获取图片大小
56 | width, heigth, _ = img.shape
57 | print("图像宽:%s px" % width)
58 | print("图像高:%s px" % heigth)
59 |
60 | if show == True:
61 | # plt.imshow(img)
62 | #将fig画布分隔成1行1列,不共享x轴和y轴,fig画布的大小为(13,8)
63 | #当nrow=3,nclos=2时,代表fig画布被分为六个区域,axs[0][0]表示第一行第一列
64 | fig, axs = plt.subplots(nrows=2, ncols=1, sharex=False, sharey=False, figsize=(10,5))
65 | axs[0].imshow(img)
66 | axs0_title_text = axs[0].set_title(u'RGB img')
67 | plt.setp(axs0_title_text, size=10)
68 | # 转换为灰度图
69 | gray = np.mean(img, axis=-1)
70 | axs[1].imshow(gray, cmap='Greys_r')
71 | axs1_title_text = axs[1].set_title(u'GRAY img')
72 | plt.setp(axs1_title_text, size=10)
73 | plt.show()
74 |
75 | def get_imgs(self, rate = 0.2):
76 | """
77 | 获取图片,并划分训练集和测试集
78 | Parameters:
79 | rate:测试集和训练集的比例,即测试集个数/训练集个数
80 | Returns:
81 | test_imgs:测试集
82 | test_labels:测试集标签
83 | train_imgs:训练集
84 | test_labels:训练集标签
85 | """
86 | # 读取图片
87 | imgs = os.listdir(self.data_path)
88 | # 打乱图片顺序
89 | random.shuffle(imgs)
90 |
91 | # 数据集总共个数
92 | imgs_num = len(imgs)
93 | # 按照比例求出测试集个数
94 | test_num = int(imgs_num * rate / (1 + rate))
95 | # 测试集
96 | test_imgs = imgs[:test_num]
97 | # 根据文件名获取测试集标签
98 | test_labels = list(map(lambda x: x.split('.')[0], test_imgs))
99 | # 训练集
100 | train_imgs = imgs[test_num:]
101 | # 根据文件名获取训练集标签
102 | train_labels = list(map(lambda x: x.split('.')[0], train_imgs))
103 |
104 | return test_imgs, test_labels, train_imgs, train_labels
105 |
106 | def get_next_batch(self, train_flag=True, batch_size=100):
107 | """
108 | 获得batch_size大小的数据集
109 | Parameters:
110 | batch_size:batch_size大小
111 | train_flag:是否从训练集获取数据
112 | Returns:
113 | batch_x:大小为batch_size的数据x
114 | batch_y:大小为batch_size的数据y
115 | """
116 | # 从训练集获取数据
117 | if train_flag == True:
118 | if (batch_size + self.train_ptr) < self.train_size:
119 | trains = self.train_imgs[self.train_ptr:(self.train_ptr + batch_size)]
120 | labels = self.train_labels[self.train_ptr:(self.train_ptr + batch_size)]
121 | self.train_ptr += batch_size
122 | else:
123 | new_ptr = (self.train_ptr + batch_size) % self.train_size
124 | trains = self.train_imgs[self.train_ptr:] + self.train_imgs[:new_ptr]
125 | labels = self.train_labels[self.train_ptr:] + self.train_labels[:new_ptr]
126 | self.train_ptr = new_ptr
127 |
128 | batch_x = np.zeros([batch_size, self.heigth*self.width])
129 | batch_y = np.zeros([batch_size, self.max_captcha*self.char_set_len])
130 |
131 | for index, train in enumerate(trains):
132 | img = np.mean(cv2.imread(self.data_path + train), -1)
133 | # 将多维降维1维
134 | batch_x[index,:] = img.flatten() / 255
135 | for index, label in enumerate(labels):
136 | batch_y[index,:] = self.text2vec(label)
137 |
138 | # 从测试集获取数据
139 | else:
140 | if (batch_size + self.test_ptr) < self.test_size:
141 | tests = self.test_imgs[self.test_ptr:(self.test_ptr + batch_size)]
142 | labels = self.test_labels[self.test_ptr:(self.test_ptr + batch_size)]
143 | self.test_ptr += batch_size
144 | else:
145 | new_ptr = (self.test_ptr + batch_size) % self.test_size
146 | tests = self.test_imgs[self.test_ptr:] + self.test_imgs[:new_ptr]
147 | labels = self.test_labels[self.test_ptr:] + self.test_labels[:new_ptr]
148 | self.test_ptr = new_ptr
149 |
150 | batch_x = np.zeros([batch_size, self.heigth*self.width])
151 | batch_y = np.zeros([batch_size, self.max_captcha*self.char_set_len])
152 |
153 | for index, test in enumerate(tests):
154 | img = np.mean(cv2.imread(self.data_path + test), -1)
155 | # 将多维降维1维
156 | batch_x[index,:] = img.flatten() / 255
157 | for index, label in enumerate(labels):
158 | batch_y[index,:] = self.text2vec(label)
159 |
160 | return batch_x, batch_y
161 |
162 | def text2vec(self, text):
163 | """
164 | 文本转向量
165 | Parameters:
166 | text:文本
167 | Returns:
168 | vector:向量
169 | """
170 | if len(text) > 4:
171 | raise ValueError('验证码最长4个字符')
172 |
173 | vector = np.zeros(4 * self.char_set_len)
174 | def char2pos(c):
175 | if c =='_':
176 | k = 62
177 | return k
178 | k = ord(c) - 48
179 | if k > 9:
180 | k = ord(c) - 55
181 | if k > 35:
182 | k = ord(c) - 61
183 | if k > 61:
184 | raise ValueError('No Map')
185 | return k
186 | for i, c in enumerate(text):
187 | idx = i * self.char_set_len + char2pos(c)
188 | vector[idx] = 1
189 | return vector
190 |
191 | def vec2text(self, vec):
192 | """
193 | 向量转文本
194 | Parameters:
195 | vec:向量
196 | Returns:
197 | 文本
198 | """
199 | char_pos = vec.nonzero()[0]
200 | text = []
201 | for i, c in enumerate(char_pos):
202 | char_at_pos = i #c/63
203 | char_idx = c % self.char_set_len
204 | if char_idx < 10:
205 | char_code = char_idx + ord('0')
206 | elif char_idx < 36:
207 | char_code = char_idx - 10 + ord('A')
208 | elif char_idx < 62:
209 | char_code = char_idx - 36 + ord('a')
210 | elif char_idx == 62:
211 | char_code = ord('_')
212 | else:
213 | raise ValueError('error')
214 | text.append(chr(char_code))
215 | return "".join(text)
216 |
217 | def crack_captcha_cnn(self, w_alpha=0.01, b_alpha=0.1):
218 | """
219 | 定义CNN
220 | Parameters:
221 | w_alpha:权重系数
222 | b_alpha:偏置系数
223 | Returns:
224 | out:CNN输出
225 | """
226 | # 卷积的input: 一个Tensor。数据维度是四维[batch, in_height, in_width, in_channels]
227 | # 具体含义是[batch大小, 图像高度, 图像宽度, 图像通道数]
228 | # 因为是灰度图,所以是单通道的[?, 100, 30, 1]
229 | x = tf.reshape(self.X, shape=[-1, self.heigth, self.width, 1])
230 | # 卷积的filter:一个Tensor。数据维度是四维[filter_height, filter_width, in_channels, out_channels]
231 | # 具体含义是[卷积核的高度, 卷积核的宽度, 图像通道数, 卷积核个数]
232 | w_c1 = tf.Variable(w_alpha*tf.random_normal([3, 3, 1, 32]))
233 | # 偏置项bias
234 | b_c1 = tf.Variable(b_alpha*tf.random_normal([32]))
235 | # conv2d卷积层输入:
236 | # strides: 一个长度是4的一维整数类型数组,每一维度对应的是 input 中每一维的对应移动步数
237 | # padding:一个字符串,取值为 SAME 或者 VALID 前者使得卷积后图像尺寸不变, 后者尺寸变化
238 | # conv2d卷积层输出:
239 | # 一个四维的Tensor, 数据维度为 [batch, out_width, out_height, in_channels * out_channels]
240 | # [?, 100, 30, 32]
241 | # 输出计算公式H0 = (H - F + 2 * P) / S + 1
242 | # 对于本卷积层而言,因为padding为SAME,所以P为1。
243 | # 其中H为图像高度,F为卷积核高度,P为边填充,S为步长
244 | # 学习参数:
245 | # 32*(3*3+1)=320
246 | # 连接个数:
247 | # (输出图像宽度*输出图像高度)(卷积核高度*卷积核宽度+1)*卷积核数量(100*30)(3*3+1)*32=100*30*320=960000个
248 |
249 | # bias_add:将偏差项bias加到value上。这个操作可以看做是tf.add的一个特例,其中bias是必须的一维。
250 | # 该API支持广播形式,因此value可以是任何维度。但是,该API又不像tf.add,可以让bias的维度和value的最后一维不同,
251 | conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
252 | # max_pool池化层输入:
253 | # ksize:池化窗口的大小,取一个四维向量,一般是[1, height, width, 1]
254 | # 因为我们不想在batch和channels上做池化,所以这两个维度设为了1
255 | # strides:和卷积类似,窗口在每一个维度上滑动的步长,一般也是[1, stride,stride, 1]
256 | # padding:和卷积类似,可以取'VALID' 或者'SAME'
257 | # max_pool池化层输出:
258 | # 返回一个Tensor,类型不变,shape仍然是[batch, out_width, out_height, in_channels]这种形式
259 | # [?, 50, 15, 32]
260 | # 学习参数:
261 | # 2*32
262 | # 连接个数:
263 | # 15*50*32*(2*2+1)=120000
264 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
265 | # dropout层
266 | # conv1 = tf.nn.dropout(conv1, self.keep_prob)
267 | w_c2 = tf.Variable(w_alpha*tf.random_normal([3, 3, 32, 64]))
268 | b_c2 = tf.Variable(b_alpha*tf.random_normal([64]))
269 | # [?, 50, 15, 64]
270 | conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
271 | # [?, 25, 8, 64]
272 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
273 | #conv2 = tf.nn.dropout(conv2, self.keep_prob)
274 | w_c3 = tf.Variable(w_alpha*tf.random_normal([3, 3, 64, 64]))
275 | b_c3 = tf.Variable(b_alpha*tf.random_normal([64]))
276 | # [?, 25, 8, 64]
277 | conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
278 | # [?, 13, 4, 64]
279 | conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
280 | #conv3 = tf.nn.dropout(conv3, self.keep_prob)
281 | # [3328, 1024]
282 | w_d = tf.Variable(w_alpha*tf.random_normal([4*13*64, 1024]))
283 | b_d = tf.Variable(b_alpha*tf.random_normal([1024]))
284 | # [?, 3328]
285 | dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
286 | # [?, 1024]
287 | dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
288 | dense = tf.nn.dropout(dense, self.keep_prob)
289 | # [1024, 63*4=252]
290 | w_out = tf.Variable(w_alpha*tf.random_normal([1024, self.max_captcha*self.char_set_len]))
291 |
292 | b_out = tf.Variable(b_alpha*tf.random_normal([self.max_captcha*self.char_set_len]))
293 | # [?, 252]
294 | out = tf.add(tf.matmul(dense, w_out), b_out)
295 | # out = tf.nn.softmax(out)
296 | return out
297 |
298 | def train_crack_captcha_cnn(self):
299 | """
300 | 训练函数
301 | """
302 | output = self.crack_captcha_cnn()
303 |
304 | # 创建损失函数
305 | # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=self.Y))
306 | diff = tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=self.Y)
307 | loss = tf.reduce_mean(diff)
308 | tf.summary.scalar('loss', loss)
309 |
310 | # 使用AdamOptimizer优化器训练模型,最小化交叉熵损失
311 | optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
312 |
313 | # 计算准确率
314 | y = tf.reshape(output, [-1, self.max_captcha, self.char_set_len])
315 | y_ = tf.reshape(self.Y, [-1, self.max_captcha, self.char_set_len])
316 | correct_pred = tf.equal(tf.argmax(y, 2), tf.argmax(y_, 2))
317 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
318 | tf.summary.scalar('accuracy', accuracy)
319 |
320 | merged = tf.summary.merge_all()
321 | saver = tf.train.Saver()
322 | with tf.Session(config=self.config) as sess:
323 | # 写到指定的磁盘路径中
324 | train_writer = tf.summary.FileWriter(self.log_dir + '/train', sess.graph)
325 | test_writer = tf.summary.FileWriter(self.log_dir + '/test')
326 | sess.run(tf.global_variables_initializer())
327 |
328 | # 遍历self.max_steps次
329 | for i in range(self.max_steps):
330 | # 迭代500次,打乱一下数据集
331 | if i % 499 == 0:
332 | self.test_imgs, self.test_labels, self.train_imgs, self.train_labels = self.get_imgs()
333 | # 每10次,使用测试集,测试一下准确率
334 | if i % 10 == 0:
335 | batch_x_test, batch_y_test = self.get_next_batch(False, 100)
336 | summary, acc = sess.run([merged, accuracy], feed_dict={self.X: batch_x_test, self.Y: batch_y_test, self.keep_prob: 1})
337 | print('迭代第%d次 accuracy:%f' % (i+1, acc))
338 | test_writer.add_summary(summary, i)
339 |
340 | # 如果准确率大于85%,则保存模型并退出。
341 | if acc > 0.95:
342 | train_writer.close()
343 | test_writer.close()
344 | saver.save(sess, "crack_capcha.model", global_step=i)
345 | break
346 | # 一直训练
347 | else:
348 | batch_x, batch_y = self.get_next_batch(True, 100)
349 | loss_value, _ = sess.run([loss, optimizer], feed_dict={self.X: batch_x, self.Y: batch_y, self.keep_prob: 1})
350 | print('迭代第%d次 loss:%f' % (i+1, loss_value))
351 | curve = sess.run(merged, feed_dict={self.X: batch_x_test, self.Y: batch_y_test, self.keep_prob: 1})
352 | train_writer.add_summary(curve, i)
353 |
354 | train_writer.close()
355 | test_writer.close()
356 | saver.save(sess, "crack_capcha.model", global_step=self.max_steps)
357 |
358 |
359 | if __name__ == '__main__':
360 | dz = Discuz()
361 | dz.train_crack_captcha_cnn()
362 |
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-1/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 | """Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
3 |
4 | import torch.nn.functional as F
5 |
6 | from unet_parts import *
7 |
8 |
9 | class UNet(nn.Module):
10 | def __init__(self, n_channels, n_classes, bilinear=False):
11 | super(UNet, self).__init__()
12 | self.n_channels = n_channels
13 | self.n_classes = n_classes
14 | self.bilinear = bilinear
15 |
16 | self.inc = DoubleConv(n_channels, 64)
17 | self.down1 = Down(64, 128)
18 | self.down2 = Down(128, 256)
19 | self.down3 = Down(256, 512)
20 | self.down4 = Down(512, 1024)
21 | self.up1 = Up(1024, 512, bilinear)
22 | self.up2 = Up(512, 256, bilinear)
23 | self.up3 = Up(256, 128, bilinear)
24 | self.up4 = Up(128, 64, bilinear)
25 | self.outc = OutConv(64, n_classes)
26 |
27 | def forward(self, x):
28 | x1 = self.inc(x)
29 | x2 = self.down1(x1)
30 | x3 = self.down2(x2)
31 | x4 = self.down3(x3)
32 | x5 = self.down4(x4)
33 | x = self.up1(x5, x4)
34 | x = self.up2(x, x3)
35 | x = self.up3(x, x2)
36 | x = self.up4(x, x1)
37 | logits = self.outc(x)
38 | return logits
39 |
40 | if __name__ == '__main__':
41 | net = UNet(n_channels=3, n_classes=1)
42 | print(net)
43 |
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-1/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 | """https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class DoubleConv(nn.Module):
10 | """(convolution => [BN] => ReLU) * 2"""
11 |
12 | def __init__(self, in_channels, out_channels):
13 | super().__init__()
14 | self.double_conv = nn.Sequential(
15 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
19 | nn.BatchNorm2d(out_channels),
20 | nn.ReLU(inplace=True)
21 | )
22 |
23 | def forward(self, x):
24 | return self.double_conv(x)
25 |
26 |
27 | class Down(nn.Module):
28 | """Downscaling with maxpool then double conv"""
29 |
30 | def __init__(self, in_channels, out_channels):
31 | super().__init__()
32 | self.maxpool_conv = nn.Sequential(
33 | nn.MaxPool2d(2),
34 | DoubleConv(in_channels, out_channels)
35 | )
36 |
37 | def forward(self, x):
38 | return self.maxpool_conv(x)
39 |
40 |
41 | class Up(nn.Module):
42 | """Upscaling then double conv"""
43 |
44 | def __init__(self, in_channels, out_channels, bilinear=True):
45 | super().__init__()
46 |
47 | # if bilinear, use the normal convolutions to reduce the number of channels
48 | if bilinear:
49 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
50 | else:
51 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
52 |
53 | self.conv = DoubleConv(in_channels, out_channels)
54 |
55 | def forward(self, x1, x2):
56 | x1 = self.up(x1)
57 | # input is CHW
58 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
59 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
60 |
61 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
62 | diffY // 2, diffY - diffY // 2])
63 | # if you have padding issues, see
64 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
65 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
66 | x = torch.cat([x2, x1], dim=1)
67 | return self.conv(x)
68 |
69 |
70 | class OutConv(nn.Module):
71 | def __init__(self, in_channels, out_channels):
72 | super(OutConv, self).__init__()
73 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
74 |
75 | def forward(self, x):
76 | return self.conv(x)
77 |
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/best_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/best_model.pth
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/0.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/1.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/10.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/11.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/12.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/13.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/14.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/15.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/16.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/17.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/18.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/19.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/2.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/20.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/21.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/22.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/23.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/24.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/25.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/26.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/27.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/28.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/29.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/3.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/4.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/5.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/6.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/7.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/8.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/test/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/test/9.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/0.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/1.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/10.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/11.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/12.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/13.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/14.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/15.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/16.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/17.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/18.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/19.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/2.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/20.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/21.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/22.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/23.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/24.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/25.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/26.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/27.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/28.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/29.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/3.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/4.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/5.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/6.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/7.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/8.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/image/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/image/9.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/0.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/1.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/10.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/11.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/12.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/13.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/14.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/15.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/16.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/17.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/18.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/19.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/2.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/20.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/21.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/22.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/23.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/24.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/25.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/26.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/27.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/28.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/29.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/3.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/4.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/5.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/6.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/7.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/8.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/data/train/label/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/data/train/label/9.png
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/model/__init__.py
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/model/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 | """Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
3 |
4 | import torch.nn.functional as F
5 |
6 | from .unet_parts import *
7 |
8 |
9 | class UNet(nn.Module):
10 | def __init__(self, n_channels, n_classes, bilinear=True):
11 | super(UNet, self).__init__()
12 | self.n_channels = n_channels
13 | self.n_classes = n_classes
14 | self.bilinear = bilinear
15 |
16 | self.inc = DoubleConv(n_channels, 64)
17 | self.down1 = Down(64, 128)
18 | self.down2 = Down(128, 256)
19 | self.down3 = Down(256, 512)
20 | self.down4 = Down(512, 512)
21 | self.up1 = Up(1024, 256, bilinear)
22 | self.up2 = Up(512, 128, bilinear)
23 | self.up3 = Up(256, 64, bilinear)
24 | self.up4 = Up(128, 64, bilinear)
25 | self.outc = OutConv(64, n_classes)
26 |
27 | def forward(self, x):
28 | x1 = self.inc(x)
29 | x2 = self.down1(x1)
30 | x3 = self.down2(x2)
31 | x4 = self.down3(x3)
32 | x5 = self.down4(x4)
33 | x = self.up1(x5, x4)
34 | x = self.up2(x, x3)
35 | x = self.up3(x, x2)
36 | x = self.up4(x, x1)
37 | logits = self.outc(x)
38 | return logits
39 |
40 | if __name__ == '__main__':
41 | net = UNet(n_channels=3, n_classes=1)
42 | print(net)
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/model/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 | """https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class DoubleConv(nn.Module):
10 | """(convolution => [BN] => ReLU) * 2"""
11 |
12 | def __init__(self, in_channels, out_channels):
13 | super().__init__()
14 | self.double_conv = nn.Sequential(
15 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
19 | nn.BatchNorm2d(out_channels),
20 | nn.ReLU(inplace=True)
21 | )
22 |
23 | def forward(self, x):
24 | return self.double_conv(x)
25 |
26 |
27 | class Down(nn.Module):
28 | """Downscaling with maxpool then double conv"""
29 |
30 | def __init__(self, in_channels, out_channels):
31 | super().__init__()
32 | self.maxpool_conv = nn.Sequential(
33 | nn.MaxPool2d(2),
34 | DoubleConv(in_channels, out_channels)
35 | )
36 |
37 | def forward(self, x):
38 | return self.maxpool_conv(x)
39 |
40 |
41 | class Up(nn.Module):
42 | """Upscaling then double conv"""
43 |
44 | def __init__(self, in_channels, out_channels, bilinear=True):
45 | super().__init__()
46 |
47 | # if bilinear, use the normal convolutions to reduce the number of channels
48 | if bilinear:
49 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
50 | else:
51 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
52 |
53 | self.conv = DoubleConv(in_channels, out_channels)
54 |
55 | def forward(self, x1, x2):
56 | x1 = self.up(x1)
57 | # input is CHW
58 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
59 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
60 |
61 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
62 | diffY // 2, diffY - diffY // 2])
63 |
64 | x = torch.cat([x2, x1], dim=1)
65 | return self.conv(x)
66 |
67 |
68 | class OutConv(nn.Module):
69 | def __init__(self, in_channels, out_channels):
70 | super(OutConv, self).__init__()
71 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
72 |
73 | def forward(self, x):
74 | return self.conv(x)
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/predict.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | import torch
4 | import os
5 | import cv2
6 | from model.unet_model import UNet
7 |
8 | if __name__ == "__main__":
9 | # 选择设备,有cuda用cuda,没有就用cpu
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 | # 加载网络,图片单通道,分类为1。
12 | net = UNet(n_channels=1, n_classes=1)
13 | # 将网络拷贝到deivce中
14 | net.to(device=device)
15 | # 加载模型参数
16 | net.load_state_dict(torch.load('best_model.pth', map_location=device))
17 | # 测试模式
18 | net.eval()
19 | # 读取所有图片路径
20 | tests_path = glob.glob('data/test/*.png')
21 | # 遍历素有图片
22 | for test_path in tests_path:
23 | # 保存结果地址
24 | save_res_path = test_path.split('.')[0] + '_res.png'
25 | # 读取图片
26 | img = cv2.imread(test_path)
27 | # 转为灰度图
28 | img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
29 | # 转为batch为1,通道为1,大小为512*512的数组
30 | img = img.reshape(1, 1, img.shape[0], img.shape[1])
31 | # 转为tensor
32 | img_tensor = torch.from_numpy(img)
33 | # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
34 | img_tensor = img_tensor.to(device=device, dtype=torch.float32)
35 | # 预测
36 | pred = net(img_tensor)
37 | # 提取结果
38 | pred = np.array(pred.data.cpu()[0])[0]
39 | # 处理结果
40 | pred[pred >= 0.5] = 255
41 | pred[pred < 0.5] = 0
42 | # 保存图片
43 | cv2.imwrite(save_res_path, pred)
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/train.py:
--------------------------------------------------------------------------------
1 | from model.unet_model import UNet
2 | from utils.dataset import ISBI_Loader
3 | from torch import optim
4 | import torch.nn as nn
5 | import torch
6 |
7 | def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
8 | # 加载训练集
9 | isbi_dataset = ISBI_Loader(data_path)
10 | train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
11 | batch_size=batch_size,
12 | shuffle=True)
13 | # 定义RMSprop算法
14 | optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
15 | # 定义Loss算法
16 | criterion = nn.BCEWithLogitsLoss()
17 | # best_loss统计,初始化为正无穷
18 | best_loss = float('inf')
19 | # 训练epochs次
20 | for epoch in range(epochs):
21 | # 训练模式
22 | net.train()
23 | # 按照batch_size开始训练
24 | for image, label in train_loader:
25 | optimizer.zero_grad()
26 | # 将数据拷贝到device中
27 | image = image.to(device=device, dtype=torch.float32)
28 | label = label.to(device=device, dtype=torch.float32)
29 | # 使用网络参数,输出预测结果
30 | pred = net(image)
31 | # 计算loss
32 | loss = criterion(pred, label)
33 | print('Loss/train', loss.item())
34 | # 保存loss值最小的网络参数
35 | if loss < best_loss:
36 | best_loss = loss
37 | torch.save(net.state_dict(), 'best_model.pth')
38 | # 更新参数
39 | loss.backward()
40 | optimizer.step()
41 |
42 | if __name__ == "__main__":
43 | # 选择设备,有cuda用cuda,没有就用cpu
44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45 | # 加载网络,图片单通道1,分类为1。
46 | net = UNet(n_channels=1, n_classes=1)
47 | # 将网络拷贝到deivce中
48 | net.to(device=device)
49 | # 指定训练集地址,开始训练
50 | data_path = "data/train/"
51 | train_net(net, device, data_path)
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/utils/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/5fd254b61ad45367fbae28c49976e82b14ff7110/Pytorch-Seg/lesson-2/utils/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-2/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import os
4 | import glob
5 | from torch.utils.data import Dataset
6 | import random
7 |
8 | class ISBI_Loader(Dataset):
9 | def __init__(self, data_path):
10 | # 初始化函数,读取所有data_path下的图片
11 | self.data_path = data_path
12 | self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
13 |
14 | def augment(self, image, flipCode):
15 | # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
16 | flip = cv2.flip(image, flipCode)
17 | return flip
18 |
19 | def __getitem__(self, index):
20 | # 根据index读取图片
21 | image_path = self.imgs_path[index]
22 | # 根据image_path生成label_path
23 | label_path = image_path.replace('image', 'label')
24 | # 读取训练图片和标签图片
25 | image = cv2.imread(image_path)
26 | label = cv2.imread(label_path)
27 | # 将数据转为单通道的图片
28 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
29 | label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
30 | image = image.reshape(1, image.shape[0], image.shape[1])
31 | label = label.reshape(1, label.shape[0], label.shape[1])
32 | # 处理标签,将像素值为255的改为1
33 | if label.max() > 1:
34 | label = label / 255
35 | # 随机进行数据增强,为2时不做处理
36 | flipCode = random.choice([-1, 0, 1, 2])
37 | if flipCode != 2:
38 | image = self.augment(image, flipCode)
39 | label = self.augment(label, flipCode)
40 | return image, label
41 |
42 | def __len__(self):
43 | # 返回训练集大小
44 | return len(self.imgs_path)
45 |
46 |
47 | if __name__ == "__main__":
48 | isbi_dataset = ISBI_Loader("data/train/")
49 | print("数据个数:", len(isbi_dataset))
50 | train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
51 | batch_size=2,
52 | shuffle=True)
53 | for image, label in train_loader:
54 | print(image.shape)
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-3/log.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | class Logger():
4 | def __init__(self, filename="log.txt"):
5 | self.terminal = sys.stdout
6 | self.log = open(filename, "w")
7 |
8 | def write(self, message):
9 | self.terminal.write(message)
10 | self.log.write(message)
11 |
12 | def flush(self):
13 | pass
14 |
15 | sys.stdout = Logger()
16 |
17 | print("Jack Cui")
18 | print("https://cuijiahua.com")
19 | print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")
20 |
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-3/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | def get_logger(LEVEL, log_file = None):
4 | head = '[%(asctime)-15s] [%(levelname)s] %(message)s'
5 | if LEVEL == 'info':
6 | logging.basicConfig(level=logging.INFO, format=head)
7 | elif LEVEL == 'debug':
8 | logging.basicConfig(level=logging.DEBUG, format=head)
9 | logger = logging.getLogger()
10 | if log_file != None:
11 | fh = logging.FileHandler(log_file)
12 | logger.addHandler(fh)
13 | return logger
14 |
15 | logger = get_logger('info')
16 |
17 | logger.info('Jack Cui')
18 | logger.info('https://cuijiahua.com')
19 | logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-3/show_loss.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | # Jupyter notebook 中开启
3 | # %matplotlib inline
4 | with open('train_loss.txt', 'r') as f:
5 | train_loss = f.readlines()
6 | train_loss = list(map(lambda x:float(x.strip()), train_loss))
7 | x = range(len(train_loss))
8 | y = train_loss
9 | plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
10 | plt.xlabel('Epoch')
11 | plt.ylabel('Loss Value')
12 | plt.legend()
13 | plt.show()
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-3/tensorboardX_test.py:
--------------------------------------------------------------------------------
1 | from tensorboardX import SummaryWriter
2 | from urllib.request import urlretrieve
3 | import cv2
4 |
5 | # 选择运行那个示例
6 | choose_example = 1
7 |
8 | if choose_example == 1:
9 | """
10 | Example 1:创建 writer 示例
11 | """
12 | # 创建 writer1 对象
13 | # log 会保存到 runs/exp 文件夹中
14 | writer1 = SummaryWriter('runs/exp')
15 |
16 | # 使用默认参数创建 writer2 对象
17 | # log 会保存到 runs/日期_用户名 格式的文件夹中
18 | writer2 = SummaryWriter()
19 |
20 | # 使用 commet 参数,创建 writer3 对象
21 | # log 会保存到 runs/日期_用户名_resnet 格式的文件中
22 | writer3 = SummaryWriter(comment='_resnet')
23 |
24 | if choose_example == 2:
25 | """
26 | Example 2:写入数字示例
27 | """
28 | writer = SummaryWriter('runs/scalar_example')
29 | for i in range(10):
30 | writer.add_scalar('quadratic', i**2, global_step=i)
31 | writer.add_scalar('exponential', 2**i, global_step=i)
32 | writer.close()
33 |
34 | if choose_example == 3:
35 | """
36 | Example 3:写入图片示例
37 | """
38 | urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')
39 | urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')
40 | urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg')
41 |
42 | writer = SummaryWriter('runs/image_example')
43 | for i in range(1, 4):
44 | writer.add_image('UNet_Seg',
45 | cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),
46 | global_step=i,
47 | dataformats='HWC')
48 | writer.close()
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-3/train_loss.txt:
--------------------------------------------------------------------------------
1 | 32.464943
2 | 17.410048
3 | 16.051996
4 | 15.255561
5 | 14.547606
6 | 13.847259
7 | 13.449913
8 | 12.980925
9 | 12.948830
10 | 12.398992
11 | 12.485570
12 | 12.089432
13 | 12.118484
14 | 12.021638
15 | 11.817263
16 | 11.644099
17 | 11.417233
18 | 11.549695
19 | 11.223548
20 | 11.172435
21 | 11.027787
22 | 10.939758
23 | 10.666803
24 | 10.993714
25 | 10.574224
26 | 10.658235
27 | 10.631421
28 | 10.498351
29 | 10.557507
30 | 10.502128
31 | 10.543790
32 | 10.523225
33 | 10.231854
34 | 10.398646
35 | 10.406532
36 | 10.283625
37 | 10.105809
38 | 9.987217
39 | 9.936296
40 | 9.876533
41 | 9.953513
42 | 9.899665
43 | 9.926085
44 | 9.877600
45 | 9.829120
46 | 9.865887
47 | 9.770892
48 | 9.576312
49 | 9.615096
50 | 9.722373
51 | 9.715674
52 | 9.644127
53 | 9.581133
54 | 9.565999
55 | 9.459929
56 | 9.518677
57 | 9.321252
58 | 9.382160
59 | 9.545680
60 | 8.467113
61 | 8.369641
62 | 8.301431
63 | 8.306873
64 | 8.244370
65 | 8.223052
66 | 8.215305
67 | 8.191195
68 | 8.174629
69 | 8.184194
70 | 8.139848
71 | 8.143331
72 | 8.107319
73 | 8.110783
74 | 8.083336
75 | 8.056860
76 | 8.053325
77 | 8.038514
78 | 8.047304
79 | 8.027021
80 | 7.909974
81 | 7.896411
82 | 7.891089
83 | 7.892738
84 | 7.902834
85 | 7.896441
86 | 7.903152
87 | 7.878296
88 | 7.888803
89 | 7.879333
90 | 7.881098
91 | 7.868000
92 | 7.871295
93 | 7.887029
94 | 7.880289
95 | 7.863110
96 | 7.889467
97 | 7.876264
98 | 7.871953
99 | 7.869154
100 | 7.860284
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-4/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import os
4 | import glob
5 | from torch.utils.data import Dataset
6 | import random
7 | import torchvision.transforms as transforms
8 | from PIL import ImageFile
9 | ImageFile.LOAD_TRUNCATED_IMAGES = True
10 |
11 | class Garbage_Loader(Dataset):
12 | def __init__(self, txt_path, train_flag=True):
13 | self.imgs_info = self.get_images(txt_path)
14 | self.train_flag = train_flag
15 |
16 | self.train_tf = transforms.Compose([
17 | transforms.Resize(224),
18 | transforms.RandomHorizontalFlip(),
19 | transforms.RandomVerticalFlip(),
20 | transforms.ToTensor(),
21 |
22 | ])
23 | self.val_tf = transforms.Compose([
24 | transforms.Resize(224),
25 | transforms.ToTensor(),
26 | ])
27 |
28 | def get_images(self, txt_path):
29 | with open(txt_path, 'r', encoding='utf-8') as f:
30 | imgs_info = f.readlines()
31 | imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
32 | return imgs_info
33 |
34 | def padding_black(self, img):
35 |
36 | w, h = img.size
37 |
38 | scale = 224. / max(w, h)
39 | img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
40 |
41 | size_fg = img_fg.size
42 | size_bg = 224
43 |
44 | img_bg = Image.new("RGB", (size_bg, size_bg))
45 |
46 | img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
47 | (size_bg - size_fg[1]) // 2))
48 |
49 | img = img_bg
50 | return img
51 |
52 | def __getitem__(self, index):
53 | img_path, label = self.imgs_info[index]
54 | img = Image.open(img_path)
55 | img = img.convert('RGB')
56 | img = self.padding_black(img)
57 | if self.train_flag:
58 | img = self.train_tf(img)
59 | else:
60 | img = self.val_tf(img)
61 | label = int(label)
62 |
63 | return img, label
64 |
65 | def __len__(self):
66 | return len(self.imgs_info)
67 |
68 |
69 | if __name__ == "__main__":
70 | train_dataset = Garbage_Loader("train.txt", True)
71 | print("数据个数:", len(train_dataset))
72 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
73 | batch_size=1,
74 | shuffle=True)
75 | for image, label in train_loader:
76 | print(image.shape)
77 | print(label)
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-4/dir_label.txt:
--------------------------------------------------------------------------------
1 | 其他垃圾_PE塑料袋 0 0
2 | 其他垃圾_U型回形针 1 0
3 | 其他垃圾_一次性杯子 2 0
4 | 其他垃圾_一次性棉签 3 0
5 | 其他垃圾_串串竹签 4 0
6 | 其他垃圾_便利贴 5 0
7 | 其他垃圾_创可贴 6 0
8 | 其他垃圾_厨房手套 7 0
9 | 其他垃圾_口罩 8 0
10 | 其他垃圾_唱片 9 0
11 | 其他垃圾_图钉 10 0
12 | 其他垃圾_大龙虾头 11 0
13 | 其他垃圾_奶茶杯 12 0
14 | 其他垃圾_干果壳 13 0
15 | 其他垃圾_干燥剂 14 0
16 | 其他垃圾_打泡网 15 0
17 | 其他垃圾_打火机 16 0
18 | 其他垃圾_放大镜 17 0
19 | 其他垃圾_毛巾 18 0
20 | 其他垃圾_涂改带 19 0
21 | 其他垃圾_湿纸巾 20 0
22 | 其他垃圾_烟蒂 21 0
23 | 其他垃圾_牙刷 22 0
24 | 其他垃圾_百洁布 23 0
25 | 其他垃圾_眼镜 24 0
26 | 其他垃圾_票据 25 0
27 | 其他垃圾_空调滤芯 26 0
28 | 其他垃圾_笔及笔芯 27 0
29 | 其他垃圾_纸巾 28 0
30 | 其他垃圾_胶带 29 0
31 | 其他垃圾_胶水废包装 30 0
32 | 其他垃圾_苍蝇拍 31 0
33 | 其他垃圾_茶壶碎片 32 0
34 | 其他垃圾_餐盒 33 0
35 | 其他垃圾_验孕棒 34 0
36 | 其他垃圾_鸡毛掸 35 0
37 | 厨余垃圾_八宝粥 36 1
38 | 厨余垃圾_冰糖葫芦 37 1
39 | 厨余垃圾_咖啡渣 38 1
40 | 厨余垃圾_哈密瓜 39 1
41 | 厨余垃圾_圣女果 40 1
42 | 厨余垃圾_巴旦木 41 1
43 | 厨余垃圾_开心果 42 1
44 | 厨余垃圾_普通面包 43 1
45 | 厨余垃圾_板栗 44 1
46 | 厨余垃圾_果冻 45 1
47 | 厨余垃圾_核桃 46 1
48 | 厨余垃圾_梨 47 1
49 | 厨余垃圾_橙子 48 1
50 | 厨余垃圾_残渣剩饭 49 1
51 | 厨余垃圾_汉堡 50 1
52 | 厨余垃圾_火龙果 51 1
53 | 厨余垃圾_炸鸡 52 1
54 | 厨余垃圾_烤鸡烤鸭 53 1
55 | 厨余垃圾_牛肉干 54 1
56 | 厨余垃圾_瓜子 55 1
57 | 厨余垃圾_甘蔗 56 1
58 | 厨余垃圾_生肉 57 1
59 | 厨余垃圾_番茄 58 1
60 | 厨余垃圾_白菜 59 1
61 | 厨余垃圾_白萝卜 60 1
62 | 厨余垃圾_粉条 61 1
63 | 厨余垃圾_糕点 62 1
64 | 厨余垃圾_红豆 63 1
65 | 厨余垃圾_肠(火腿) 64 1
66 | 厨余垃圾_胡萝卜 65 1
67 | 厨余垃圾_花生皮 66 1
68 | 厨余垃圾_苹果 67 1
69 | 厨余垃圾_茶叶 68 1
70 | 厨余垃圾_草莓 69 1
71 | 厨余垃圾_荷包蛋 70 1
72 | 厨余垃圾_菠萝 71 1
73 | 厨余垃圾_菠萝包 72 1
74 | 厨余垃圾_菠萝蜜 73 1
75 | 厨余垃圾_蒜 74 1
76 | 厨余垃圾_薯条 75 1
77 | 厨余垃圾_蘑菇 76 1
78 | 厨余垃圾_蚕豆 77 1
79 | 厨余垃圾_蛋 78 1
80 | 厨余垃圾_蛋挞 79 1
81 | 厨余垃圾_西瓜皮 80 1
82 | 厨余垃圾_贝果 81 1
83 | 厨余垃圾_辣椒 82 1
84 | 厨余垃圾_陈皮 83 1
85 | 厨余垃圾_青菜 84 1
86 | 厨余垃圾_饼干 85 1
87 | 厨余垃圾_香蕉皮 86 1
88 | 厨余垃圾_骨肉相连 87 1
89 | 厨余垃圾_鸡翅 88 1
90 | 可回收物_乒乓球拍 89 2
91 | 可回收物_书 90 2
92 | 可回收物_保温杯 91 2
93 | 可回收物_保鲜盒 92 2
94 | 可回收物_信封 93 2
95 | 可回收物_充电头 94 2
96 | 可回收物_充电宝 95 2
97 | 可回收物_充电线 96 2
98 | 可回收物_八宝粥罐 97 2
99 | 可回收物_刀 98 2
100 | 可回收物_剃须刀片 99 2
101 | 可回收物_剪刀 100 2
102 | 可回收物_勺子 101 2
103 | 可回收物_单肩包手提包 102 2
104 | 可回收物_卡 103 2
105 | 可回收物_叉子 104 2
106 | 可回收物_变形玩具 105 2
107 | 可回收物_台历 106 2
108 | 可回收物_台灯 107 2
109 | 可回收物_吹风机 108 2
110 | 可回收物_呼啦圈 109 2
111 | 可回收物_地球仪 110 2
112 | 可回收物_地铁票 111 2
113 | 可回收物_垫子 112 2
114 | 可回收物_塑料瓶 113 2
115 | 可回收物_塑料盆 114 2
116 | 可回收物_奶盒 115 2
117 | 可回收物_奶粉罐 116 2
118 | 可回收物_奶粉罐铝盖 117 2
119 | 可回收物_尺子 118 2
120 | 可回收物_帽子 119 2
121 | 可回收物_废弃扩声器 120 2
122 | 可回收物_手提包 121 2
123 | 可回收物_手机 122 2
124 | 可回收物_手电筒 123 2
125 | 可回收物_手链 124 2
126 | 可回收物_打印机墨盒 125 2
127 | 可回收物_打气筒 126 2
128 | 可回收物_护肤品空瓶 127 2
129 | 可回收物_报纸 128 2
130 | 可回收物_拖鞋 129 2
131 | 可回收物_插线板 130 2
132 | 可回收物_搓衣板 131 2
133 | 可回收物_收音机 132 2
134 | 可回收物_放大镜 133 2
135 | 可回收物_易拉罐 134 2
136 | 可回收物_暖宝宝 135 2
137 | 可回收物_望远镜 136 2
138 | 可回收物_木制切菜板 137 2
139 | 可回收物_木制玩具 138 2
140 | 可回收物_木质梳子 139 2
141 | 可回收物_木质锅铲 140 2
142 | 可回收物_枕头 141 2
143 | 可回收物_档案袋 142 2
144 | 可回收物_水杯 143 2
145 | 可回收物_泡沫盒子 144 2
146 | 可回收物_灯罩 145 2
147 | 可回收物_烟灰缸 146 2
148 | 可回收物_烧水壶 147 2
149 | 可回收物_热水瓶 148 2
150 | 可回收物_玩偶 149 2
151 | 可回收物_玻璃器皿 150 2
152 | 可回收物_玻璃壶 151 2
153 | 可回收物_玻璃球 152 2
154 | 可回收物_电动剃须刀 153 2
155 | 可回收物_电动卷发棒 154 2
156 | 可回收物_电动牙刷 155 2
157 | 可回收物_电熨斗 156 2
158 | 可回收物_电视遥控器 157 2
159 | 可回收物_电路板 158 2
160 | 可回收物_登机牌 159 2
161 | 可回收物_盘子 160 2
162 | 可回收物_碗 161 2
163 | 可回收物_空气加湿器 162 2
164 | 可回收物_空调遥控器 163 2
165 | 可回收物_纸牌 164 2
166 | 可回收物_纸箱 165 2
167 | 可回收物_罐头瓶 166 2
168 | 可回收物_网卡 167 2
169 | 可回收物_耳套 168 2
170 | 可回收物_耳机 169 2
171 | 可回收物_耳钉耳环 170 2
172 | 可回收物_芭比娃娃 171 2
173 | 可回收物_茶叶罐 172 2
174 | 可回收物_蛋糕盒 173 2
175 | 可回收物_螺丝刀 174 2
176 | 可回收物_衣架 175 2
177 | 可回收物_袜子 176 2
178 | 可回收物_裤子 177 2
179 | 可回收物_计算器 178 2
180 | 可回收物_订书机 179 2
181 | 可回收物_话筒 180 2
182 | 可回收物_购物纸袋 181 2
183 | 可回收物_路由器 182 2
184 | 可回收物_车钥匙 183 2
185 | 可回收物_量杯 184 2
186 | 可回收物_钉子 185 2
187 | 可回收物_钟表 186 2
188 | 可回收物_钢丝球 187 2
189 | 可回收物_锅 188 2
190 | 可回收物_锅盖 189 2
191 | 可回收物_键盘 190 2
192 | 可回收物_镊子 191 2
193 | 可回收物_鞋 192 2
194 | 可回收物_餐垫 193 2
195 | 可回收物_鼠标 194 2
196 | 有害垃圾_LED灯泡 195 3
197 | 有害垃圾_保健品瓶 196 3
198 | 有害垃圾_口服液瓶 197 3
199 | 有害垃圾_指甲油 198 3
200 | 有害垃圾_杀虫剂 199 3
201 | 有害垃圾_温度计 200 3
202 | 有害垃圾_滴眼液瓶 201 3
203 | 有害垃圾_玻璃灯管 202 3
204 | 有害垃圾_电池 203 3
205 | 有害垃圾_电池板 204 3
206 | 有害垃圾_碘伏空瓶 205 3
207 | 有害垃圾_红花油 206 3
208 | 有害垃圾_纽扣电池 207 3
209 | 有害垃圾_胶水 208 3
210 | 有害垃圾_药品包装 209 3
211 | 有害垃圾_药片 210 3
212 | 有害垃圾_药膏 211 3
213 | 有害垃圾_蓄电池 212 3
214 | 有害垃圾_血压计 213 3
215 |
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-4/infer.py:
--------------------------------------------------------------------------------
1 | from dataset import Garbage_Loader
2 | from torch.utils.data import DataLoader
3 | import torchvision.transforms as transforms
4 | from torchvision import models
5 | import torch.nn as nn
6 | import torch
7 | import os
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | #%matplotlib inline
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12 |
13 | def softmax(x):
14 | exp_x = np.exp(x)
15 | softmax_x = exp_x / np.sum(exp_x, 0)
16 | return softmax_x
17 |
18 | with open('dir_label.txt', 'r', encoding='utf-8') as f:
19 | labels = f.readlines()
20 | labels = list(map(lambda x:x.strip().split('\t'), labels))
21 |
22 | if __name__ == "__main__":
23 | test_list = 'test.txt'
24 | test_data = Garbage_Loader(test_list, train_flag=False)
25 | test_loader = DataLoader(dataset=test_data, num_workers=1, pin_memory=True, batch_size=1)
26 | model = models.resnet50(pretrained=False)
27 | fc_inputs = model.fc.in_features
28 | model.fc = nn.Linear(fc_inputs, 214)
29 | model = model.cuda()
30 | # 加载训练好的模型
31 | checkpoint = torch.load('model_best_checkpoint_resnet50.pth.tar')
32 | model.load_state_dict(checkpoint['state_dict'])
33 | model.eval()
34 | for i, (image, label) in enumerate(test_loader):
35 | src = image.numpy()
36 | src = src.reshape(3, 224, 224)
37 | src = np.transpose(src, (1, 2, 0))
38 | image = image.cuda()
39 | label = label.cuda()
40 | pred = model(image)
41 | pred = pred.data.cpu().numpy()[0]
42 | score = softmax(pred)
43 | pred_id = np.argmax(score)
44 | plt.imshow(src)
45 | print('预测结果:', labels[pred_id][0])
46 | plt.show()
47 |
--------------------------------------------------------------------------------
/Pytorch-Seg/lesson-4/train.py:
--------------------------------------------------------------------------------
1 | from dataset import Garbage_Loader
2 | from torch.utils.data import DataLoader
3 | from torchvision import models
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch
7 | import time
8 | import os
9 | import shutil
10 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11 |
12 | """
13 | Author : Jack Cui
14 | Wechat : https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA
15 | """
16 |
17 | from tensorboardX import SummaryWriter
18 |
19 | def accuracy(output, target, topk=(1,)):
20 | """
21 | 计算topk的准确率
22 | """
23 | with torch.no_grad():
24 | maxk = max(topk)
25 | batch_size = target.size(0)
26 |
27 | _, pred = output.topk(maxk, 1, True, True)
28 | pred = pred.t()
29 | correct = pred.eq(target.view(1, -1).expand_as(pred))
30 |
31 | class_to = pred[0].cpu().numpy()
32 |
33 | res = []
34 | for k in topk:
35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
36 | res.append(correct_k.mul_(100.0 / batch_size))
37 | return res, class_to
38 |
39 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
40 | """
41 | 根据 is_best 存模型,一般保存 valid acc 最好的模型
42 | """
43 | torch.save(state, filename)
44 | if is_best:
45 | shutil.copyfile(filename, 'model_best_' + filename)
46 |
47 | def train(train_loader, model, criterion, optimizer, epoch, writer):
48 | """
49 | 训练代码
50 | 参数:
51 | train_loader - 训练集的 DataLoader
52 | model - 模型
53 | criterion - 损失函数
54 | optimizer - 优化器
55 | epoch - 进行第几个 epoch
56 | writer - 用于写 tensorboardX
57 | """
58 | batch_time = AverageMeter()
59 | data_time = AverageMeter()
60 | losses = AverageMeter()
61 | top1 = AverageMeter()
62 | top5 = AverageMeter()
63 |
64 | # switch to train mode
65 | model.train()
66 |
67 | end = time.time()
68 | for i, (input, target) in enumerate(train_loader):
69 | # measure data loading time
70 | data_time.update(time.time() - end)
71 |
72 | input = input.cuda()
73 | target = target.cuda()
74 |
75 | # compute output
76 | output = model(input)
77 | loss = criterion(output, target)
78 |
79 | # measure accuracy and record loss
80 | [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5))
81 | losses.update(loss.item(), input.size(0))
82 | top1.update(prec1[0], input.size(0))
83 | top5.update(prec5[0], input.size(0))
84 |
85 | # compute gradient and do SGD step
86 | optimizer.zero_grad()
87 | loss.backward()
88 | optimizer.step()
89 |
90 | # measure elapsed time
91 | batch_time.update(time.time() - end)
92 | end = time.time()
93 |
94 | if i % 10 == 0:
95 | print('Epoch: [{0}][{1}/{2}]\t'
96 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
97 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
98 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
99 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
100 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
101 | epoch, i, len(train_loader), batch_time=batch_time,
102 | data_time=data_time, loss=losses, top1=top1, top5=top5))
103 | writer.add_scalar('loss/train_loss', losses.val, global_step=epoch)
104 |
105 | def validate(val_loader, model, criterion, epoch, writer, phase="VAL"):
106 | """
107 | 验证代码
108 | 参数:
109 | val_loader - 验证集的 DataLoader
110 | model - 模型
111 | criterion - 损失函数
112 | epoch - 进行第几个 epoch
113 | writer - 用于写 tensorboardX
114 | """
115 | batch_time = AverageMeter()
116 | losses = AverageMeter()
117 | top1 = AverageMeter()
118 | top5 = AverageMeter()
119 |
120 | # switch to evaluate mode
121 | model.eval()
122 |
123 | with torch.no_grad():
124 | end = time.time()
125 | for i, (input, target) in enumerate(val_loader):
126 | input = input.cuda()
127 | target = target.cuda()
128 | # compute output
129 | output = model(input)
130 | loss = criterion(output, target)
131 |
132 | # measure accuracy and record loss
133 | [prec1, prec5], class_to = accuracy(output, target, topk=(1, 5))
134 | losses.update(loss.item(), input.size(0))
135 | top1.update(prec1[0], input.size(0))
136 | top5.update(prec5[0], input.size(0))
137 |
138 | # measure elapsed time
139 | batch_time.update(time.time() - end)
140 | end = time.time()
141 |
142 | if i % 10 == 0:
143 | print('Test-{0}: [{1}/{2}]\t'
144 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
145 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
146 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
147 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
148 | phase, i, len(val_loader),
149 | batch_time=batch_time,
150 | loss=losses,
151 | top1=top1, top5=top5))
152 |
153 | print(' * {} Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
154 | .format(phase, top1=top1, top5=top5))
155 | writer.add_scalar('loss/valid_loss', losses.val, global_step=epoch)
156 | return top1.avg, top5.avg
157 |
158 | class AverageMeter(object):
159 | """Computes and stores the average and current value"""
160 | def __init__(self):
161 | self.reset()
162 |
163 | def reset(self):
164 | self.val = 0
165 | self.avg = 0
166 | self.sum = 0
167 | self.count = 0
168 |
169 | def update(self, val, n=1):
170 | self.val = val
171 | self.sum += val * n
172 | self.count += n
173 | self.avg = self.sum / self.count
174 |
175 | if __name__ == "__main__":
176 | # -------------------------------------------- step 1/4 : 加载数据 ---------------------------
177 | train_dir_list = 'train.txt'
178 | valid_dir_list = 'val.txt'
179 | batch_size = 64
180 | epochs = 80
181 | num_classes = 214
182 | train_data = Garbage_Loader(train_dir_list, train_flag=True)
183 | valid_data = Garbage_Loader(valid_dir_list, train_flag=False)
184 | train_loader = DataLoader(dataset=train_data, num_workers=8, pin_memory=True, batch_size=batch_size, shuffle=True)
185 | valid_loader = DataLoader(dataset=valid_data, num_workers=8, pin_memory=True, batch_size=batch_size)
186 | train_data_size = len(train_data)
187 | print('训练集数量:%d' % train_data_size)
188 | valid_data_size = len(valid_data)
189 | print('验证集数量:%d' % valid_data_size)
190 | # ------------------------------------ step 2/4 : 定义网络 ------------------------------------
191 | model = models.resnet50(pretrained=True)
192 | fc_inputs = model.fc.in_features
193 | model.fc = nn.Linear(fc_inputs, num_classes)
194 | model = model.cuda()
195 | # ------------------------------------ step 3/4 : 定义损失函数和优化器等 -------------------------
196 | lr_init = 0.0001
197 | lr_stepsize = 20
198 | weight_decay = 0.001
199 | criterion = nn.CrossEntropyLoss().cuda()
200 | optimizer = optim.Adam(model.parameters(), lr=lr_init, weight_decay=weight_decay)
201 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_stepsize, gamma=0.1)
202 |
203 | writer = SummaryWriter('runs/resnet50')
204 | # ------------------------------------ step 4/4 : 训练 -----------------------------------------
205 | best_prec1 = 0
206 | for epoch in range(epochs):
207 | scheduler.step()
208 | train(train_loader, model, criterion, optimizer, epoch, writer)
209 | # 在验证集上测试效果
210 | valid_prec1, valid_prec5 = validate(valid_loader, model, criterion, epoch, writer, phase="VAL")
211 | is_best = valid_prec1 > best_prec1
212 | best_prec1 = max(valid_prec1, best_prec1)
213 | save_checkpoint({
214 | 'epoch': epoch + 1,
215 | 'arch': 'resnet50',
216 | 'state_dict': model.state_dict(),
217 | 'best_prec1': best_prec1,
218 | 'optimizer' : optimizer.state_dict(),
219 | }, is_best,
220 | filename='checkpoint_resnet50.pth.tar')
221 | writer.close()
222 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-Learning
2 |
3 | * 贵有恒,何必三更起五更睡;最无益,只怕一日曝十日寒。
4 |
5 | 原创文章每周最少两篇,**后续最新文章**会在[【公众号】](https://cuijiahua.com/wp-content/uploads/2020/05/gzh-w.jpg)首发,视频[【B站】](https://space.bilibili.com/331507846)首发,大家可以加我[【微信】](https://cuijiahua.com/wp-content/uploads/2020/05/gzh-w.jpg)进**交流群**,技术交流或提意见都可以,欢迎**Star**!
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |