├── Data_pre-processing.ipynb ├── Feature_extractor.ipynb ├── Lable_information_vector_definition.ipynb ├── Model.ipynb └── README.md /Data_pre-processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "import pywt\n", 12 | "import numpy as np\n", 13 | "import scipy.signal.spectral as sss\n", 14 | "import tensorflow as tf\n", 15 | "from PIL import Image\n", 16 | "# import seaborn as sns\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "\n", 19 | "\n", 20 | "def batch2image(signal_batch, sampling_rate, pooling_size, process_type='wavelet'):\n", 21 | " \"\"\"\n", 22 | " 批数据转换成图片\n", 23 | " :signal_batch:\n", 24 | " :sampling_rate:\n", 25 | " :pooling_size:\n", 26 | " :process_type:\n", 27 | " :return:\n", 28 | " \"\"\"\n", 29 | " num, dim = np.shape(signal_batch)\n", 30 | " batch_set = []\n", 31 | " if process_type == 'wavelet':\n", 32 | " for i in range(num):\n", 33 | " image, _ = wavelet2image(signal_batch[i, :], sampling_rate)\n", 34 | " batch_set.append(image)\n", 35 | " elif process_type == 'stft':\n", 36 | " for i in range(num):\n", 37 | " image, _ = stft2image(signal_batch[i, :], sampling_rate)\n", 38 | " batch_set.append(image)\n", 39 | " else:\n", 40 | " raise KeyError(\"process_type must be wavelet of stft!\")\n", 41 | " batch_set = np.array(batch_set)\n", 42 | " batch_set = image_downsampling(batch_set, pooling_size, form='avg_pooling')\n", 43 | "\n", 44 | " return batch_set\n", 45 | "\n", 46 | "def wavelet2image(signal, sampling_rate, freq_dim_scale=256, wavelet_name='morl'):\n", 47 | "\n", 48 | " \"\"\"\n", 49 | " 小波图像\n", 50 | " :param signal: 1D temporal sequence\n", 51 | " :param sampling_rate: sampling rate for the sequence 定义了每秒从连续信号中提取并组成离散信号的采样个数\n", 52 | " :param freq_dim_scale: frequency resolution 目的是避免信号混淆保证高频信号不被歪曲成低频信号\n", 53 | " :param wavelet_name: wavelet name for CWT, here we have 'morl', 'gaus', 'cmor',...\n", 54 | " :return: time-freq image and its reciprocal frequencies 时频图像及其倒数频率\n", 55 | " \"\"\"\n", 56 | "\n", 57 | " freq_centre = pywt.central_frequency(wavelet_name) # 所选小波的中心频率\n", 58 | " cparam = 2 * freq_centre * freq_dim_scale\n", 59 | " scales = cparam / np.arange(1, freq_dim_scale + 1, 1) # 获取小波基函数的尺度参数 a 的倒数\n", 60 | " [cwt_matrix, frequencies] = pywt.cwt(signal, scales, wavelet_name, 1.0 / sampling_rate)\n", 61 | "\n", 62 | " return abs(cwt_matrix), frequencies\n", 63 | "\n", 64 | "def stft2image(signal, sampling_rate, freq_dim_scale=256, window_name=('gaussian', 3.0)):\n", 65 | "\n", 66 | " \"\"\"\n", 67 | " :param signal: signal input for stft\n", 68 | " :param sampling_rate:\n", 69 | " :param window_name: (gaussian,3), hann, hamming, etc.\n", 70 | "\n", 71 | " Notes\n", 72 | " -----\n", 73 | " Window types:\n", 74 | "\n", 75 | " `boxcar`, `triang`, `blackman`, `hamming`, `hann`, `bartlett`,\n", 76 | " `flattop`, `parzen`, `bohman`, `blackmanharris`, `nuttall`,\n", 77 | " `barthann`, `kaiser` (needs beta), `gaussian` (needs standard\n", 78 | " deviation), `general_gaussian` (needs power, width), `slepian`\n", 79 | " (needs width), `dpss` (needs normalized half-bandwidth),\n", 80 | " `chebwin` (needs attenuation), `exponential` (needs decay scale),\n", 81 | " `tukey` (needs taper fraction)\n", 82 | "\n", 83 | " :return: time-freq image and its frequencies\n", 84 | " \"\"\"\n", 85 | "\n", 86 | " f, t, Zxx = sss.stft(signal, fs=sampling_rate, window=window_name, nperseg=freq_dim_scale)\n", 87 | "\n", 88 | " return Zxx, f\n", 89 | "\n", 90 | "def image_downsampling(image_set, pooling_size=2, form='max_pooling', axis=None):\n", 91 | "\n", 92 | " \"\"\"\n", 93 | " :param image_set: input image with large size\n", 94 | " :param pooling_size: down-sampling rate\n", 95 | " :param form: 'max_pooling' or 'avg_pooling'\n", 96 | " :param axis: if axis is not None, it means that the image will be down-sampled\n", 97 | " just within it row(axis=0) or column(axis=1).\n", 98 | " :return: image has been down-sampled\n", 99 | " \"\"\"\n", 100 | "\n", 101 | " num, time_dim, freq_dim = np.shape(image_set)[0], np.shape(image_set)[1], np.shape(image_set)[2]\n", 102 | " image_set = image_set.reshape(num, time_dim, freq_dim, 1)\n", 103 | " im_input = tf.placeholder(dtype=tf.float32, shape=[num, time_dim, freq_dim, 1])\n", 104 | " kernel_size = [pooling_size, 2*pooling_size]\n", 105 | " if axis == 0:\n", 106 | " kernel_size = [pooling_size, 1]\n", 107 | " elif axis == 1:\n", 108 | " kernel_size = [1, pooling_size]\n", 109 | "\n", 110 | " with tf.device('/cpu:0'):\n", 111 | " pooling_max = tf.contrib.slim.max_pool2d(im_input, kernel_size=kernel_size, stride=kernel_size)\n", 112 | " pooling_avg = tf.contrib.slim.avg_pool2d(im_input, kernel_size=kernel_size, stride=kernel_size)\n", 113 | "\n", 114 | " with tf.Session() as sess:\n", 115 | " down_sampling_im = sess.run(fetches=pooling_max, feed_dict={im_input: image_set})\n", 116 | " if form == 'avg_pooling':\n", 117 | " down_sampling_im = sess.run(fetches=pooling_avg, feed_dict={im_input: image_set})\n", 118 | "\n", 119 | " return down_sampling_im\n", 120 | "\n", 121 | "def get_batch(filename, window_size=512, batch_size=1000, stride=180):\n", 122 | " data = np.loadtxt(filename)\n", 123 | " print(data.shape)\n", 124 | " start = 0\n", 125 | " cnt = 0\n", 126 | " batch_data = []\n", 127 | " while start + window_size < data.shape[0] and cnt < batch_size:\n", 128 | " batch_data.append(data[start: start + window_size])\n", 129 | " start = start + stride + 1\n", 130 | " cnt += 1\n", 131 | " batch_data = np.array(batch_data)\n", 132 | " return batch_data\n", 133 | "\n", 134 | "for root, dirs, files in os.walk(\"./datas/origin\"):\n", 135 | " print(files)\n", 136 | "\n", 137 | "output_path = './datas/image'\n", 138 | "if not os.path.exists(output_path):\n", 139 | " os.makedirs(output_path)\n", 140 | "\n", 141 | " \n", 142 | "\n", 143 | "for i, file in enumerate(files):\n", 144 | " #print(1)\n", 145 | " print(\"processing %s\" % file)\n", 146 | " c = file.split('_')[0] if \"normal\" in file else file.split('_')[3]\n", 147 | " #label = \"{}_{}\".format(c, i) # 师兄写的\n", 148 | " label = c.split('.')[0] # 我写的\n", 149 | " print(label)\n", 150 | " file_path = os.path.join(root, file)\n", 151 | " print(file_path)\n", 152 | " signal = get_batch(filename=file_path, batch_size=2000)\n", 153 | " batch_image = batch2image(signal, sampling_rate=1, pooling_size=4)\n", 154 | " print(\"saving %s/%s.npy, shape %s\" % (output_path, label, batch_image.shape))\n", 155 | " np.save(\"%s/%s.npy\" % (output_path, label), batch_image)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "import numpy as np\n", 165 | "import keras\n", 166 | "import matplotlib.pyplot as plt\n", 167 | "%matplotlib inline\n", 168 | "from keras.utils import plot_model\n", 169 | "from keras.utils import np_utils" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "def creat_dataset1(select):\n", 179 | " x_all = []\n", 180 | " y_all = []\n", 181 | " i=0\n", 182 | " for elem in select:\n", 183 | " print(elem[0])\n", 184 | "# sig = df[elem[0]]\n", 185 | " # print(len(sig))\n", 186 | " if elem[1] == 0:\n", 187 | " sig =np.load(\"./datas/image/IF.npy\")\n", 188 | " elif elem[1] == 1:\n", 189 | " sig =np.load(\"./datas/image/OF.npy\")\n", 190 | " else:\n", 191 | " sig =np.load(\"./datas/image/BF.npy\")\n", 192 | " # i = i+1 \n", 193 | " label = elem[1]\n", 194 | " x = sig\n", 195 | " print(x.shape)\n", 196 | " # print(x.shape[0])\n", 197 | " x_all.append(x)\n", 198 | " y = [[label] for _ in range(x.shape[0])]\n", 199 | " y_all.append(y)\n", 200 | " x_merge = np.vstack(x_all) # 在竖直方向上堆叠\n", 201 | " y_merge = np.vstack(y_all)\n", 202 | " return x_merge, y_merge\n", 203 | "def creat_dataset2(select):\n", 204 | " x_all = []\n", 205 | " y_all = []\n", 206 | " i=0\n", 207 | " for elem in select:\n", 208 | " print(elem[0])\n", 209 | "# sig = df[elem[0]]\n", 210 | " # print(len(sig))\n", 211 | " if elem[1] == 0:\n", 212 | " sig =np.load(\"./datas/image/IO.npy\")\n", 213 | " elif elem[1] == 1:\n", 214 | " sig =np.load(\"./datas/image/IB.npy\")\n", 215 | " \n", 216 | " elif elem[1] == 2:\n", 217 | " sig =np.load(\"./datas/image/OB.npy\")\n", 218 | " else:\n", 219 | " sig =np.load(\"./datas/image/IOB.npy\")\n", 220 | " # i = i+1 \n", 221 | " label = elem[1]\n", 222 | " x = sig\n", 223 | " print(x.shape)\n", 224 | " # print(x.shape[0])\n", 225 | " x_all.append(x)\n", 226 | " y = [[label] for _ in range(x.shape[0])]\n", 227 | " y_all.append(y)\n", 228 | " x_merge = np.vstack(x_all)\n", 229 | " y_merge = np.vstack(y_all)\n", 230 | " return x_merge, y_merge" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "data_train = [('Inner',0), ('Outter', 1), ('Ball', 2)]\n", 240 | "data_test = [('IO',0), ('IB', 1), ('OB', 2),('IOB', 3)]\n", 241 | "\n", 242 | "x_train, y_train = creat_dataset1(select=data_train)\n", 243 | "x_test, y_test = creat_dataset2(select=data_test)\n", 244 | "\n", 245 | "np.save('./data/train/x_trian', x_train)\n", 246 | "np.save('./data/train/y_train', y_train)\n", 247 | "np.save('./data/test/x_test', x_test)\n", 248 | "np.save('./data/test/y_test', y_test)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "idx1 = np.random.randint(0, 2000,1)\n", 258 | "idx2 = np.random.randint(2000, 4000,1)\n", 259 | "idx3 = np.random.randint(4000, 6000,1)\n", 260 | "idx4 = np.random.randint(6000, 8000,1)\n", 261 | "\n", 262 | "x_train =np.load(\"./datas/train/x_train.npy\")\n", 263 | "x_test =np.load(\"./datas/test/x_test.npy\")\n", 264 | "\n", 265 | "IF = x_train[idx1]\n", 266 | "OF = x_train[idx2]\n", 267 | "BF = x_train[idx3]\n", 268 | "\n", 269 | "IO = x_test[idx1]\n", 270 | "IB = x_test[idx2]\n", 271 | "OB = x_test[idx3]\n", 272 | "IOB = x_test[idx4]\n", 273 | "#print(IO.shape)\n", 274 | "IF = IF.reshape(64,64) \n", 275 | "OF = OF.reshape(64,64) \n", 276 | "BF = BF.reshape(64,64) \n", 277 | "\n", 278 | "IO = IO.reshape(64,64) \n", 279 | "IB = IB.reshape(64,64) \n", 280 | "OB = OB.reshape(64,64) \n", 281 | "IOB = IOB.reshape(64,64) " 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "import matplotlib.pyplot as pyplot \n", 291 | "#import cv2 as cv\n", 292 | "# cv.imshow('GrayImage', IO)\n", 293 | "\n", 294 | "# GDATA = rgb2gray(IO)\n", 295 | "# pyplot.imshow(GDATA)\n", 296 | "plt.imshow(IF, cmap = plt.get_cmap('gray'))" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "import matplotlib.pyplot as pyplot \n", 306 | "#import cv2 as cv\n", 307 | "# cv.imshow('GrayImage', IO)\n", 308 | "\n", 309 | "# GDATA = rgb2gray(IO)\n", 310 | "# pyplot.imshow(GDATA)\n", 311 | "plt.imshow(OF, cmap = plt.get_cmap('gray'))" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "plt.imshow(BF, cmap = plt.get_cmap('gray'))" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "plt.imshow(IO, cmap = plt.get_cmap('gray'))" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "plt.imshow(IB, cmap = plt.get_cmap('gray'))" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "plt.imshow(OB, cmap = plt.get_cmap('gray'))" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "plt.imshow(IOB, cmap = plt.get_cmap('gray'))" 357 | ] 358 | } 359 | ], 360 | "metadata": { 361 | "kernelspec": { 362 | "display_name": "Python 3", 363 | "language": "python", 364 | "name": "python3" 365 | }, 366 | "language_info": { 367 | "codemirror_mode": { 368 | "name": "ipython", 369 | "version": 3 370 | }, 371 | "file_extension": ".py", 372 | "mimetype": "text/x-python", 373 | "name": "python", 374 | "nbconvert_exporter": "python", 375 | "pygments_lexer": "ipython3", 376 | "version": "3.6.13" 377 | } 378 | }, 379 | "nbformat": 4, 380 | "nbformat_minor": 4 381 | } 382 | -------------------------------------------------------------------------------- /Feature_extractor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "senior-spring", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import tensorflow as tf\n", 11 | "import numpy as np, h5py\n", 12 | "#import scipy.io as sio\n", 13 | "import sys\n", 14 | "import random\n", 15 | "#import kNN_cosine\n", 16 | "import re\n", 17 | "from numpy import * \n", 18 | "import numpy as np\n", 19 | "from keras.datasets import mnist\n", 20 | "from keras.utils import np_utils\n", 21 | "from keras.models import Sequential\n", 22 | "from keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten\n", 23 | "from keras.optimizers import Adam\n", 24 | "from keras.models import load_model,Model\n", 25 | "\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from sklearn import manifold\n", 28 | "from sklearn.cluster import KMeans\n", 29 | "from numpy import *\n", 30 | "import operator" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "printable-wisdom", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# 1. 加载数据\n", 41 | "# x_train_feature = np.load('./datas/train/x_train.npy',allow_pickle=True)\n", 42 | "x_train = np.load('./datas/train/x_train.npy').astype('float32')/255.\n", 43 | "y_train = np.load('./datas/train/y_train.npy')\n", 44 | "\n", 45 | "x_test = np.load('./datas/test/x_test.npy').astype('float32')/255.\n", 46 | "y_test = np.load('./datas/test/y_test.npy')\n", 47 | "\n", 48 | "y_tra = np_utils.to_categorical(y_train,num_classes=3)\n", 49 | "y_tes = np_utils.to_categorical(y_test,num_classes=4)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "jewish-banner", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# 定义顺序模型\n", 60 | "model = Sequential()\n", 61 | "\n", 62 | "# 第一个卷积层\n", 63 | "# input_shape 输入平面\n", 64 | "# filters 卷积核/滤波器个数\n", 65 | "# kernel_size 卷积窗口大小\n", 66 | "# strides 步长\n", 67 | "# padding padding方式 same/valid\n", 68 | "# activation 激活函数\n", 69 | "model.add(Convolution2D(\n", 70 | " input_shape = (64,64,1),\n", 71 | " filters = 32,\n", 72 | " kernel_size = 5,\n", 73 | " strides = 1,\n", 74 | " padding = 'same',\n", 75 | " activation = 'relu',\n", 76 | "))\n", 77 | "# 第一个池化层\n", 78 | "model.add(MaxPooling2D(\n", 79 | " pool_size = 2,\n", 80 | " strides = 2,\n", 81 | " padding = 'same',\n", 82 | "))\n", 83 | "# 第二个卷积层\n", 84 | "model.add(Convolution2D(64,5,strides=1,padding='same',activation='relu'))\n", 85 | "# 第二个池化层\n", 86 | "model.add(MaxPooling2D(2,2,'same'))\n", 87 | "# 把第二个池化层的输出扁平化为1维\n", 88 | "model.add(Flatten())\n", 89 | "# 第一个全连接层\n", 90 | "model.add(Dense(4096,activation='relu'))\n", 91 | "model.add(Dense(2048,activation='relu'))\n", 92 | "model.add(Dense(2048,activation='relu'))\n", 93 | "\n", 94 | "\n", 95 | "\n", 96 | "# Dropout\n", 97 | "model.add(Dropout(0.25))\n", 98 | "# 第二个全连接层\n", 99 | "model.add(Dense(3,activation='softmax'))\n", 100 | "\n", 101 | "# 定义优化器\n", 102 | "adam = Adam(lr=1e-4)\n", 103 | "\n", 104 | "# 定义优化器,loss function,训练过程中计算准确率\n", 105 | "model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])\n", 106 | "\n", 107 | "# 训练模型\n", 108 | "model.fit(x_train,y_tra,batch_size=64,epochs=25)\n", 109 | "\n", 110 | "# 评估模型\n", 111 | "loss,accuracy = model.evaluate(x_train,y_tra)\n", 112 | "\n", 113 | "\n", 114 | "#保存模型\n", 115 | "model.save('cnn_extractor.h5')\n", 116 | "print(model.summary())\n", 117 | "print('train loss',loss)\n", 118 | "print('train accuracy',accuracy)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "reflected-halloween", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "#x_train = np.load('./data/train/x_train.npy')\n", 129 | "# x_train = x_train.reshape(-1,64,64,1)/255.0\n", 130 | "\n", 131 | "#载入模型\n", 132 | "model = load_model('cnn_extractor.h5')\n", 133 | "resnet_model = Model(inputs=model.input, outputs=model.get_layer('dense_3').output)\n", 134 | "intermediate_output1 = resnet_model.predict(x_train)\n", 135 | "print(intermediate_output1.shape)\n", 136 | "#print(intermediate_output1)\n", 137 | "np.save('./datas/features/feature_train', intermediate_output1 )\n", 138 | "\n", 139 | "\n", 140 | "# x_test = np.load('./data/test/x_test.npy')\n", 141 | "# x_test = x_test.reshape(-1,64,64,1)/255.0\n", 142 | "#载入模型\n", 143 | "model = load_model('cnn_extractor.h5')\n", 144 | "resnet_model = Model(inputs=model.input, outputs=model.get_layer('dense_3').output)\n", 145 | "intermediate_output2 = resnet_model.predict(x_test)\n", 146 | "print(intermediate_output2.shape)\n", 147 | "np.save('./datas/features/feature_test', intermediate_output2 )" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "transparent-cooperative", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [] 157 | } 158 | ], 159 | "metadata": { 160 | "kernelspec": { 161 | "display_name": "Python 3", 162 | "language": "python", 163 | "name": "python3" 164 | }, 165 | "language_info": { 166 | "codemirror_mode": { 167 | "name": "ipython", 168 | "version": 3 169 | }, 170 | "file_extension": ".py", 171 | "mimetype": "text/x-python", 172 | "name": "python", 173 | "nbconvert_exporter": "python", 174 | "pygments_lexer": "ipython3", 175 | "version": "3.6.13" 176 | } 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 5 180 | } 181 | -------------------------------------------------------------------------------- /Lable_information_vector_definition.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "violent-charleston", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "printable-binding", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "data_Normal = np.loadtxt(\"./datas/origin/fault_data_0HP_IF.txt\") \n", 22 | "sig1 = data_Normal[1000:1500]\n", 23 | "plt.figure(figsize=(10, 3))\n", 24 | "plt.plot(sig1, color='blue')\n", 25 | "\n", 26 | "data_Norma2 = np.loadtxt(\"./datas/origin/fault_data_0HP_OF.txt\") \n", 27 | "sig2 = data_Norma2[820:1320]\n", 28 | "plt.figure(figsize=(10, 3))\n", 29 | "plt.plot(sig2, color='blue')\n", 30 | "\n", 31 | "data_Norma3 = np.loadtxt(\"./datas/origin/fault_data_0HP_BF.txt\") \n", 32 | "sig3 = data_Norma3[700:1200]\n", 33 | "plt.figure(figsize=(10, 3))\n", 34 | "plt.plot(sig3, color='blue')\n", 35 | "\n", 36 | "data_Norma4 = np.loadtxt(\"./datas/origin/fault_data_0HP_IO.txt\") \n", 37 | "sig4 = data_Norma4[800:1300]\n", 38 | "plt.figure(figsize=(10, 3))\n", 39 | "plt.plot(sig4, color='lightblue')\n", 40 | "\n", 41 | "data_Norma5 = np.loadtxt(\"./datas/origin/fault_data_0HP_IB.txt\") \n", 42 | "sig5 = data_Norma5[600:1100]\n", 43 | "plt.figure(figsize=(10, 3))\n", 44 | "plt.plot(sig5, color='lightblue')\n", 45 | "\n", 46 | "data_Norma6 = np.loadtxt(\"./datas/origin/fault_data_0HP_OB.txt\") \n", 47 | "sig6 = data_Norma6[700:1200]\n", 48 | "plt.figure(figsize=(10, 3))\n", 49 | "plt.plot(sig6, color='lightblue')\n", 50 | "\n", 51 | "data_Norma7 = np.loadtxt(\"./datas/origin/fault_data_0HP_IOB.txt\") \n", 52 | "sig7 = data_Norma7[700:1200]\n", 53 | "plt.figure(figsize=(10, 3))\n", 54 | "plt.plot(sig7, color='lightblue')" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "aquatic-equilibrium", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "#构建故障类别的语义向量\n", 65 | "t =5\n", 66 | "s1=[]\n", 67 | "s2=[]\n", 68 | "s3=[]\n", 69 | "s4=[]\n", 70 | "s5=[]\n", 71 | "s6=[]\n", 72 | "s7=[]\n", 73 | "threshold1 = np.max(sig1)/t\n", 74 | "threshold2 = np.max(sig2)/t\n", 75 | "threshold3 = np.max(sig3)/t\n", 76 | "threshold = 0\n", 77 | "if threshold1 > threshold2:\n", 78 | " threshold = threshold1\n", 79 | " if threshold1 > threshold3:\n", 80 | " threshold = threshold1\n", 81 | " else:\n", 82 | " threshold = threshold3\n", 83 | "else:\n", 84 | " threshold = threshold2\n", 85 | " if threshold2>threshold3:\n", 86 | " threshold = threshold2\n", 87 | " else:\n", 88 | " threshold = threshold3\n", 89 | "for i in range(500):\n", 90 | " #内圈\n", 91 | " if sig1[i]>=threshold and sig1[i]<2*threshold:\n", 92 | " s1.append(1)\n", 93 | " elif sig1[i]>=2*threshold and sig1[i]<3*threshold:\n", 94 | " s1.append(2)\n", 95 | " elif sig1[i]>=3*threshold and sig1[i]<4*threshold:\n", 96 | " s1.append(3)\n", 97 | " elif sig1[i]>=4*threshold and sig1[i]<=5*threshold:\n", 98 | " s1.append(4)\n", 99 | " else:\n", 100 | " s1.append(0) \n", 101 | " #外圈\n", 102 | "\n", 103 | " if sig2[i]>=threshold and sig2[i]<2*threshold:\n", 104 | " s2.append(1)\n", 105 | " elif sig2[i]>=2*threshold and sig2[i]<3*threshold:\n", 106 | " s2.append(2)\n", 107 | " elif sig2[i]>=3*threshold and sig2[i]<4*threshold:\n", 108 | " s2.append(3)\n", 109 | " elif sig2[i]>=4*threshold and sig2[i]<=5*threshold:\n", 110 | " s2.append(4)\n", 111 | " else:\n", 112 | " s2.append(0) \n", 113 | " #滚子 \n", 114 | " if sig3[i]>=threshold and sig3[i]<2*threshold:\n", 115 | " s3.append(1)\n", 116 | " elif sig3[i]>=2*threshold and sig3[i]<3*threshold:\n", 117 | " s3.append(2)\n", 118 | " elif sig3[i]>=3*threshold and sig3[i]<4*threshold:\n", 119 | " s3.append(3)\n", 120 | " elif sig3[i]>=4*threshold and sig3[i]<5*threshold:\n", 121 | " s3.append(4)\n", 122 | " else:\n", 123 | " s3.append(0)\n", 124 | "\n", 125 | "\n", 126 | "#内外圈 \n", 127 | "for i in range(500):\n", 128 | " if s1[i] == 4 or s2[i] == 4:\n", 129 | " s4.append(4)\n", 130 | " elif s1[i] == 3 or s2[i] == 3:\n", 131 | " s4.append(3)\n", 132 | " elif s1[i] == 2 or s2[i] == 2:\n", 133 | " s4.append(2)\n", 134 | " elif s1[i] == 1 or s2[i] == 1:\n", 135 | " s4.append(1)\n", 136 | " \n", 137 | " else:\n", 138 | " s4.append(0)\n", 139 | "\n", 140 | "\n", 141 | " \n", 142 | " \n", 143 | "#内圈滚子 \n", 144 | "for i in range(500):\n", 145 | " if s1[i] == 4 or s3[i] == 4:\n", 146 | " s5.append(4)\n", 147 | " elif s1[i] == 3 or s3[i] == 3:\n", 148 | " s5.append(3)\n", 149 | " elif s1[i] == 2 or s3[i] == 2:\n", 150 | " s5.append(2)\n", 151 | " elif s1[i] == 1 or s3[i] == 1:\n", 152 | " s5.append(1)\n", 153 | " else:\n", 154 | " s5.append(0)\n", 155 | "\n", 156 | "\n", 157 | "\n", 158 | " \n", 159 | "#外圈滚子 \n", 160 | "for i in range(500):\n", 161 | " if s2[i] == 4 or s3[i] == 4:\n", 162 | " s6.append(4)\n", 163 | " elif s2[i] == 3 or s3[i] == 3:\n", 164 | " s6.append(3)\n", 165 | " elif s2[i] == 2 or s3[i] == 2:\n", 166 | " s6.append(2)\n", 167 | " elif s2[i] == 1 or s3[i] == 1:\n", 168 | " s6.append(1)\n", 169 | " else :\n", 170 | " s6.append(0)\n", 171 | "\n", 172 | "\n", 173 | " \n", 174 | " \n", 175 | "#内外圈滚子 \n", 176 | "for i in range(500):\n", 177 | " if s2[i] == 4 or s3[i] == 4 or s1[i] == 4:\n", 178 | " s7.append(4)\n", 179 | " elif s2[i] == 3 or s3[i] == 3 or s1[i] == 3 :\n", 180 | " s7.append(3)\n", 181 | " elif s2[i] == 2 or s3[i] == 2 or s1[i] == 2 :\n", 182 | " s7.append(2)\n", 183 | " elif s2[i] == 1 or s3[i] == 1 or s1[i] == 1 :\n", 184 | " s7.append(1)\n", 185 | " else:\n", 186 | " s7.append(0)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "polished-unemployment", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "semantic_all = [s1,s2,s3,s4,s5,s6,s7]\n", 197 | "semantic_train = []\n", 198 | "semantic_test = []\n", 199 | "\n", 200 | "for i in range(2000):\n", 201 | " semantic_train.append(s1)\n", 202 | "for i in range(2000):\n", 203 | " semantic_train.append(s2)\n", 204 | "for i in range(2000):\n", 205 | " semantic_train.append(s3)\n", 206 | "semantic_train = np.vstack(semantic_train)\n", 207 | "\n", 208 | "for i in range(2000):\n", 209 | " semantic_test.append(s4)\n", 210 | "for i in range(2000):\n", 211 | " semantic_test.append(s5)\n", 212 | "for i in range(2000):\n", 213 | " semantic_test.append(s6)\n", 214 | "for i in range(2000):\n", 215 | " semantic_test.append(s7)\n", 216 | "semantic_test = np.vstack(semantic_test)\n", 217 | "\n", 218 | "# print(np.array(semantic_all).shape)\n", 219 | "# print(np.array(semantic_train).shape)\n", 220 | "# print(np.array(semantic_test).shape)\n", 221 | "\n", 222 | "np.save('./datas/semantic_test/semantic_all', semantic_all)\n", 223 | "np.save('./datas/semantic_test/semantic_train', semantic_train)\n", 224 | "np.save('./datas/semantic_test/semantic_test', semantic_test)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "pretty-statistics", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "Python 3", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.6.13" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 5 257 | } 258 | -------------------------------------------------------------------------------- /Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "veterinary-hello", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from __future__ import print_function, division\n", 11 | "\n", 12 | "from keras.datasets import mnist\n", 13 | "from keras.layers.merge import _Merge\n", 14 | "from keras.layers import Input, Dense, Reshape, Flatten, Dropout\n", 15 | "from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Concatenate\n", 16 | "from keras.layers.advanced_activations import LeakyReLU\n", 17 | "from keras.layers.convolutional import UpSampling2D, Conv2D\n", 18 | "from keras.models import Sequential, Model\n", 19 | "from keras.optimizers import RMSprop\n", 20 | "from functools import partial\n", 21 | "\n", 22 | "import keras.backend as K\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "import sys\n", 27 | "\n", 28 | "import numpy as np" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "found-melbourne", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "img_dim = 2048\n", 39 | "semantic_dim = 500\n", 40 | "# latent_dim = 120\n", 41 | "latent_dim = 50\n", 42 | "n_critic = 5\n", 43 | "\n", 44 | "epochs=50000\n", 45 | "batch_size=32\n", 46 | "sample_interval=50" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "indian-investment", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Generator network\n", 57 | "z = Input(shape=(latent_dim,),name='z_input')\n", 58 | "semantics = Input(shape=(semantic_dim,),name='semantics_input')\n", 59 | "\n", 60 | "merged_layer = Concatenate()([z,semantics])\n", 61 | "generator = Dense(2048, activation=\"relu\")(merged_layer)\n", 62 | "\n", 63 | "generator = Dense(2048, activation=\"relu\")(generator)\n", 64 | "# generator = Activation(\"tanh\")(generator)\n", 65 | "\n", 66 | "generator = Model(inputs=[z, semantics], outputs=generator, name='generator')\n", 67 | "\n", 68 | "generator.summary()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "angry-motion", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# Discriminator\n", 79 | "z = Input(shape=(latent_dim,),name='z_input')\n", 80 | "img = Input(shape=(img_dim,),name='img_input')\n", 81 | "semantics = Input(shape=(semantic_dim,),name='semantics_input')\n", 82 | "# d_in = concatenate([z, img, semantics])\n", 83 | "merged_layer = Concatenate()([z, img, semantics])\n", 84 | "\n", 85 | "discriminator = Dense(4096)(merged_layer)\n", 86 | "discriminator = LeakyReLU(alpha=0.2)(discriminator)\n", 87 | "discriminator = Dropout(0.25)(discriminator)\n", 88 | "\n", 89 | "\n", 90 | "discriminator = Dense(2048)(discriminator)\n", 91 | "discriminator = LeakyReLU(alpha=0.2)(discriminator)\n", 92 | "discriminator = Dropout(0.25)(discriminator)\n", 93 | "\n", 94 | "discriminator = Dense(1)(discriminator)\n", 95 | "\n", 96 | "discriminator = Model(inputs=[z, img, semantics], outputs=discriminator, name='discriminator')\n", 97 | "discriminator.summary()" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "white-juvenile", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "class RandomWeightedAverage(_Merge):\n", 108 | " \"\"\"Provides a (random) weighted average between real and generated image samples\"\"\"\n", 109 | " def _merge_function(self, inputs):\n", 110 | "# alpha = K.random_uniform((32, 1, 1, 1))\n", 111 | " alpha = K.random_uniform((32, 1))\n", 112 | " return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "blond-brush", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "# 惩罚函数\n", 123 | "def gradient_penalty_loss(y_true, y_pred, averaged_samples):\n", 124 | " \"\"\"\n", 125 | " Computes gradient penalty based on prediction and weighted real / fake samples\n", 126 | " \"\"\"\n", 127 | " gradients = K.gradients(y_pred, averaged_samples)[0]\n", 128 | " # compute the euclidean norm by squaring ...\n", 129 | " gradients_sqr = K.square(gradients)\n", 130 | " # ... summing over the rows ...\n", 131 | " gradients_sqr_sum = K.sum(gradients_sqr,\n", 132 | " axis=np.arange(1, len(gradients_sqr.shape)))\n", 133 | " # ... and sqrt\n", 134 | " gradient_l2_norm = K.sqrt(gradients_sqr_sum)\n", 135 | " # compute lambda * (1 - ||grad||)^2 still for each single sample\n", 136 | " gradient_penalty = K.square(1 - gradient_l2_norm)\n", 137 | " # return the mean as loss over all the batch samples\n", 138 | " return K.mean(gradient_penalty)\n", 139 | "\n", 140 | "\n", 141 | "#它取的是两个图像差异的均值。这种损失函数可以改善生成对抗网络的收敛性。\n", 142 | "def wasserstein_loss(y_true, y_pred):\n", 143 | " return K.mean(y_true * y_pred)\n", 144 | "\n", 145 | "# 均方差\n", 146 | "def mean_squared_error(y_true, y_pred):\n", 147 | " return K.mean(K.square(y_pred - y_true), axis=-1)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "restricted-particular", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "#-------------------------------\n", 158 | "# Construct Computational Graph\n", 159 | "# for the discriminator\n", 160 | "#-------------------------------\n", 161 | "# Freeze generator's layers while training discriminator\n", 162 | "generator.trainable = False\n", 163 | "real_img = Input(shape=(img_dim,),name='real_img')\n", 164 | "z = Input(shape=(latent_dim,),name='z_input')\n", 165 | "semantics = Input(shape=(semantic_dim,),name='semantics_input')\n", 166 | "\n", 167 | "# Generate image based of noise (fake sample)\n", 168 | "fake_img = generator([z, semantics])\n", 169 | "\n", 170 | "# Discriminator determines validity of the real and fake images\n", 171 | "fake = discriminator([z,fake_img, semantics])\n", 172 | "valid = discriminator([z,real_img, semantics])\n", 173 | "\n", 174 | "# Construct weighted average between real and fake images\n", 175 | "interpolated_img = RandomWeightedAverage()([real_img, fake_img])\n", 176 | "# Determine validity of weighted sample\n", 177 | "validity_interpolated = discriminator([z,interpolated_img, semantics])\n", 178 | "\n", 179 | "# Use Python partial to provide loss function with additional\n", 180 | "# 'averaged_samples' argument\n", 181 | "partial_gp_loss = partial(gradient_penalty_loss,\n", 182 | " averaged_samples=interpolated_img)\n", 183 | "partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names\n", 184 | "\n", 185 | "discriminator_model = Model(inputs=[real_img, z, semantics],\n", 186 | " outputs=[valid, fake, validity_interpolated])\n", 187 | "\n", 188 | "\n", 189 | "optimizer = RMSprop(lr=0.000001)\n", 190 | "\n", 191 | "\n", 192 | "discriminator_model.compile(loss=[wasserstein_loss,\n", 193 | " wasserstein_loss,\n", 194 | " partial_gp_loss],\n", 195 | " optimizer=optimizer,\n", 196 | " loss_weights=[1, 1, 10])\n", 197 | "\n", 198 | "discriminator_model.summary()" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "present-japanese", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "#-------------------------------\n", 209 | "# Construct Computational Graph\n", 210 | "# for Generator\n", 211 | "#-------------------------------\n", 212 | "\n", 213 | "# For the generator we freeze the discriminator's layers\n", 214 | "discriminator.trainable = False\n", 215 | "generator.trainable = True\n", 216 | "\n", 217 | "# Sampled noise for input to generator\n", 218 | "z = Input(shape=(latent_dim,),name='z_input')\n", 219 | "semantics = Input(shape=(semantic_dim,),name='semantics_input')\n", 220 | "# Generate images based of noise\n", 221 | "img = generator([z, semantics])\n", 222 | "# Discriminator determines validity\n", 223 | "valid = discriminator([z, img, semantics])\n", 224 | "# Defines generator model\n", 225 | "generator_model = Model([z, semantics], valid)\n", 226 | "generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)\n", 227 | "generator_model.summary()" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "dying-holiday", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "from numpy import * \n", 238 | "import operator\n", 239 | "def classify(inX,dataSet,labels,k):\n", 240 | " # 获取维度\n", 241 | " dataSetSize=dataSet.shape[0] # 训练数据集数量\n", 242 | "\n", 243 | " diffMat=tile(inX,(dataSetSize,1))-dataSet # 测试样本的各维度的差值\n", 244 | "\n", 245 | " sqDiffMat=diffMat**2 # 平方计算\n", 246 | "\n", 247 | " sqDistance=sqDiffMat.sum(axis=1) # 输出每行的值\n", 248 | "\n", 249 | " distances=sqDistance**0.5 # 开方计算\n", 250 | "\n", 251 | " sortedDistances=distances.argsort() # 排序 按距离从小到大 输出索引\n", 252 | "\n", 253 | " classCount={}\n", 254 | " for i in range(k):\n", 255 | "# print(sortedDistances[i])\n", 256 | " voteIlabel=labels[sortedDistances[i]]\n", 257 | " classCount[voteIlabel]=classCount.get(voteIlabel,0)+1.0\n", 258 | " sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)\n", 259 | "\n", 260 | " return sortedClassCount[0][0]\n", 261 | "\n", 262 | "def accuracy_train(x,y,z):\n", 263 | " group = z\n", 264 | " labels= [0,1,2]\n", 265 | "\n", 266 | " num=0\n", 267 | " y_pred =[]\n", 268 | "\n", 269 | " for i in range(6000):\n", 270 | " res=classify(x[i],group,labels,1)\n", 271 | " y_pred.append(res)\n", 272 | " \n", 273 | " if res == y[i]:\n", 274 | " num = num+1\n", 275 | "\n", 276 | " accuracy = num/6000\n", 277 | " \n", 278 | " return accuracy\n", 279 | "\n", 280 | "def accuracy_test1(x,y,z):\n", 281 | " group = z\n", 282 | " labels= [0,1,2]\n", 283 | "\n", 284 | " num=0\n", 285 | " y_pred =[]\n", 286 | "\n", 287 | " for i in range(6000):\n", 288 | " res=classify(x[i],group,labels,1)\n", 289 | " y_pred.append(res)\n", 290 | " \n", 291 | " if res == y[i]:\n", 292 | " num = num+1\n", 293 | "\n", 294 | " accuracy = num/6000\n", 295 | " \n", 296 | " return accuracy\n", 297 | "\n", 298 | "def accuracy_test2(x,y,z):\n", 299 | " group = z\n", 300 | " labels= [0,1,2,3]\n", 301 | "\n", 302 | " num=0\n", 303 | " y_pred =[]\n", 304 | "\n", 305 | " for i in range(8000):\n", 306 | " res=classify(x[i],group,labels,1)\n", 307 | " y_pred.append(res)\n", 308 | " \n", 309 | " if res == y[i]:\n", 310 | " num = num+1\n", 311 | "\n", 312 | " accuracy = num/8000\n", 313 | " \n", 314 | " return accuracy" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "economic-jimmy", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "x_train = np.load('./datas/features/feature_train.npy')\n", 325 | "y_train = np.load('./datas/train/y_train.npy')\n", 326 | "\n", 327 | "x_test = np.load('./datas/features/feature_test.npy')\n", 328 | "y_test = np.load('./datas/test/y_test.npy')\n", 329 | "\n", 330 | "semantic_train = np.load('./datas/semantic_test/semantic_train.npy')\n", 331 | "semantic_test = np.load('./datas/semantic_test/semantic_test.npy')\n", 332 | "print(x_train.shape,y_train.shape)\n", 333 | "print(x_test.shape,y_test.shape)\n", 334 | "print(semantic_train.shape,semantic_test.shape)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "instructional-compression", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "flag_acc_test1 = 0\n", 345 | "flag_acc_test2 = 0\n", 346 | "\n", 347 | "valid = -np.ones((batch_size, 1))\n", 348 | "fake = np.ones((batch_size, 1))\n", 349 | "\n", 350 | "dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty\n", 351 | "for epoch in range(epochs):\n", 352 | "\n", 353 | " for _ in range(n_critic):\n", 354 | "\n", 355 | " # ---------------------\n", 356 | " # Train Discriminator\n", 357 | " # ---------------------\n", 358 | "\n", 359 | " # Select a random batch of images\n", 360 | " idx = np.random.randint(0, x_train.shape[0], batch_size)\n", 361 | " imgs = x_train[idx]\n", 362 | " semantics = semantic_train[idx]\n", 363 | " # Sample generator input\n", 364 | " noise = np.random.normal(0, 1, (batch_size, latent_dim))\n", 365 | " # Train the discriminator\n", 366 | " d_loss = discriminator_model.train_on_batch([imgs, noise, semantics],\n", 367 | " [valid, fake, dummy])\n", 368 | "\n", 369 | " # ---------------------\n", 370 | " # Train Generatorhttp://localhost:8888/notebooks/lk_model/WGANGP.ipynb#\n", 371 | " # ---------------------\n", 372 | "\n", 373 | " g_loss = generator_model.train_on_batch([noise,semantics], valid)\n", 374 | " \n", 375 | " mean_squared_error_loss = mean_squared_error(imgs,generator.predict([noise,semantics]))\n", 376 | "\n", 377 | " # If at save interval => save generated image samples\n", 378 | " if epoch % sample_interval == 0:\n", 379 | " samples = 7\n", 380 | " noise = np.random.normal(0, 1, (samples, latent_dim))\n", 381 | " semantics_all = np.load('./datas/semantic_test/semantic_all.npy')\n", 382 | " fake_img = generator.predict([noise, semantics_all])\n", 383 | " z1 = fake_img[0:3]\n", 384 | " z2 = fake_img[3:6]\n", 385 | " z3 = fake_img[3:7]\n", 386 | " \n", 387 | " acc_train= accuracy_train(x_train,y_train,z1)\n", 388 | " acc_test1 = accuracy_test1(x_test,y_test,z2)\n", 389 | " acc_test2 = accuracy_test2(x_test,y_test,z3)\n", 390 | " \n", 391 | " if(flag_acc_test1 < acc_test1):\n", 392 | " flag_acc_test1 = acc_test1\n", 393 | " \n", 394 | " if(flag_acc_test2 < acc_test2):\n", 395 | " flag_acc_test2 = acc_test2\n", 396 | " \n", 397 | " \n", 398 | " print (\"[epoch: %d] [acc_train: %f] [acc_test1: %f] [acc_test2: %f]\" % (epoch,acc_train, flag_acc_test1, flag_acc_test2))\n" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "id": "vulnerable-internet", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [] 408 | } 409 | ], 410 | "metadata": { 411 | "kernelspec": { 412 | "display_name": "tensorflow_cpu", 413 | "language": "python", 414 | "name": "tensorflow_cpu" 415 | }, 416 | "language_info": { 417 | "codemirror_mode": { 418 | "name": "ipython", 419 | "version": 3 420 | }, 421 | "file_extension": ".py", 422 | "mimetype": "text/x-python", 423 | "name": "python", 424 | "nbconvert_exporter": "python", 425 | "pygments_lexer": "ipython3", 426 | "version": "3.6.13" 427 | } 428 | }, 429 | "nbformat": 4, 430 | "nbformat_minor": 5 431 | } 432 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label-Information-Vector-Generative-Zero-shot-Model 2 | This is the code of paper "A Label Information Vector Generative Zero-shot Model for the Diagnosis of Compound Faults", Expert Systems With Applications, June 2023 3 | 4 | Four Python programs are included. Each program implements a function block of our method desribed in our manuscript "A Label Information Vector Generative Zero-shot Model for the Diagnosis of Compound Faults". Please cite our paper if you download and use the programs for any reasons. 5 | 6 | Juan Xu, Kang Li, Yuqi Fan, Xiaohui Yuan, A Label Information Vector Generative Zero-shot Model for the Diagnosis of Compound Faults, Expert Systems with Applications, Accepted on June 15, 2023 7 | --------------------------------------------------------------------------------