├── Keras_TransforLearning ├── README.md └── keras_classifiter.ipynb ├── MXnet_TransforLearning ├── README.md └── mxnet_classifiter.ipynb ├── README.md ├── TransferLearning_reload.py ├── TransforLearning.py └── flower_photos ├── daisy └── 43474673_7bb4465a86.jpg └── roses └── 12240303_80d87f77a3_n.jpg /Keras_TransforLearning/README.md: -------------------------------------------------------------------------------- 1 | ## Keras 实现简单的分类任务 2 | 实际上对于基础的分类任务,不论是迁移学习特征提取器或者时对网络整体进行微调,Keras都可以很轻易地实现。得益于Keras Module的方便接口,迁移学习可以很轻松的设定特征网络参数是否冻结,适应不同的数据集。
3 | 首先做好图片文件夹,形式和tf版本一致:一个大文件夹下存放各个class的文件夹,每个class的文件夹内存放对应class的图片。 4 | 然后,将`keras_classifiter.ipynb`数据生成器块中如下代码中的图片目录指定到自己的图片文件夹, 5 | ``` 6 | train_flow = train_datagen.flow_from_directory( 7 | r'.\猫狗数据', # this is the target directory 8 | target_size=(224, 224), # all images will be resized to 150x150 9 | batch_size=32, 10 | class_mode='categorical') 11 | ``` 12 | 注意,数据生成器块我定义了两种生成方式:自建的生成器或者官方API,选择其一即可,推荐官方,自建的我只考虑了二分类的情况,对于多分类需要自行修改代码。完成了这些工作,按照顺序运行代码块即可(注意数据生成器块2选1即可)。 13 | -------------------------------------------------------------------------------- /Keras_TransforLearning/keras_classifiter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 一、数据生成器\n", 8 | "## 1、自建数据生成器\n", 9 | "自建数据生成器仅支持二分类,建议使用API数据生成器。对于二分类问题,将下面代码中的两个路径正则匹配好即可
\n", 10 | "```\n", 11 | "real_list = glob.glob(r\".\\Hotdog\\train\\hotdog/*.png\") \n", 12 | "glitch_list = glob.glob(r\".\\Hotdog\\train\\not-hotdog/*png\")\n", 13 | "```" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "1000 1000\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "import glob\n", 31 | "\n", 32 | "# real_list = glob.glob(r\".\\train\\1/*.png\")\n", 33 | "# glitch_list = glob.glob(r\".\\train\\0/*png\")\n", 34 | "\n", 35 | "real_list = glob.glob(r\".\\Hotdog\\train\\hotdog/*.png\") \n", 36 | "glitch_list = glob.glob(r\".\\Hotdog\\train\\not-hotdog/*png\")\n", 37 | "\n", 38 | "list_img = real_list+glitch_list\n", 39 | "\n", 40 | "print(len(real_list), len(glitch_list))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "[('.\\\\Hotdog\\\\train\\\\not-hotdog\\\\331.png', array([0, 1])),\n", 53 | " ('.\\\\Hotdog\\\\train\\\\hotdog\\\\282.png', array([1, 0])),\n", 54 | " ('.\\\\Hotdog\\\\train\\\\not-hotdog\\\\8.png', array([0, 1])),\n", 55 | " ('.\\\\Hotdog\\\\train\\\\hotdog\\\\651.png', array([1, 0])),\n", 56 | " ('.\\\\Hotdog\\\\train\\\\hotdog\\\\79.png', array([1, 0])),\n", 57 | " ('.\\\\Hotdog\\\\train\\\\hotdog\\\\697.png', array([1, 0])),\n", 58 | " ('.\\\\Hotdog\\\\train\\\\not-hotdog\\\\599.png', array([0, 1])),\n", 59 | " ('.\\\\Hotdog\\\\train\\\\not-hotdog\\\\404.png', array([0, 1]))]\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "import cv2\n", 65 | "import numpy as np\n", 66 | "import pprint as pp\n", 67 | "\n", 68 | "def get_batch(img_list, label_list, batch_size=8, show=False):\n", 69 | " # 在方法中需要用while写成死循环,因为每个step不会重新调用方法\n", 70 | " while True:\n", 71 | " index = np.random.choice(range(len(img_list)), batch_size)\n", 72 | " batch_list = [img_list[i] for i in index]\n", 73 | " batch_row = [cv2.imread(img) for img in batch_list]\n", 74 | " batch_gray = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in batch_row]\n", 75 | " batch_matrix = [cv2.resize(i,(224, 224),interpolation=cv2.INTER_CUBIC) for i in batch_gray]\n", 76 | " batch_label = [label_list[i] for i in index]\n", 77 | " if show:\n", 78 | " pp.pprint(list(zip(batch_list, batch_label)))\n", 79 | " batch_matrix = np.concatenate(np.expand_dims(batch_matrix, axis=0), axis=0).astype(float)\n", 80 | " batch_label = np.array(batch_label)\n", 81 | " # batch_matrix = np.expand_dims(batch_matrix, -1)\n", 82 | " batch_matrix = (batch_matrix).astype(int)/255.\n", 83 | " yield batch_matrix, batch_label # , batch_list\n", 84 | "\n", 85 | "\n", 86 | "list_lab = np.zeros([len(list_img), 2], dtype=int)\n", 87 | "list_lab[:len(real_list), 0] = 1 # real: [1, 0]\n", 88 | "list_lab[len(real_list):, 1] = 1 # fake: [0, 1]\n", 89 | "b = next(get_batch(list_img, list_lab, show=True))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "import matplotlib.pyplot as plt\n", 101 | "num = np.random.randint(0, len(b))\n", 102 | "img = (b[0][num]*255).astype(np.uint8) \n", 103 | "plt.imshow(img)\n", 104 | "plt.show()\n", 105 | "# print(b[2][num])\n", 106 | "# pp.pprint(list(zip(b[1][:], b[2][:])))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## 2、API数据生成器\n", 114 | "官方的数据读取器,路径指定到图片父级目录,其下有各个class的子目录即可" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 1, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stderr", 124 | "output_type": "stream", 125 | "text": [ 126 | "Using TensorFlow backend.\n" 127 | ] 128 | }, 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "Found 3850 images belonging to 2 classes.\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "from keras.preprocessing.image import ImageDataGenerator\n", 139 | "\n", 140 | "# train_datagen=ImageDataGenerator(rescale=1./255)\n", 141 | "# train_flow=train_pic_gen.flow_from_directory(r'./hotdog\\train',target_size=(128, 128),batch_size=32,class_mode='categorical')\n", 142 | "\n", 143 | "train_datagen = ImageDataGenerator(\n", 144 | " rescale=1,\n", 145 | " shear_range=0.2,\n", 146 | " zoom_range=0.2,\n", 147 | " horizontal_flip=False)\n", 148 | "\n", 149 | "train_flow = train_datagen.flow_from_directory(\n", 150 | " r'.\\train', # this is the target directory\n", 151 | " target_size=(224, 224), # all images will be resized to 150x150\n", 152 | " batch_size=32,\n", 153 | " class_mode='categorical')" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "# 二、网络定义\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 2, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "_________________________________________________________________\n", 173 | "Layer (type) Output Shape Param # \n", 174 | "=================================================================\n", 175 | "inception_v3 (Model) (None, 5, 5, 2048) 21802784 \n", 176 | "_________________________________________________________________\n", 177 | "sequential_2 (Sequential) (None, 2) 13107970 \n", 178 | "=================================================================\n", 179 | "Total params: 34,910,754\n", 180 | "Trainable params: 13,107,970\n", 181 | "Non-trainable params: 21,802,784\n", 182 | "_________________________________________________________________\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "from keras.models import Sequential, Model\n", 188 | "from keras.layers import Dense,Flatten,Dropout,Input\n", 189 | "from keras.layers.convolutional import Conv2D,MaxPooling2D\n", 190 | "from keras.layers import Activation, Dropout, Flatten, Dense\n", 191 | "from keras.applications import InceptionV3\n", 192 | "\n", 193 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", 194 | "gpu_options = tf.GPUOptions(allow_growth=True)\n", 195 | "sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))\n", 196 | "\n", 197 | "TransforLearning = True\n", 198 | "model = Sequential()\n", 199 | "if not TransforLearning:\n", 200 | " model.add(Conv2D(32, 3, 3, input_shape=(224, 224, 3)))\n", 201 | " model.add(Activation('relu'))\n", 202 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 203 | "\n", 204 | " model.add(Conv2D(32, 3, 3))\n", 205 | " model.add(Activation('relu'))\n", 206 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 207 | "\n", 208 | " model.add(Conv2D(64, 3, 3))\n", 209 | " model.add(Activation('relu'))\n", 210 | " model.add(MaxPooling2D(pool_size=(2, 2)))\n", 211 | "\n", 212 | " model.add(Flatten())# this converts our 3D feature maps to 1D feature vectors\n", 213 | " model.add(Dense(64))\n", 214 | " model.add(Activation('relu'))\n", 215 | " model.add(Dropout(0.5))\n", 216 | " model.add(Dense(2))\n", 217 | " model.add(Activation('sigmoid'))\n", 218 | " \n", 219 | "else:\n", 220 | " inception = InceptionV3(weights='imagenet', include_top=False, input_shape=(224,224,3))\n", 221 | " model.add(inception)\n", 222 | " # build a classifier model to put on top of the convolutional model\n", 223 | " top_model = Sequential()\n", 224 | " top_model.add(Flatten(input_shape=model.output_shape[1:]))\n", 225 | " top_model.add(Dense(64, activation='relu'))\n", 226 | " top_model.add(Dropout(0.5))\n", 227 | " top_model.add(Dense(2, activation='sigmoid'))\n", 228 | "\n", 229 | " # 数据集不够大,这里将inception部分参数冻结,不参与训练\n", 230 | " model.add(top_model)\n", 231 | " for layer in inception.layers[:]:\n", 232 | " layer.trainable = False\n", 233 | "\n", 234 | "model.summary()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "# 三、训练" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 3, 247 | "metadata": { 248 | "scrolled": true 249 | }, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "Epoch 1/3\n", 256 | "625/625 [==============================] - 1048s 2s/step - loss: 0.6029 - acc: 0.6771\n", 257 | "Epoch 2/3\n", 258 | "625/625 [==============================] - 1057s 2s/step - loss: 0.5119 - acc: 0.7497\n", 259 | "Epoch 3/3\n", 260 | "625/625 [==============================] - 1074s 2s/step - loss: 0.4856 - acc: 0.7697\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "from keras import losses\n", 266 | "import keras.backend as K\n", 267 | "from keras import optimizers\n", 268 | "from keras.callbacks import TensorBoard\n", 269 | "\n", 270 | "samples_per_epoch = 10000 # len(list_img)\n", 271 | "batch_size = 32\n", 272 | "\n", 273 | "# K.set_value(sgd.lr, 0.5 * K.get_value(sgd.lr))\n", 274 | "# sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)\n", 275 | "# model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])\n", 276 | "model.compile(optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),\n", 277 | " loss='binary_crossentropy', # 'categorical_crossentropy',\n", 278 | " metrics=['accuracy'])\n", 279 | "\n", 280 | "\n", 281 | "model.fit_generator(train_flow, \n", 282 | " # get_batch(list_img, list_lab),\n", 283 | " steps_per_epoch=2*samples_per_epoch//batch_size, # samples_per_epoch//batch_size, \n", 284 | " epochs=3,\n", 285 | " shuffle=True,\n", 286 | " callbacks=[TensorBoard(log_dir='./tmp/log', write_graph=True)])\n", 287 | "# model.fit(img, label, batch_size=8, epochs=10)\n", 288 | "model.save('model.h5')" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "# 四、测试\n", 296 | "## 1、自建数据生成器\n", 297 | "如果上面用的get_batch函数获取的训练数据,则这里也要使用自建的数据获取流程
\n", 298 | "须将下面两个路径替换为测试图片的路径:\n", 299 | "```\n", 300 | "test_real_list = glob.glob(\"./0/*.jpg\")\n", 301 | "test_glitch_list = glob.glob(\"./1/*jpg\")\n", 302 | "```" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 3, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "import matplotlib.pyplot as plt\n", 312 | "\n", 313 | "test_real_list = glob.glob(\"./test/0/*.png\")\n", 314 | "test_glitch_list = glob.glob(\"./test/1/*.png\")\n", 315 | "test_list_img = test_real_list + test_glitch_list\n", 316 | "\n", 317 | "img_test = [cv2.imread(i) for i in test_list_img]" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 4, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stderr", 327 | "output_type": "stream", 328 | "text": [ 329 | "C:\\software\\Anaconda\\lib\\site-packages\\ipykernel_launcher.py:2: FutureWarning: comparison to `None` will result in an elementwise object comparison in the future.\n", 330 | " \n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "img_test = [cv2.resize(i,(224, 224),interpolation=cv2.INTER_CUBIC)\n", 336 | " for i in img_test if not i==None]\n", 337 | "img_test = np.concatenate(np.expand_dims(img_test, axis=0), axis=0)/255\n", 338 | "\n", 339 | "label_test = np.zeros([len(img_test), 2], dtype=int)\n", 340 | "label_test[:len(test_real_list), 0] = 1\n", 341 | "label_test[len(test_real_list):, 1] = 1\n", 342 | "# pp.pprint(list(zip(test_list_img, label_test)))" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": {}, 348 | "source": [ 349 | "## 2、API数据生成器\n", 350 | "否则使用迭代器即可" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 4, 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "name": "stdout", 360 | "output_type": "stream", 361 | "text": [ 362 | "Found 800 images belonging to 2 classes.\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "test_datagen = ImageDataGenerator(\n", 368 | " rescale=1,\n", 369 | " shear_range=0.2,\n", 370 | " zoom_range=0.2,\n", 371 | " horizontal_flip=False)\n", 372 | "\n", 373 | "test_flow = train_datagen.flow_from_directory(\n", 374 | " r'.\\Hotdog\\test', # this is the target directory\n", 375 | " target_size=(224, 224), # all images will be resized to 150x150\n", 376 | " batch_size=32,\n", 377 | " class_mode='categorical')" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 5, 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "name": "stdout", 387 | "output_type": "stream", 388 | "text": [ 389 | "_________________________________________________________________\n", 390 | "Layer (type) Output Shape Param # \n", 391 | "=================================================================\n", 392 | "inception_v3 (Model) (None, 5, 5, 2048) 21802784 \n", 393 | "_________________________________________________________________\n", 394 | "sequential_2 (Sequential) (None, 2) 13107970 \n", 395 | "=================================================================\n", 396 | "Total params: 34,910,754\n", 397 | "Trainable params: 13,107,970\n", 398 | "Non-trainable params: 21,802,784\n", 399 | "_________________________________________________________________\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "import keras.models as KM\n", 405 | "\n", 406 | "model = KM.load_model('model.h5')\n", 407 | "model.summary()" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 7, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "[('loss', 6.2614383697509766), ('acc', 0.46875)]\n" 420 | ] 421 | } 422 | ], 423 | "source": [ 424 | "import pprint as pp\n", 425 | "# result = model.evaluate(img_test, label_test) # \n", 426 | "result = model.evaluate_generator(test_flow, 1)\n", 427 | "pp.pprint(list(zip(model.metrics_names, result)))\n", 428 | "# [('loss', 4.7238349914550781), ('acc', 0.59375)]" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": { 435 | "collapsed": true 436 | }, 437 | "outputs": [], 438 | "source": [] 439 | } 440 | ], 441 | "metadata": { 442 | "kernelspec": { 443 | "display_name": "Python 3", 444 | "language": "python", 445 | "name": "python3" 446 | }, 447 | "language_info": { 448 | "codemirror_mode": { 449 | "name": "ipython", 450 | "version": 3 451 | }, 452 | "file_extension": ".py", 453 | "mimetype": "text/x-python", 454 | "name": "python", 455 | "nbconvert_exporter": "python", 456 | "pygments_lexer": "ipython3", 457 | "version": "3.6.8" 458 | } 459 | }, 460 | "nbformat": 4, 461 | "nbformat_minor": 2 462 | } 463 | -------------------------------------------------------------------------------- /MXnet_TransforLearning/README.md: -------------------------------------------------------------------------------- 1 | ## MXNet 实现简单的分类任务 2 | [『MXNet』第九弹_分类器以及迁移学习DEMO](https://www.cnblogs.com/hellcat/p/9098168.html)
3 | 首先做好图片文件夹,形式和tf版本一致:一个大文件夹下存放各个class的文件夹,每个class的文件夹内存放对应class的图片。
4 | 然后,将`mxnet_classifiter.ipynb `数据预处理代码块中如下代码中的图片目录指定到自己的图片文件夹,
5 | ``` 6 | train_ds = gdata.vision.ImageFolderDataset(root = r'.\Hotdog', # 指定到图片根目录 7 | flag=1) # 0转换为灰度图,1转换为彩色图 8 | ``` 9 | 最后按顺序运行各个代码块即可。 10 | -------------------------------------------------------------------------------- /MXnet_TransforLearning/mxnet_classifiter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 数据读取及预处理" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "C:\\software\\Anaconda\\lib\\site-packages\\mxnet\\gluon\\data\\vision\\datasets.py:317: UserWarning: Ignoring .\\Hotdog\\test\\hotdog of type . Only support .jpg, .jpeg, .png\n", 20 | " filename, ext, ', '.join(self._exts)))\n", 21 | "C:\\software\\Anaconda\\lib\\site-packages\\mxnet\\gluon\\data\\vision\\datasets.py:317: UserWarning: Ignoring .\\Hotdog\\test\\not-hotdog of type . Only support .jpg, .jpeg, .png\n", 22 | " filename, ext, ', '.join(self._exts)))\n", 23 | "C:\\software\\Anaconda\\lib\\site-packages\\mxnet\\gluon\\data\\vision\\datasets.py:317: UserWarning: Ignoring .\\Hotdog\\train\\hotdog of type . Only support .jpg, .jpeg, .png\n", 24 | " filename, ext, ', '.join(self._exts)))\n", 25 | "C:\\software\\Anaconda\\lib\\site-packages\\mxnet\\gluon\\data\\vision\\datasets.py:317: UserWarning: Ignoring .\\Hotdog\\train\\not-hotdog of type . Only support .jpg, .jpeg, .png\n", 26 | " filename, ext, ', '.join(self._exts)))\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "import os\n", 32 | "import mxnet as mx\n", 33 | "from mxnet import gluon, autograd, init\n", 34 | "from mxnet.gluon import data as gdata, loss as gloss, nn \n", 35 | "from mxnet.gluon.data.vision import transforms\n", 36 | "\n", 37 | "transform_train = transforms.Compose([\n", 38 | " transforms.Resize(32),\n", 39 | " transforms.ToTensor(),])\n", 40 | "# gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465], \n", 41 | "# [0.2023, 0.1994, 0.2010])])\n", 42 | "\n", 43 | "train_ds = gdata.vision.ImageFolderDataset(root = r'.\\Hotdog', # 指定到图片根目录\n", 44 | " flag=1) # 0转换为灰度图,1转换为彩色图\n", 45 | "\n", 46 | "loader = gluon.data.DataLoader # 可迭代对象class,非迭代器class\n", 47 | "\n", 48 | "train_data = loader(dataset = train_ds.transform_first(transform_train), \n", 49 | " batch_size = 128,\n", 50 | " shuffle=True,\n", 51 | " last_batch='keep')" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "## 网络定义及初始化方法" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": { 65 | "collapsed": true 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "from mxnet import nd\n", 70 | "from mxnet.gluon import nn\n", 71 | "\n", 72 | " \n", 73 | "class Residual(nn.HybridBlock):\n", 74 | " def __init__(self, channels, same_shape=True, **kwargs):\n", 75 | " super(Residual, self).__init__(**kwargs)\n", 76 | " self.same_shape = same_shape\n", 77 | " with self.name_scope():\n", 78 | " strides = 1 if same_shape else 2\n", 79 | " self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1,\n", 80 | " strides=strides)\n", 81 | " self.bn1 = nn.BatchNorm()\n", 82 | " self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)\n", 83 | " self.bn2 = nn.BatchNorm()\n", 84 | " if not same_shape:\n", 85 | " self.conv3 = nn.Conv2D(channels, kernel_size=1,\n", 86 | " strides=strides)\n", 87 | "\n", 88 | " def hybrid_forward(self, F, x):\n", 89 | " \"\"\"\n", 90 | " conv(3*3 kernel, unknow stride)->bn->relu->conv(3*3 kernel, 1*1 stride)-bn\n", 91 | " conv(1*1 kernal, unknow stride)\n", 92 | " \"\"\"\n", 93 | " out = F.relu(self.bn1(self.conv1(x)))\n", 94 | " out = self.bn2(self.conv2(out))\n", 95 | " if not self.same_shape:\n", 96 | " x = self.conv3(x)\n", 97 | " return F.relu(out + x)\n", 98 | " \n", 99 | " \n", 100 | "class ResNet(nn.HybridBlock):\n", 101 | " def __init__(self, num_classes, verbose=False, **kwargs):\n", 102 | " super(ResNet, self).__init__(**kwargs)\n", 103 | " self.verbose = verbose\n", 104 | " with self.name_scope():\n", 105 | " net = self.net = nn.HybridSequential()\n", 106 | " # 模块1\n", 107 | " net.add(nn.Conv2D(channels=32, kernel_size=3, strides=1, padding=1))\n", 108 | " net.add(nn.BatchNorm())\n", 109 | " net.add(nn.Activation(activation='relu'))\n", 110 | " # 模块2\n", 111 | " for _ in range(3):\n", 112 | " net.add(Residual(channels=32))\n", 113 | " # 模块3\n", 114 | " net.add(Residual(channels=64, same_shape=False))\n", 115 | " for _ in range(2):\n", 116 | " net.add(Residual(channels=64))\n", 117 | " # 模块4\n", 118 | " net.add(Residual(channels=128, same_shape=False))\n", 119 | " for _ in range(2):\n", 120 | " net.add(Residual(channels=128))\n", 121 | " # 模块5\n", 122 | " net.add(nn.AvgPool2D(pool_size=8))\n", 123 | " net.add(nn.Flatten())\n", 124 | " net.add(nn.Dense(num_classes))\n", 125 | " \n", 126 | " def hybrid_forward(self, F, x):\n", 127 | " out = x\n", 128 | " for i, b in enumerate(self.net):\n", 129 | " out = b(out)\n", 130 | " if self.verbose:\n", 131 | " print('Block %d output: %s'%(i+1, out.shape))\n", 132 | " return out\n", 133 | " \n", 134 | " \n", 135 | "def get_net(ctx):\n", 136 | " num_outputs = 5\n", 137 | " net = ResNet(num_outputs)\n", 138 | " net.initialize(ctx=ctx, init=mx.init.Xavier())\n", 139 | " return net" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## 预训练网络获取" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 5, 152 | "metadata": { 153 | "collapsed": true 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "from mxnet.gluon.model_zoo import vision as models\n", 158 | "\n", 159 | "pretrained_net = models.resnet18_v2(pretrained=True)\n", 160 | "finetune_net = models.resnet18_v2(classes=5)\n", 161 | "finetune_net.features = pretrained_net.features\n", 162 | "pretrained_net.features.collect_params().setattr('grad_req', 'null') # 固定特征层参数\n", 163 | "finetune_net.output.initialize(init.Xavier())" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 16, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "Help on method setattr in module mxnet.gluon.parameter:\n", 176 | "\n", 177 | "setattr(name, value) method of mxnet.gluon.parameter.ParameterDict instance\n", 178 | " Set an attribute to a new value for all Parameters.\n", 179 | " \n", 180 | " For example, set grad_req to null if you don't need gradient w.r.t a\n", 181 | " model's Parameters::\n", 182 | " \n", 183 | " model.collect_params().setattr('grad_req', 'null')\n", 184 | " \n", 185 | " or change the learning rate multiplier::\n", 186 | " \n", 187 | " model.collect_params().setattr('lr_mult', 0.5)\n", 188 | " \n", 189 | " Parameters\n", 190 | " ----------\n", 191 | " name : str\n", 192 | " Name of the attribute.\n", 193 | " value : valid type for attribute name\n", 194 | " The new value for the attribute.\n", 195 | "\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "help(pretrained_net.features.collect_params().setattr)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "## 训练策略" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 11, 213 | "metadata": { 214 | "collapsed": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "import datetime\n", 219 | "import sys\n", 220 | "sys.path.append('..')\n", 221 | "import gluonbook as gb\n", 222 | " \n", 223 | "def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay):\n", 224 | " trainer = gluon.Trainer(\n", 225 | " net.collect_params(), 'sgd', {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})\n", 226 | " softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()\n", 227 | "\n", 228 | " prev_time = datetime.datetime.now()\n", 229 | " for epoch in range(num_epochs):\n", 230 | " train_loss = 0.0\n", 231 | " train_acc = 0.0\n", 232 | " if epoch > 0 and epoch % lr_period == 0:\n", 233 | " trainer.set_learning_rate(trainer.learning_rate * lr_decay)\n", 234 | " for data, label in train_data:\n", 235 | " label = label.astype('float32').as_in_context(ctx)\n", 236 | " with autograd.record():\n", 237 | " output = net(data.as_in_context(ctx))\n", 238 | " loss = softmax_cross_entropy(output, label)\n", 239 | " loss.backward()\n", 240 | " trainer.step(batch_size)\n", 241 | " train_loss += nd.mean(loss).asscalar()\n", 242 | " train_acc += gb.accuracy(output, label)\n", 243 | " # print(gb.accuracy(output, label))\n", 244 | " cur_time = datetime.datetime.now()\n", 245 | " h, remainder = divmod((cur_time - prev_time).seconds, 3600)\n", 246 | " m, s = divmod(remainder, 60)\n", 247 | " time_str = \"Time %02d:%02d:%02d\" % (h, m, s)\n", 248 | " if valid_data is not None:\n", 249 | " valid_acc = gb.evaluate_accuracy(valid_data, net, ctx)\n", 250 | " epoch_str = (\"Epoch %d. Loss: %f, Train acc %f, Valid acc %f, \"\n", 251 | " % (epoch, train_loss / len(train_data),\n", 252 | " train_acc / len(train_data), valid_acc))\n", 253 | " else:\n", 254 | " epoch_str = (\"Epoch %d. Loss: %f, Train acc %f, \"\n", 255 | " % (epoch, train_loss / len(train_data),\n", 256 | " train_acc / len(train_data)))\n", 257 | " prev_time = cur_time\n", 258 | " # net.export('./model/astro')\n", 259 | " net.save_parameters('./model/astro_{}'.format(epoch))\n", 260 | " print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "## 实际训练" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "Epoch 0. Loss: 0.817778, Train acc 0.618649, Time 00:02:06, lr 0.01\n", 280 | "Epoch 1. Loss: 0.460649, Train acc 0.778780, Time 00:02:07, lr 0.01\n", 281 | "Epoch 2. Loss: 0.468455, Train acc 0.784274, Time 00:02:06, lr 0.01\n", 282 | "Epoch 3. Loss: 0.344460, Train acc 0.848286, Time 00:02:05, lr 0.01\n", 283 | "Epoch 4. Loss: 0.279991, Train acc 0.886139, Time 00:02:05, lr 0.01\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "ctx = gb.try_gpu()\n", 289 | "num_epochs = 5\n", 290 | "learning_rate = 0.01\n", 291 | "weight_decay = 5e-4\n", 292 | "lr_period = 80\n", 293 | "lr_decay = 0.1\n", 294 | "batch_size = 32\n", 295 | "\n", 296 | "\n", 297 | "finetune = False # <-----是否使用预训练网络\n", 298 | "if finetune:\n", 299 | " finetune_net.collect_params().reset_ctx(ctx)\n", 300 | " finetune_net.hybridize()\n", 301 | " net = finetune_net\n", 302 | "else:\n", 303 | " net = get_net(ctx)\n", 304 | " net.hybridize()\n", 305 | " # net.load_parameters('./model/astro_0')\n", 306 | "\n", 307 | "train(net, train_data, None, num_epochs, learning_rate,\n", 308 | " weight_decay, ctx, lr_period, lr_decay)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": { 315 | "collapsed": true 316 | }, 317 | "outputs": [], 318 | "source": [] 319 | } 320 | ], 321 | "metadata": { 322 | "kernelspec": { 323 | "display_name": "Python 3", 324 | "language": "python", 325 | "name": "python3" 326 | }, 327 | "language_info": { 328 | "codemirror_mode": { 329 | "name": "ipython", 330 | "version": 3 331 | }, 332 | "file_extension": ".py", 333 | "mimetype": "text/x-python", 334 | "name": "python", 335 | "nbconvert_exporter": "python", 336 | "pygments_lexer": "ipython3", 337 | "version": "3.6.1" 338 | } 339 | }, 340 | "nbformat": 4, 341 | "nbformat_minor": 2 342 | } 343 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 迁移学习TransforLearning 2 | ====================== 3 | [『TensorFlow』迁移学习](http://www.cnblogs.com/hellcat/p/6909269.html "我的博客")
4 | ## 1、相关下载 5 | 数据和预训练模型下载: 6 | ```Shell 7 | curl -O http://download.tensorflow.org/example_images/flower_photos.tgz 8 | wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip 9 | ``` 10 | ## 2、项目简介 11 | 本项目将使用ImageNet数据集预训练好的InceptionV3网络结构舍弃后面全连接层,使用了新的分类器对花朵数据进行了迁移学习,迁移学习对于这种中等偏小的数据集又为合适。
12 | 13 | ### 项目文件 14 | ![](https://images2018.cnblogs.com/blog/1161096/201804/1161096-20180424094519006-1238870240.png "项目文件")
15 | `inception_dec_2015`:模型存放文件夹,下载后解压模型文件就会生成
16 | `flower_photos`:文件目录,下面包含各个子类的`文件夹`,如果使用自己的数据的话,将自己数据各个类别分别放入一个文件夹,文件夹名字是类的字符串名字即可,将这些文件夹放入flower_photos文件夹内即可
17 | `TransforLearning.py`:主程序,用于训练,不过注意,可训练文件格式应该是jpg(jpeg、JPG等等写法均可)
18 | `TransferLearning_reload.py`:用于预测,仅能进行单张图片类别预测,需要进入文件中(21行左右),将`image_path`修改为自己的图片路径
19 | 其他文件夹为程序自己生成,不需要提前新建
20 | 文件夹`Keras_TransforLearning `和`MXnet_TransforLearning `分别展示了使用Keras和MXNet快速进行分类任务的接口调用Demo,由于使用的是高级API,可以极快上手,值得学习
21 | 22 | ### 运行命令 23 | 首先训练, 24 | ```Shell 25 | python TransforLearning.py 26 | ``` 27 | 等待训练完成后(不等也行,不过需要保证已经有训练中间生成模型被保存了),预测一张自己的图片, 28 | ```Python 29 | python TransferLearning_reload.py 30 | ``` 31 | 命令很简单,之后就会输出预测信息,如下格式, 32 | ![](https://images2018.cnblogs.com/blog/1161096/201804/1161096-20180424094042927-662872256.png "分类信息")
33 | 第一行表示分类的类别,这里是根据图片文件夹的名字来的,可以看到和之前的项目文件示意图中`flower_photos`的子文件夹名称一一对应;第二行为分类结果,每一个值和第一行对应位置的类别相对应,比如这个结果就是分类为daisy的概率为0.23 34 | -------------------------------------------------------------------------------- /TransferLearning_reload.py: -------------------------------------------------------------------------------- 1 | # Author : hellcat 2 | # Time : 18-4-23 3 | 4 | """ 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 7 | 8 | import numpy as np 9 | np.set_printoptions(threshold=np.inf) 10 | """ 11 | 12 | import numpy as np 13 | import pprint as pp 14 | import tensorflow as tf 15 | from TransforLearning import creat_image_lists 16 | 17 | tf.logging.set_verbosity(tf.logging.INFO) 18 | 19 | config = tf.ConfigProto() 20 | config.gpu_options.allow_growth = True 21 | sess = tf.Session(config=config) 22 | 23 | image_path = ['/home/hellcat/PycharmProjects/CV&DL/迁移学习/flower_photos/daisy/5673551_01d1ea993e_n.jpg', 24 | '/home/hellcat/PycharmProjects/CV&DL/迁移学习/flower_photos/roses/99383371_37a5ac12a3_n.jpg'] 25 | 26 | # with open(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f: # 阅读器上下文 27 | # graph_def = tf.GraphDef() # 生成图 28 | # graph_def.ParseFromString(f.read()) # 图加载模型 29 | # # 加载图上节点张量(按照句柄理解) 30 | # bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def( # 从图上读取张量,同时导入默认图 31 | # graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME]) 32 | # 33 | # image_data = open(image_path, 'rb').read() 34 | # bottleneck = sess.run(bottleneck_tensor, feed_dict={jpeg_data_tensor: image_data}) 35 | 36 | ckpt = tf.train.get_checkpoint_state('./model/') 37 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta') 38 | saver.restore(sess, ckpt.model_checkpoint_path) 39 | # pp.pprint(tf.get_default_graph().get_operations()) 40 | g = tf.get_default_graph() 41 | for img in image_path: 42 | image_data = open(img, 'rb').read() 43 | bottleneck = sess.run(g.get_tensor_by_name('import/pool_3/_reshape:0'), 44 | feed_dict={g.get_tensor_by_name('import/DecodeJpeg/contents:0'): image_data}) 45 | 46 | class_result = sess.run(g.get_tensor_by_name('final_train_ops/Softmax:0'), 47 | feed_dict={g.get_tensor_by_name('BottleneckInputPlaceholder:0'): bottleneck}) 48 | 49 | images_lists = creat_image_lists(10, 10) 50 | tf.logging.info(images_lists.keys()) 51 | print(np.squeeze(class_result)) 52 | -------------------------------------------------------------------------------- /TransforLearning.py: -------------------------------------------------------------------------------- 1 | # Author : hellcat 2 | # Time : 18-4-23 3 | 4 | """ 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 7 | 8 | import numpy as np 9 | np.set_printoptions(threshold=np.inf) 10 | 11 | import tensorflow as tf 12 | config = tf.ConfigProto() 13 | config.gpu_options.allow_growth = True 14 | sess = tf.Session(config=config) 15 | """ 16 | 17 | import glob 18 | import os.path 19 | import random 20 | import numpy as np 21 | import tensorflow as tf 22 | config = tf.ConfigProto() 23 | config.gpu_options.allow_growth = True 24 | 25 | '''模型及样本路径设置''' 26 | 27 | BOTTLENECK_TENSOR_SIZE = 2048 # 瓶颈层节点个数 28 | BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' # 瓶颈层输出张量名称 29 | JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' # 输入层张量名称 30 | 31 | MODEL_DIR = './inception_dec_2015' # 模型存放文件夹 32 | MODEL_FILE = 'tensorflow_inception_graph.pb' # 模型名 33 | 34 | CACHE_DIR = './bottleneck' # 瓶颈输出中转文件夹 35 | INPUT_DATA = './flower_photos' # 数据文件夹 36 | 37 | VALIDATION_PERCENTAGE = 10 # 验证用数据百分比 38 | TEST_PERCENTAGE = 10 # 测试用数据百分比 39 | 40 | '''新添加神经网络部参数设置''' 41 | 42 | LEARNING_RATE = 0.01 43 | STEP = 5000 44 | BATCH = 100 45 | 46 | 47 | def creat_image_lists(validation_percentage, testing_percentage): 48 | ''' 49 | 将图片(无路径文件名)信息保存在字典中 50 | :param validation_percentage: 验证数据百分比 51 | :param testing_percentage: 测试数据百分比 52 | :return: 字典{标签:{文件夹:str,训练:[],验证:[],测试:[]},...} 53 | ''' 54 | result = {} 55 | sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] 56 | # 由于os.walk()列表第一个是'./',所以排除 57 | is_root_dir = True # <----- 58 | # 遍历各个label文件夹 59 | for sub_dir in sub_dirs: 60 | if is_root_dir: # <----- 61 | is_root_dir = False 62 | continue 63 | 64 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 65 | file_list = [] 66 | dir_name = os.path.basename(sub_dir) 67 | # 遍历各个可能的文件尾缀 68 | for extension in extensions: 69 | # file_glob = os.path.join(INPUT_DATA,dir_name,'*.'+extension) 70 | file_glob = os.path.join(sub_dir, '*.' + extension) 71 | file_list.extend(glob.glob(file_glob)) # 匹配并收集路径&文件名 72 | # print(file_glob,'\n',glob.glob(file_glob)) 73 | if not file_list: continue 74 | 75 | label_name = dir_name.lower() # 生成label,实际就是小写文件夹名 76 | 77 | # 初始化各个路径&文件收集list 78 | training_images = [] 79 | testing_images = [] 80 | validation_images = [] 81 | 82 | # 去路径,只保留文件名 83 | for file_name in file_list: 84 | base_name = os.path.basename(file_name) 85 | 86 | # 随机划分数据给验证和测试 87 | chance = np.random.randint(100) 88 | if chance < validation_percentage: 89 | validation_images.append(base_name) 90 | elif chance < (validation_percentage + testing_percentage): 91 | testing_images.append(base_name) 92 | else: 93 | training_images.append(base_name) 94 | # 本标签字典项生成 95 | result[label_name] = { 96 | 'dir': dir_name, 97 | 'training': training_images, 98 | 'testing': testing_images, 99 | 'validation': validation_images 100 | } 101 | return result 102 | 103 | 104 | def get_random_cached_bottlenecks(sess, n_class, image_lists, batch, category, jpeg_data_tensor, bottleneck_tensor): 105 | """ 106 | 函数随机获取一个batch的图片作为训练数据 107 | :param sess: 108 | :param n_class: 109 | :param image_lists: 110 | :param how_many: 111 | :param category: training or validation 112 | :param jpeg_data_tensor: 113 | :param bottleneck_tensor: 114 | :return: 瓶颈张量输出 & label 115 | """ 116 | bottlenecks = [] 117 | ground_truths = [] 118 | for i in range(batch): 119 | label_index = random.randrange(n_class) # 标签索引随机生成 120 | label_name = list(image_lists.keys())[label_index] # 标签名获取 121 | image_index = random.randrange(65536) # 标签内图片索引随机种子 122 | # 瓶颈层张量 123 | bottleneck = get_or_create_bottleneck( # 获取对应标签随机图片瓶颈张量 124 | sess, image_lists, label_name, image_index, category, 125 | jpeg_data_tensor, bottleneck_tensor) 126 | ground_truth = np.zeros(n_class, dtype=np.float32) 127 | ground_truth[label_index] = 1.0 # 标准结果[0,0,1,0...] 128 | # 收集瓶颈张量和label 129 | bottlenecks.append(bottleneck) 130 | ground_truths.append(ground_truth) 131 | return bottlenecks, ground_truths 132 | 133 | 134 | def get_or_create_bottleneck( 135 | sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor): 136 | """ 137 | 寻找已经计算且保存下来的特征向量,如果找不到则先计算这个特征向量,然后保存到文件 138 | :param sess: 139 | :param image_lists: 全图像字典 140 | :param label_name: 当前标签 141 | :param index: 图片索引 142 | :param category: training or validation 143 | :param jpeg_data_tensor: 144 | :param bottleneck_tensor: 145 | :return: 146 | """ 147 | label_lists = image_lists[label_name] # 本标签字典获取 标签:{文件夹:str,训练:[],验证:[],测试:[]} 148 | sub_dir = label_lists['dir'] # 获取标签值 149 | sub_dir_path = os.path.join(CACHE_DIR, sub_dir) # 保存文件路径 150 | if not os.path.exists(sub_dir_path): 151 | os.makedirs(sub_dir_path) 152 | bottleneck_path = get_bottleneck_path(image_lists, label_name, index, category) 153 | if not os.path.exists(bottleneck_path): 154 | image_path = get_image_path(image_lists, INPUT_DATA, label_name, index, category) 155 | # image_data = gfile.FastGFile(image_path,'rb').read() 156 | image_data = open(image_path, 'rb').read() 157 | # print(gfile.FastGFile(image_path,'rb').read()==open(image_path,'rb').read()) 158 | # 生成向前传播后的瓶颈张量 159 | bottleneck_values = run_bottleneck_on_images(sess, image_data, jpeg_data_tensor, bottleneck_tensor) 160 | # list2string以便于写入文件 161 | bottleneck_string = ','.join(str(x) for x in bottleneck_values) 162 | # print(bottleneck_values) 163 | # print(bottleneck_string) 164 | with open(bottleneck_path, 'w') as bottleneck_file: 165 | bottleneck_file.write(bottleneck_string) 166 | else: 167 | with open(bottleneck_path, 'r') as bottleneck_file: 168 | bottleneck_string = bottleneck_file.read() 169 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')] 170 | # 返回的是list注意 171 | return bottleneck_values 172 | 173 | 174 | def run_bottleneck_on_images(sess, image_data, jpeg_data_tensor, bottleneck_tensor): 175 | """ 176 | 使用加载的训练好的Inception-v3模型处理一张图片,得到这个图片的特征向量。 177 | :param sess: 会话句柄 178 | :param image_data: 图片文件句柄 179 | :param jpeg_data_tensor: 输入张量句柄 180 | :param bottleneck_tensor: 瓶颈张量句柄 181 | :return: 瓶颈张量值 182 | """ 183 | # print('input:',len(image_data)) 184 | bottleneck_values = sess.run(bottleneck_tensor, feed_dict={jpeg_data_tensor: image_data}) 185 | bottleneck_values = np.squeeze(bottleneck_values) 186 | # print('bottle:',len(bottleneck_values)) 187 | return bottleneck_values 188 | 189 | 190 | def get_bottleneck_path(image_lists, label_name, index, category): 191 | """ 192 | 获取一张图片的中转(featuremap)地址(添加txt) 193 | :param image_lists: 全图片字典 194 | :param label_name: 标签名 195 | :param index: 随机数索引 196 | :param category: training or validation 197 | :return: 中转(featuremap)地址(添加txt) 198 | """ 199 | return get_image_path(image_lists, CACHE_DIR, label_name, index, category) + '.txt' 200 | 201 | 202 | def get_image_path(image_lists, image_dir, label_name, index, category): 203 | """ 204 | 通过类别名称、所属数据集和图片编号获取一张图片的中转(featuremap)地址(无txt) 205 | :param image_lists: 全图片字典 206 | :param image_dir: 外层文件夹(内部是标签文件夹) 207 | :param label_name: 标签名 208 | :param index: 随机数索引 209 | :param category: training or validation 210 | :return: 图片中间变量地址 211 | """ 212 | label_lists = image_lists[label_name] 213 | category_list = label_lists[category] # 获取目标category图片列表 214 | mod_index = index % len(category_list) # 随机获取一张图片的索引 215 | base_name = category_list[mod_index] # 通过索引获取图片名 216 | return os.path.join(image_dir, label_lists['dir'], base_name) 217 | 218 | 219 | def get_test_bottlenecks(sess, image_lists, n_class, jpeg_data_tensor, bottleneck_tensor): 220 | """ 221 | 获取全部的测试数据,计算输出 222 | :param sess: 223 | :param image_lists: 224 | :param n_class: 225 | :param jpeg_data_tensor: 226 | :param bottleneck_tensor: 227 | :return:瓶颈输出 & label 228 | """ 229 | bottlenecks = [] 230 | ground_truths = [] 231 | label_name_list = list(image_lists.keys()) 232 | for label_index, label_name in enumerate(label_name_list): 233 | category = 'testing' 234 | for index, unused_base_name in enumerate(image_lists[label_name][category]): # 索引, {文件名} 235 | bottleneck = get_or_create_bottleneck( 236 | sess, image_lists, label_name, index, 237 | category, jpeg_data_tensor, bottleneck_tensor) 238 | ground_truth = np.zeros(n_class, dtype=np.float32) 239 | ground_truth[label_index] = 1.0 240 | bottlenecks.append(bottleneck) 241 | ground_truths.append(ground_truth) 242 | return bottlenecks, ground_truths 243 | 244 | 245 | def main(): 246 | # 生成文件字典 247 | images_lists = creat_image_lists(VALIDATION_PERCENTAGE, TEST_PERCENTAGE) 248 | # 记录label种类(字典项数) 249 | n_class = len(images_lists.keys()) 250 | 251 | # 加载模型 252 | # with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),'rb') as f: # 阅读器上下文 253 | with open(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f: # 阅读器上下文 254 | graph_def = tf.GraphDef() # 生成图 255 | graph_def.ParseFromString(f.read()) # 图加载模型 256 | # 加载图上节点张量(按照句柄理解) 257 | bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def( # 从图上读取张量,同时导入默认图 258 | graph_def, 259 | return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME]) 260 | 261 | '''新的神经网络''' 262 | # 输入层,由原模型输出层feed 263 | bottleneck_input = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE], name='BottleneckInputPlaceholder') 264 | ground_truth_input = tf.placeholder(tf.float32, [None, n_class], name='GroundTruthInput') 265 | # 全连接层 266 | with tf.name_scope('final_train_ops'): 267 | Weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_class], stddev=0.001)) 268 | biases = tf.Variable(tf.zeros([n_class])) 269 | logits = tf.matmul(bottleneck_input, Weights) + biases 270 | final_tensor = tf.nn.softmax(logits) 271 | # 交叉熵损失函数 272 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input)) 273 | # 优化算法选择 274 | train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy) 275 | 276 | # 正确率 277 | with tf.name_scope('evaluation'): 278 | correct_prediction = tf.equal(tf.argmax(final_tensor, 1), tf.argmax(ground_truth_input, 1)) 279 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 280 | 281 | if not os.path.exists("./model/"): 282 | os.makedirs("./model/") 283 | with tf.Session(config=config) as sess: 284 | saver = tf.train.Saver() 285 | init = tf.global_variables_initializer() 286 | sess.run(init) 287 | for i in range(STEP): 288 | # 随机batch获取瓶颈输出 & label 289 | train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks( 290 | sess, n_class, images_lists, BATCH, 'training', jpeg_data_tensor, bottleneck_tensor) 291 | sess.run(train_step, 292 | feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth}) 293 | 294 | # 每迭代100次运行一次验证程序 295 | if i % 100 == 0 or i + 1 == STEP: 296 | validation_bottlenecks, validation_ground_truth = get_random_cached_bottlenecks( 297 | sess, n_class, images_lists, BATCH, 'validation', jpeg_data_tensor, bottleneck_tensor) 298 | validation_accuracy = sess.run(evaluation_step, feed_dict={ 299 | bottleneck_input: validation_bottlenecks, ground_truth_input: validation_ground_truth}) 300 | print('Step %d: Validation accuracy on random sampled %d examples = %.1f%%' % 301 | (i, BATCH, validation_accuracy * 100)) 302 | if i % 500 == 0 or i + 1 == STEP: 303 | saver.save(sess, "./model/TransferLearning.model", global_step=i) 304 | 305 | test_bottlenecks, test_ground_truth = get_test_bottlenecks( 306 | sess, images_lists, n_class, jpeg_data_tensor, bottleneck_tensor) 307 | test_accuracy = sess.run(evaluation_step, feed_dict={ 308 | bottleneck_input: test_bottlenecks, ground_truth_input: test_ground_truth}) 309 | print('Final test accuracy = %.1f%%' % (test_accuracy * 100)) 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | 315 | -------------------------------------------------------------------------------- /flower_photos/daisy/43474673_7bb4465a86.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hellcatzm/TransforLearning_TensorFlow/9f943af71efa4f6a0007eefe287bbf690d8a7b57/flower_photos/daisy/43474673_7bb4465a86.jpg -------------------------------------------------------------------------------- /flower_photos/roses/12240303_80d87f77a3_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hellcatzm/TransforLearning_TensorFlow/9f943af71efa4f6a0007eefe287bbf690d8a7b57/flower_photos/roses/12240303_80d87f77a3_n.jpg --------------------------------------------------------------------------------