├── 微信客户端 ├── pages │ ├── index │ │ ├── index.json │ │ ├── index.wxml │ │ ├── index.wxss │ │ └── index.js │ ├── predict │ │ ├── predict.json │ │ ├── predict.wxml │ │ ├── predict.wxss │ │ └── predict.js │ ├── select │ │ ├── select.json │ │ ├── select.wxml │ │ ├── select.wxss │ │ └── select.js │ └── upload │ │ ├── upload.json │ │ ├── upload.wxml │ │ ├── upload.wxss │ │ └── upload.js ├── .DS_Store ├── app.wxss ├── app.json ├── utils │ └── util.js ├── project.config.json └── app.js ├── 演示视屏.mp4 ├── 服务端 ├── .DS_Store ├── datas │ ├── .DS_Store │ └── test_data.txt ├── __pycache__ │ ├── CONFIG.cpython-36.pyc │ └── data.cpython-36.pyc ├── .idea │ ├── misc.xml │ ├── modules.xml │ ├── 服务端.iml │ └── workspace.xml ├── client.py ├── CONFIG.py ├── models.py ├── model.py ├── server.py ├── data.py ├── test.py ├── ensemble.py ├── train.py ├── predict.py └── utils.py ├── Imgs ├── Picture1.png ├── Picture2.png ├── Picture3.png └── Picture4.png └── README.md /微信客户端/pages/index/index.json: -------------------------------------------------------------------------------- 1 | { 2 | "usingComponents": {} 3 | } -------------------------------------------------------------------------------- /微信客户端/pages/predict/predict.json: -------------------------------------------------------------------------------- 1 | { 2 | "usingComponents": {} 3 | } -------------------------------------------------------------------------------- /微信客户端/pages/select/select.json: -------------------------------------------------------------------------------- 1 | { 2 | "usingComponents": {} 3 | } -------------------------------------------------------------------------------- /微信客户端/pages/upload/upload.json: -------------------------------------------------------------------------------- 1 | { 2 | "usingComponents": {} 3 | } -------------------------------------------------------------------------------- /演示视屏.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/演示视屏.mp4 -------------------------------------------------------------------------------- /服务端/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/服务端/.DS_Store -------------------------------------------------------------------------------- /Imgs/Picture1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/Imgs/Picture1.png -------------------------------------------------------------------------------- /Imgs/Picture2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/Imgs/Picture2.png -------------------------------------------------------------------------------- /Imgs/Picture3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/Imgs/Picture3.png -------------------------------------------------------------------------------- /Imgs/Picture4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/Imgs/Picture4.png -------------------------------------------------------------------------------- /微信客户端/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/微信客户端/.DS_Store -------------------------------------------------------------------------------- /服务端/datas/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/服务端/datas/.DS_Store -------------------------------------------------------------------------------- /微信客户端/pages/upload/upload.wxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /微信客户端/pages/predict/predict.wxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 预测结果 4 | {{result}} 5 | 6 | -------------------------------------------------------------------------------- /服务端/__pycache__/CONFIG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/服务端/__pycache__/CONFIG.cpython-36.pyc -------------------------------------------------------------------------------- /服务端/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daimuuc/Pigmented-skin-disease-automatic-recognition-and-classification-system/HEAD/服务端/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /服务端/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /微信客户端/pages/select/select.wxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /微信客户端/app.wxss: -------------------------------------------------------------------------------- 1 | /**app.wxss**/ 2 | .container { 3 | height: 100%; 4 | display: flex; 5 | flex-direction: column; 6 | align-items: center; 7 | justify-content: space-between; 8 | padding: 200rpx 0; 9 | box-sizing: border-box; 10 | } 11 | -------------------------------------------------------------------------------- /服务端/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /微信客户端/pages/select/select.wxss: -------------------------------------------------------------------------------- 1 | /* pages/select/select.wxss */ 2 | page { 3 | background:#FFDAB9; 4 | } 5 | 6 | .button { 7 | width: 360rpx; 8 | height: 90rpx; 9 | margin: 40rpx; 10 | background-color: #FFA07A; 11 | color: white; 12 | border-radius: 98rpx; 13 | background: bg_red; 14 | } -------------------------------------------------------------------------------- /微信客户端/pages/upload/upload.wxss: -------------------------------------------------------------------------------- 1 | /* pages/upload/upload.wxss */ 2 | page { 3 | background:#FFDAB9; 4 | } 5 | 6 | .button { 7 | width: 360rpx; 8 | height: 90rpx; 9 | margin: 40rpx; 10 | background-color: #FFA07A; 11 | color: white; 12 | border-radius: 98rpx; 13 | background: bg_red; 14 | } -------------------------------------------------------------------------------- /微信客户端/app.json: -------------------------------------------------------------------------------- 1 | { 2 | "pages":[ 3 | "pages/index/index", 4 | "pages/select/select", 5 | "pages/upload/upload", 6 | "pages/predict/predict" 7 | ], 8 | "window":{ 9 | "backgroundTextStyle":"light", 10 | "navigationBarBackgroundColor": "#FFDAB9", 11 | "navigationBarTitleText": "色素性皮肤病七分类系统", 12 | "navigationBarTextStyle":"white" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /微信客户端/pages/predict/predict.wxss: -------------------------------------------------------------------------------- 1 | /* pages/predict/predict.wxss */ 2 | page { 3 | background:#FFDAB9; 4 | } 5 | 6 | .text { 7 | font-size: 50rpx; 8 | color: #87CEFF; 9 | margin-top: 10rpx; 10 | display: flex; 11 | align-items: center; 12 | justify-content: center; 13 | } 14 | 15 | .title { 16 | font-size: 80rpx; 17 | color: #87CEFF; 18 | margin-top: 40%; 19 | display: flex; 20 | align-items: center; 21 | justify-content: center; 22 | } -------------------------------------------------------------------------------- /服务端/.idea/服务端.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /微信客户端/utils/util.js: -------------------------------------------------------------------------------- 1 | const formatTime = date => { 2 | const year = date.getFullYear() 3 | const month = date.getMonth() + 1 4 | const day = date.getDate() 5 | const hour = date.getHours() 6 | const minute = date.getMinutes() 7 | const second = date.getSeconds() 8 | 9 | return [year, month, day].map(formatNumber).join('/') + ' ' + [hour, minute, second].map(formatNumber).join(':') 10 | } 11 | 12 | const formatNumber = n => { 13 | n = n.toString() 14 | return n[1] ? n : '0' + n 15 | } 16 | 17 | module.exports = { 18 | formatTime: formatTime 19 | } 20 | -------------------------------------------------------------------------------- /微信客户端/pages/index/index.wxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | {{userInfo.nickName}} 8 | 9 | 10 | 11 | {{motto}} 12 | 13 | 14 | -------------------------------------------------------------------------------- /微信客户端/pages/index/index.wxss: -------------------------------------------------------------------------------- 1 | /**index.wxss**/ 2 | page { 3 | background:#FFDAB9; 4 | } 5 | 6 | .userinfo { 7 | display: flex; 8 | flex-direction: column; 9 | align-items: center; 10 | } 11 | 12 | .userinfo-avatar { 13 | width: 128rpx; 14 | height: 128rpx; 15 | margin: 20rpx; 16 | border-radius: 50%; 17 | } 18 | 19 | .userinfo-nickname { 20 | color: #aaa; 21 | } 22 | 23 | .usermotto { 24 | margin-top: 200px; 25 | color: #FAFFF0 26 | } 27 | 28 | .button { 29 | width: 360rpx; 30 | height: 90rpx; 31 | margin: 40rpx; 32 | background-color: #FFA07A; 33 | color: white; 34 | border-radius: 98rpx; 35 | background: bg_red; 36 | } -------------------------------------------------------------------------------- /微信客户端/project.config.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "项目配置文件", 3 | "packOptions": { 4 | "ignore": [] 5 | }, 6 | "setting": { 7 | "urlCheck": true, 8 | "es6": true, 9 | "postcss": true, 10 | "minified": true, 11 | "newFeature": true, 12 | "autoAudits": false 13 | }, 14 | "compileType": "miniprogram", 15 | "libVersion": "2.6.6", 16 | "appid": "wxfa929dd31067ae5c", 17 | "projectname": "SevenClassification", 18 | "debugOptions": { 19 | "hidedInDevtools": [] 20 | }, 21 | "isGameTourist": false, 22 | "condition": { 23 | "search": { 24 | "current": -1, 25 | "list": [] 26 | }, 27 | "conversation": { 28 | "current": -1, 29 | "list": [] 30 | }, 31 | "game": { 32 | "currentL": -1, 33 | "list": [] 34 | }, 35 | "miniprogram": { 36 | "current": -1, 37 | "list": [] 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /微信客户端/pages/predict/predict.js: -------------------------------------------------------------------------------- 1 | // pages/predict/predict.js 2 | const app = getApp() 3 | Page({ 4 | 5 | /** 6 | * 页面的初始数据 7 | */ 8 | data: { 9 | 10 | }, 11 | 12 | /** 13 | * 生命周期函数--监听页面加载 14 | */ 15 | onLoad: function (options) { 16 | console.log(typeof(app.globalData.result)) 17 | let json = JSON.parse(app.globalData.result); 18 | this.setData({ result: json["result"]}) 19 | }, 20 | 21 | /** 22 | * 生命周期函数--监听页面初次渲染完成 23 | */ 24 | onReady: function () { 25 | 26 | }, 27 | 28 | /** 29 | * 生命周期函数--监听页面显示 30 | */ 31 | onShow: function () { 32 | 33 | }, 34 | 35 | /** 36 | * 生命周期函数--监听页面隐藏 37 | */ 38 | onHide: function () { 39 | 40 | }, 41 | 42 | /** 43 | * 生命周期函数--监听页面卸载 44 | */ 45 | onUnload: function () { 46 | 47 | }, 48 | 49 | /** 50 | * 页面相关事件处理函数--监听用户下拉动作 51 | */ 52 | onPullDownRefresh: function () { 53 | 54 | }, 55 | 56 | /** 57 | * 页面上拉触底事件的处理函数 58 | */ 59 | onReachBottom: function () { 60 | 61 | }, 62 | 63 | /** 64 | * 用户点击右上角分享 65 | */ 66 | onShareAppMessage: function () { 67 | 68 | } 69 | }) -------------------------------------------------------------------------------- /微信客户端/app.js: -------------------------------------------------------------------------------- 1 | //app.js 2 | App({ 3 | onLaunch: function () { 4 | // 展示本地存储能力 5 | var logs = wx.getStorageSync('logs') || [] 6 | logs.unshift(Date.now()) 7 | wx.setStorageSync('logs', logs) 8 | 9 | // 登录 10 | wx.login({ 11 | success: res => { 12 | // 发送 res.code 到后台换取 openId, sessionKey, unionId 13 | } 14 | }) 15 | // 获取用户信息 16 | wx.getSetting({ 17 | success: res => { 18 | if (res.authSetting['scope.userInfo']) { 19 | // 已经授权,可以直接调用 getUserInfo 获取头像昵称,不会弹框 20 | wx.getUserInfo({ 21 | success: res => { 22 | // 可以将 res 发送给后台解码出 unionId 23 | this.globalData.userInfo = res.userInfo 24 | 25 | // 由于 getUserInfo 是网络请求,可能会在 Page.onLoad 之后才返回 26 | // 所以此处加入 callback 以防止这种情况 27 | if (this.userInfoReadyCallback) { 28 | this.userInfoReadyCallback(res) 29 | } 30 | } 31 | }) 32 | } 33 | } 34 | }) 35 | }, 36 | globalData: { 37 | userInfo: null, 38 | tempFilePaths: null, 39 | result: {} 40 | } 41 | }) -------------------------------------------------------------------------------- /服务端/client.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | """ 3 | 模拟客户端请求. 4 | client: 5 | 模拟客户端上传图像进行色素性皮肤病七分类预测请求 6 | """ 7 | from __future__ import print_function, division 8 | import requests 9 | 10 | ################################################################################ 11 | #模拟客户端上传图像进行色素性皮肤病七分类预测请求 12 | ################################################################################ 13 | def client(path): 14 | """ 15 | 模拟客户端请求,并展示预测结果. 16 | 17 | param: 18 | path -- 上传图像地址 19 | """ 20 | # 上传文件信息 21 | files = { 22 | 'file': open(path, 'rb') 23 | } 24 | 25 | # 模拟客户端请求 26 | results = requests.post('https://www.ponma.cn:8086/process', 27 | files = files) 28 | 29 | # 显示预测结果 30 | results = results.json() 31 | for k, v in results.items(): 32 | print(k, ' ', v) 33 | 34 | ################################################################################ 35 | #函数入口 36 | ################################################################################ 37 | if __name__ == '__main__': 38 | ## 模拟客户端请求 ## 39 | # 图片地址 40 | path = 'datas/images/ISIC_0024306.jpg' 41 | client(path) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 项目名称: 2 | 色素性皮肤病七分类系统(微信版) 3 | 4 | 网站版项目地址: 5 | https://github.com/JunhaoCheng/Pigmented-skin-disease-automatic-recognition-and-classification-system/tree/website 6 | 7 | 项目功能: 8 | 基于深度学习、集成学习、迁移学习、GAN等技术的色素性皮肤病自动识别七分类系统。 9 | 本系统主要由服务端和客户端两个模块组成。服务端使用DenseNet161和SENet154 10 | 两个模型构成集成模型,从而实现了对色素性皮肤病自动识别七分类。客户端使 11 | 用微信小程序和网站(SSM、Springboot)开发。用户通过微信小程序或网站上传图像到服务端,服务端返回所属类别。 12 | 13 | 项目组织结构: 14 | 服务端: 15 | ![Image text](https://github.com/JunhaoCheng/-Pigmented-skin-disease-automatic-recognition-and-classification-system-/blob/master/Imgs/Picture4.png) 16 | 17 | 18 | 项目部署: 19 | 1、修改server.py文件并运行 20 | 2、修改client.py文件并运行(可选) 21 | 3、修改微信客户端服务器配置 22 | 23 | 项目数据集: 24 | https://challenge2018.isic-archive.com/task3/ 25 | 26 | 27 | 项目运行效果: 28 | ![Image text](https://github.com/JunhaoCheng/-Pigmented-skin-disease-automatic-recognition-and-classification-system-/blob/master/Imgs/Picture1.png)![Image text](https://github.com/JunhaoCheng/-Pigmented-skin-disease-automatic-recognition-and-classification-system-/blob/master/Imgs/Picture2.png)![Image text](https://github.com/JunhaoCheng/-Pigmented-skin-disease-automatic-recognition-and-classification-system-/blob/master/Imgs/Picture3.png) 29 | 30 | 31 | 项目演示视屏: 32 | https://github.com/JunhaoCheng/Pigmented-skin-disease-automatic-recognition-and-classification-system/blob/master/演示视屏.mp4 33 | -------------------------------------------------------------------------------- /微信客户端/pages/index/index.js: -------------------------------------------------------------------------------- 1 | //index.js 2 | //获取应用实例 3 | const app = getApp() 4 | 5 | Page({ 6 | data: { 7 | motto: 'Hello World', 8 | userInfo: {}, 9 | hasUserInfo: false, 10 | canIUse: wx.canIUse('button.open-type.getUserInfo') 11 | }, 12 | //事件处理函数 13 | bindViewTap: function() { 14 | wx.navigateTo({ 15 | url: '../select/select' 16 | }) 17 | }, 18 | onLoad: function () { 19 | if (app.globalData.userInfo) { 20 | this.setData({ 21 | userInfo: app.globalData.userInfo, 22 | hasUserInfo: true 23 | }) 24 | } else if (this.data.canIUse){ 25 | // 由于 getUserInfo 是网络请求,可能会在 Page.onLoad 之后才返回 26 | // 所以此处加入 callback 以防止这种情况 27 | app.userInfoReadyCallback = res => { 28 | this.setData({ 29 | userInfo: res.userInfo, 30 | hasUserInfo: true 31 | }) 32 | } 33 | } else { 34 | // 在没有 open-type=getUserInfo 版本的兼容处理 35 | wx.getUserInfo({ 36 | success: res => { 37 | app.globalData.userInfo = res.userInfo 38 | this.setData({ 39 | userInfo: res.userInfo, 40 | hasUserInfo: true 41 | }) 42 | } 43 | }) 44 | } 45 | }, 46 | getUserInfo: function(e) { 47 | console.log(e) 48 | app.globalData.userInfo = e.detail.userInfo 49 | this.setData({ 50 | userInfo: e.detail.userInfo, 51 | hasUserInfo: true 52 | }) 53 | } 54 | }) 55 | -------------------------------------------------------------------------------- /服务端/CONFIG.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置文件 3 | ssl配置 4 | 数据存储地址 5 | 模型配置 6 | 模型存储地址 7 | """ 8 | ################################################################################ 9 | #ssl配置,可以参考https://blog.csdn.net/robin912/article/details/80698896. 10 | ################################################################################ 11 | pem = '你的文件地址' 12 | key = '你的文件地址' 13 | 14 | ################################################################################ 15 | #数据存储地址. 16 | ################################################################################ 17 | # 图像、标签csv文件存储目录地址 18 | data_path = 'datas' 19 | 20 | # 训练集存储地址 21 | traindataset_path = 'datas/train_data.txt' 22 | 23 | # 测试集存储地址 24 | testdataset_path = 'datas/test_data.txt' 25 | 26 | # 训练集loss、验证集loss存储地址 27 | loss_log = 'logs/loss_log.txt' 28 | 29 | # 预测图片目录 30 | predict_dir = 'predict' 31 | 32 | # 服务端存储上传图像的目录地址 33 | upload_dir = './uploads' 34 | 35 | 36 | 37 | ################################################################################ 38 | #模型配置. 39 | ################################################################################ 40 | # 单张图片复制次数 41 | num = 16 42 | # 挑选多少图片进行计算平均值(int)和标准差(std) 43 | cnum = 2000 44 | 45 | ################################################################################ 46 | #模型存储地址. 47 | ################################################################################ 48 | # 最佳模型存储地址 49 | best_model = 'intermediate_models/best_model.pt' 50 | 51 | # 混淆矩阵图存储地址 52 | confusion_matrix_image = 'intermediate_models/confusion_matrix.png' 53 | 54 | # 损失图存储地址 55 | loss_image = 'intermediate_models/loss_plot.png' 56 | 57 | # 中间模型存储目录 58 | intermediate_model = 'intermediate_models' 59 | 60 | # 训练好模型地址 61 | inception_model = 'trained_models/inception_model.pt' 62 | resnet_model = 'trained_models/resnet_model.pt' 63 | densenet_model = 'trained_models/densenet_model.pt' 64 | senet_model = 'trained_models/senet_model.pt' 65 | 66 | -------------------------------------------------------------------------------- /微信客户端/pages/upload/upload.js: -------------------------------------------------------------------------------- 1 | // pages/upload/upload.js 2 | const app = getApp() 3 | Page({ 4 | 5 | /** 6 | * 页面的初始数据 7 | */ 8 | data: { 9 | 10 | }, 11 | 12 | /** 13 | * 生命周期函数--监听页面加载 14 | */ 15 | onLoad: function (options) { 16 | 17 | }, 18 | 19 | /** 20 | * 生命周期函数--监听页面初次渲染完成 21 | */ 22 | onReady: function () { 23 | 24 | }, 25 | 26 | /** 27 | * 生命周期函数--监听页面显示 28 | */ 29 | onShow: function () { 30 | 31 | }, 32 | 33 | /** 34 | * 生命周期函数--监听页面隐藏 35 | */ 36 | onHide: function () { 37 | 38 | }, 39 | 40 | /** 41 | * 生命周期函数--监听页面卸载 42 | */ 43 | onUnload: function () { 44 | 45 | }, 46 | 47 | /** 48 | * 页面相关事件处理函数--监听用户下拉动作 49 | */ 50 | onPullDownRefresh: function () { 51 | 52 | }, 53 | 54 | /** 55 | * 页面上拉触底事件的处理函数 56 | */ 57 | onReachBottom: function () { 58 | 59 | }, 60 | 61 | /** 62 | * 用户点击右上角分享 63 | */ 64 | onShareAppMessage: function () { 65 | 66 | }, 67 | 68 | /** 69 | * 上传图片按钮点击事件 70 | */ 71 | upload: function() { 72 | //获取图片地址 73 | var tempFilePaths = app.globalData.tempFilePaths 74 | //显示进度弹框 75 | wx.showLoading({ 76 | title: '处理中,请耐心等待', 77 | mask: true 78 | }) 79 | wx.uploadFile({ 80 | url: 'https://www.ponma.cn:8086/process', 81 | filePath: tempFilePaths[0], 82 | name: 'file', 83 | success(res) { 84 | //关闭进度弹框 85 | wx.hideLoading() 86 | //保存结果 87 | const result = res.data 88 | console.log(result) 89 | app.globalData.result = result 90 | //界面跳转 91 | wx.navigateTo({ 92 | url: '../predict/predict' 93 | }) 94 | }, 95 | fail() { 96 | //关闭进度弹框 97 | wx.hideLoading() 98 | //显示失败弹框 99 | wx.showToast({ 100 | title: '处理失败', 101 | icon: 'none', 102 | duration: 2000 103 | }) 104 | } 105 | }) 106 | } 107 | }) -------------------------------------------------------------------------------- /微信客户端/pages/select/select.js: -------------------------------------------------------------------------------- 1 | // pages/select/select.js 2 | const app = getApp() 3 | Page({ 4 | 5 | /** 6 | * 页面的初始数据 7 | */ 8 | data: { 9 | 10 | }, 11 | 12 | /** 13 | * 生命周期函数--监听页面加载 14 | */ 15 | onLoad: function (options) { 16 | 17 | }, 18 | 19 | /** 20 | * 生命周期函数--监听页面初次渲染完成 21 | */ 22 | onReady: function () { 23 | 24 | }, 25 | 26 | /** 27 | * 生命周期函数--监听页面显示 28 | */ 29 | onShow: function () { 30 | 31 | }, 32 | 33 | /** 34 | * 生命周期函数--监听页面隐藏 35 | */ 36 | onHide: function () { 37 | 38 | }, 39 | 40 | /** 41 | * 生命周期函数--监听页面卸载 42 | */ 43 | onUnload: function () { 44 | 45 | }, 46 | 47 | /** 48 | * 页面相关事件处理函数--监听用户下拉动作 49 | */ 50 | onPullDownRefresh: function () { 51 | 52 | }, 53 | 54 | /** 55 | * 页面上拉触底事件的处理函数 56 | */ 57 | onReachBottom: function () { 58 | 59 | }, 60 | 61 | /** 62 | * 用户点击右上角分享 63 | */ 64 | onShareAppMessage: function () { 65 | 66 | }, 67 | 68 | /** 69 | *本地图片按钮点击事件 70 | */ 71 | select_album: function() { 72 | wx.chooseImage({ 73 | count: 1, 74 | sizeType: ['original', 'compressed'], 75 | sourceType: ['album'], 76 | success(res) { 77 | // tempFilePath可以作为img标签的src属性显示图片 78 | const tempFilePaths = res.tempFilePaths 79 | console.log(tempFilePaths) 80 | app.globalData.tempFilePaths = tempFilePaths 81 | wx.navigateTo({ 82 | url: '../upload/upload' 83 | }) 84 | } 85 | }) 86 | }, 87 | 88 | /** 89 | * 拍摄图片按钮点击事件 90 | */ 91 | select_camera: function() { 92 | wx.chooseImage({ 93 | count: 1, 94 | sizeType: ['original', 'compressed'], 95 | sourceType: ['camera'], 96 | success(res) { 97 | // tempFilePath可以作为img标签的src属性显示图片 98 | const tempFilePaths = res.tempFilePaths 99 | console.log(tempFilePaths) 100 | app.globalData.tempFilePaths = tempFilePaths 101 | wx.navigateTo({ 102 | url: '../upload/upload' 103 | }) 104 | } 105 | }) 106 | } 107 | }) -------------------------------------------------------------------------------- /服务端/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 加载模型. 4 | """ 5 | from __future__ import print_function, division 6 | from torchvision.models import resnet152, densenet161, inception_v3 7 | import torch.nn as nn 8 | import pretrainedmodels 9 | 10 | ################################################################################ 11 | # 加载Resnet152模型 12 | ################################################################################ 13 | def resnet(pretrained = True): 14 | """ 15 | 加载Resnet152模型 16 | 17 | :param 18 | pretrained(bool) -- 是否预训练 19 | :return 20 | 输出为7的Resnet152模型 21 | """ 22 | # 加载模型 23 | model = resnet152(pretrained = pretrained) 24 | 25 | # 修改模型输出为7 26 | fc_features = model.fc.in_features 27 | model.fc = nn.Linear(fc_features, 7) 28 | 29 | return model 30 | 31 | ################################################################################ 32 | # 加载Densenet161模型 33 | ################################################################################ 34 | def densenet(pretrained = True): 35 | """ 36 | 加载Densenet161模型 37 | 38 | :param 39 | pretrained(bool) -- 是否预训练 40 | :return 41 | 输出为7的Densenet161模型 42 | """ 43 | # 加载模型 44 | model = densenet161(pretrained = pretrained) 45 | 46 | # 修改模型输出为7 47 | fc_features = model.classifier.in_features 48 | model.classifier = nn.Linear(fc_features, 7) 49 | return model 50 | 51 | ################################################################################ 52 | # 加载InceptionV3模型,输入为299X299 53 | ################################################################################ 54 | def inception(pretrained = True): 55 | """ 56 | 加载InceptionV3模型 57 | 58 | :param 59 | pretrained(bool) -- 是否预训练 60 | :return 61 | 输出为7的InceptionV3模型 62 | """ 63 | # 加载模型 64 | model = inception_v3(pretrained = pretrained) 65 | 66 | # 修改模型输出为7 67 | fc_features = model.fc.in_features 68 | model.fc = nn.Linear(fc_features, 7) 69 | fc_features = model.AuxLogits.fc.in_features 70 | model.AuxLogits.fc = nn.Linear(fc_features, 7) 71 | 72 | return model 73 | 74 | ################################################################################ 75 | # 加载SeNet154模型 76 | ################################################################################ 77 | def senet(pretrained = True): 78 | """ 79 | 加载SeNet154模型 80 | 81 | :param 82 | pretrained(bool) -- 是否预训练 83 | :return 84 | 输出为7的SeNet154模型 85 | """ 86 | # 加载模型 87 | if pretrained: 88 | model = pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained='imagenet') 89 | else: 90 | model = pretrainedmodels.__dict__['senet154'](num_classes=1000) 91 | 92 | # 修改模型输出为7 93 | fc_features = model.last_linear.in_features 94 | model.last_linear = nn.Linear(fc_features, 7) 95 | 96 | return model 97 | -------------------------------------------------------------------------------- /服务端/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 自定义模型MFFNet(Multiple Feature Fusion Network) 4 | """ 5 | from __future__ import print_function, division 6 | import torch.nn as nn 7 | import torch.nn.functional as f 8 | import torch 9 | from torchvision.models import vgg16, densenet121 10 | 11 | ################################################################################ 12 | # 自定义模型MFFNet(Multiple Feature Fusion Network) 13 | ################################################################################ 14 | class MFFNet(nn.Module): 15 | """ 16 | 自定义模型MFFNet(Multiple Feature Fusion Network) 17 | """ 18 | def __init__(self, num_classes = 7): 19 | super(MFFNet, self).__init__() 20 | 21 | self.features = nn.Sequential(*(list(vgg16().children())[0][ : -8])) 22 | 23 | self.bb = nn.Sequential( 24 | nn.MaxPool2d(kernel_size = 2, stride = 2), 25 | _BasicBlock(512), 26 | _BasicConv2d(512, 256, kernel_size = 3, stride = 1, padding = 1), 27 | 28 | nn.MaxPool2d(kernel_size = 2, stride = 2), 29 | _BasicBlock(256), 30 | _BasicConv2d(256, 128, kernel_size = 3, stride = 1, padding = 1), 31 | 32 | nn.AdaptiveAvgPool2d((1, 1)) 33 | ) 34 | 35 | self.classifier = nn.Sequential( \ 36 | nn.Linear(128, num_classes) 37 | ) 38 | 39 | # initialize weight of layers 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | nn.init.xavier_normal_(m.weight) 43 | if m.bias is not None: 44 | nn.init.constant_(m.bias, 0) 45 | elif isinstance(m, nn.BatchNorm2d): 46 | nn.init.constant_(m.weight, 1) 47 | nn.init.constant_(m.bias, 0) 48 | elif isinstance(m, nn.Linear): 49 | nn.init.constant_(m.weight, 0) 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.InstanceNorm2d): 52 | nn.init.constant_(m.weight, 1) 53 | nn.init.constant_(m.bias, 0) 54 | 55 | def forward(self, x): 56 | # N x 3 x 224 x 224 57 | x = self.features(x) 58 | # N x 512 x 28 x 28 59 | x = self.bb(x) 60 | # N x 128 x 1 x 1 61 | x = x.view(x.size(0), -1) 62 | # N x 128 63 | x = self.classifier(x) 64 | # N x num_classes 65 | x = f.softmax(x, dim = 1) 66 | 67 | return x 68 | 69 | ''' 70 | achieve basicblock which consists three layers 71 | ''' 72 | class _BasicBlock(nn.Module): 73 | 74 | def __init__(self, in_channels): 75 | super(_BasicBlock, self).__init__() 76 | 77 | self.bc_1 = _BasicConv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1) 78 | self.bc_2 = _BasicConv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1) 79 | self.bc_3 = _BasicConv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1) 80 | 81 | 82 | def forward(self, x): 83 | bc_1 = self.bc_1(x) 84 | bc_2 = self.bc_2(x) 85 | bc_2 = x + bc_2 86 | bc_3 = self.bc_3(bc_2) 87 | bc_3 = bc_3 + bc_1 88 | return bc_3 89 | 90 | ''' 91 | conv -> bn -> relu 92 | ''' 93 | class _BasicConv2d(nn.Module): 94 | 95 | def __init__(self, in_channels, out_channels, **kwargs): 96 | super(_BasicConv2d, self).__init__() 97 | 98 | self.bn = nn.BatchNorm2d(in_channels) 99 | # self.instance_norm = nn.InstanceNorm2d(out_channels) 100 | self.relu = nn.ReLU(True) 101 | self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs) 102 | 103 | def forward(self, x): 104 | x = self.bn(x) 105 | # x = self.instance_norm(x) 106 | x = self.relu(x) 107 | x = self.conv(x) 108 | return x -------------------------------------------------------------------------------- /服务端/server.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | """ 3 | 基于Flask的服务端API. 4 | process: 5 | 响应对上传图像进行色素性皮肤病七分类预测请求. 6 | """ 7 | 8 | from __future__ import print_function, division 9 | from flask import Flask, request, jsonify 10 | from werkzeug.utils import secure_filename 11 | import torch 12 | from torchvision import transforms 13 | from PIL import Image 14 | import os 15 | from models import densenet, senet 16 | import CONFIG 17 | 18 | # 创建Flask类实例 19 | app = Flask(__name__) 20 | 21 | ################################################################################ 22 | # 响应对上传图像进行色素性皮肤病七分类预测请求,请求方式为post. 23 | ################################################################################ 24 | @app.route('/process', methods=['post']) 25 | def process(): 26 | """ 27 | 响应对上传图像进行色素性皮肤病七分类预测请求, 28 | 以json格式返回结果 29 | """ 30 | # 单张图片数据增强后的图片个数 31 | NUM = 16 32 | # 模型权重的地址 33 | KWARGS = {'1': CONFIG.densenet_model, 34 | '2': CONFIG.senet_model } 35 | 36 | try: 37 | # 存储上传图像的目录地址 38 | UPLOAD_FOLDER = CONFIG.upload_dir 39 | 40 | ## 读取并保存上传图像 ## 41 | file = request.files['file'] 42 | filename = secure_filename(file.filename) 43 | # 图像存储路径 44 | path = os.path.join(UPLOAD_FOLDER, filename) 45 | file.save(path) 46 | 47 | # 选择设备 48 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | 50 | ## 图像处理 ## 51 | # 224 x 224 52 | normalize = transforms.Normalize( 53 | mean=[0.485, 0.456, 0.406], 54 | std=[0.229, 0.224, 0.225] 55 | ) 56 | transform = transforms.Compose([ 57 | transforms.RandomCrop(224), 58 | transforms.ToTensor(), 59 | normalize] 60 | ) 61 | # 加载图片 62 | img = Image.open(path) 63 | # 裁剪图片 64 | img = img.resize((300, 300)) 65 | # 数据增强后的所有图片 66 | imgs = None 67 | for i in range(NUM): 68 | if imgs is None: 69 | imgs = transform(img).view(1, 3, 224, 224) 70 | else: 71 | imgs = torch.cat((imgs, transform(img).view(1, 3, 224, 224)), 0) 72 | 73 | ## 加载模型 ## 74 | models = [] 75 | for k, v in KWARGS.items(): 76 | if k == '1': 77 | model = densenet(pretrained = False) 78 | elif k == '2': 79 | model = senet(pretrained = False) 80 | model.load_state_dict(torch.load(v, map_location = 'cpu')) 81 | model = model.to(device) 82 | model.eval() 83 | models.append(model) 84 | 85 | ## 预测结果 ## 86 | imgs = imgs.to(device) 87 | # 所有模型预测结果之和 88 | sum = None 89 | # 集成学习平均策略 90 | for model in models: 91 | # 预测结果 92 | output = model(imgs) 93 | 94 | output = output.detach() 95 | # 平均策略 96 | val = None 97 | for i in range(output.size(0)): 98 | if val is None: 99 | val = output[i] 100 | else: 101 | val = val + output[i] 102 | val = val / output.size(0) 103 | 104 | if sum is None: 105 | sum = val 106 | else: 107 | sum += val 108 | val = sum / len(models) 109 | _, a = torch.max(val, 0) 110 | # 预测结果 111 | a = a.item() 112 | 113 | ## 返回结果 ## 114 | # 七类色素性皮肤病名称 115 | classes = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC'] 116 | # 预测结果 117 | predict = classes[a] 118 | results_json = {} 119 | results_json['result'] = predict 120 | results_json['status'] = 'Success' 121 | return jsonify(results_json) 122 | except Exception: 123 | results_json = {} 124 | results_json['status'] = 'Failure' 125 | return jsonify(results_json) 126 | 127 | ################################################################################ 128 | # 函数入口. 129 | ################################################################################ 130 | if __name__ == '__main__': 131 | # ssl配置文件地址,可以参考https://blog.csdn.net/robin912/article/details/80698896 132 | pem = CONFIG.pem 133 | key = CONFIG.key 134 | 135 | # 运行程序并设置外部可访问 136 | app.run(port = 8086, host='0.0.0.0', debug = True, ssl_context = (pem, key)) -------------------------------------------------------------------------------- /服务端/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 数据预处理. 4 | """ 5 | from __future__ import print_function, division 6 | from PIL import Image 7 | import os 8 | import pandas as pd 9 | from sklearn.model_selection import train_test_split 10 | import numpy as np 11 | import CONFIG 12 | from torch.utils.data import Dataset 13 | import torch 14 | 15 | ################################################################################ 16 | # 划分数据集为训练集和测试集,互不交叉,同时保证划分后的数据集类别分布和原始数据集一致 17 | ################################################################################ 18 | def divide_dataset(path): 19 | """ 20 | 划分数据集 21 | param: 22 | path--存放图片和csv文件的目录地址(str) 23 | return: 24 | 无返回值,以文本形式存储 25 | """ 26 | print("start preprocess") 27 | # 判断目录地址是否有效 28 | if (not os.path.isdir(path)): 29 | print("%s is not a directory path!!!" % path) 30 | return 31 | 32 | # 读取csv文件 33 | label_path = os.path.join(path, "ISIC2018_Task3_Training_GroundTruth.csv") 34 | labels = pd.read_csv(label_path) 35 | names = ('MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC') 36 | # 图片地址 37 | data_x = [] 38 | # 图片类别 39 | data_y = [] 40 | for i in range(len(labels)): 41 | # 存储图片地址 42 | image_path = 'images/' + labels.at[i, 'image'] + '.jpg' 43 | image_path = os.path.join(path, image_path) 44 | if (not os.path.isfile(image_path)): 45 | print("%s is not a file path!!!" % image_path) 46 | data_x.append(image_path) 47 | 48 | # 存储图片类别 49 | tag = [] 50 | for name in names: 51 | tag.append(labels.at[i, name]) 52 | tag = np.array(tag) 53 | # 直接存储图片类别,不使用one-hot方式 54 | data_y.append(np.argmax(tag)) 55 | 56 | # 划分数据集 57 | print('data_x:', len(data_x), ' ', "data_y:", len(data_y)) 58 | x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size = 0.15, stratify = data_y) 59 | print('x_train:', len(x_train), ' ', "y_train:", len(y_train)) 60 | print('x_test:', len(x_test), ' ', "y_test:", len(y_test)) 61 | 62 | # 存储训练集、测试集 63 | with open(CONFIG.traindataset_path, 'wt') as f: 64 | f.write(str(x_train)) 65 | f.write('\n') 66 | f.write(str(y_train)) 67 | with open(CONFIG.testdataset_path, 'wt') as f: 68 | f.write(str(x_test)) 69 | f.write('\n') 70 | f.write(str(y_test)) 71 | print("stop preprocess") 72 | 73 | ################################################################################ 74 | # 显示训练集、测试集各类别分布情况 75 | ################################################################################ 76 | def show_dateset_info(): 77 | """ 78 | 显示训练集、测试集各类别分布情况 79 | 80 | :return 81 | 无返回值,直接打印结果 82 | """ 83 | # 各类别名称 84 | names = ('MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC') 85 | # 路径字典 86 | datas = {'traindataset' : CONFIG.traindataset_path, 'testdataset' : CONFIG.testdataset_path} 87 | 88 | # 打印训练集、测试集各类别分布情况 89 | for k, v in datas.items(): 90 | print(k) 91 | with open(v, 'rt') as f: 92 | # 统计每个类别的个数 93 | pairs = {} 94 | 95 | # 读取数据 96 | data = f.readlines() 97 | data = data[1].strip() 98 | data = data[1: len(data) - 1] 99 | data = data.split(',') 100 | for i in range(len(data)): 101 | data[i] = int(data[i]) 102 | data = np.array(data) 103 | 104 | # 打印各类别分布情况 105 | # 总样本个数 106 | total = data.shape[0] 107 | for val in data: 108 | if val in pairs.keys(): 109 | pairs[val] += 1 110 | else: 111 | pairs[val] = 1 112 | for i in range(7): 113 | key = i 114 | val = pairs[key] 115 | print("%-8s%d/%d%10.2f%%" % (names[key], val, total, (val / total) * 100)) 116 | print("%-8s%d" % ("total", total)) 117 | 118 | ################################################################################ 119 | # 自定义色素性皮肤病数据集 120 | ################################################################################ 121 | class SkinDiseaseDataset(Dataset): 122 | """ 123 | 自定义色素性皮肤病数据集 124 | """ 125 | def __init__(self, path, transforms, agumentation): 126 | """ 127 | 初始化函数 128 | 129 | :param 130 | path -- 训练集或测试集存储地址(str) 131 | transforms -- 数据增强操作(torchvision.transforms) 132 | agumentation -- 是否对单个图片多次复制(boolean) 133 | """ 134 | self.transforms = transforms 135 | self.agumentation = agumentation 136 | # 读取数据集 137 | with open(path, "rt") as f: 138 | data = f.readlines() 139 | 140 | imgs = data[0].strip() 141 | imgs = imgs[1: len(imgs) - 1] 142 | imgs = imgs.split(',') 143 | for i in range(len(imgs)): 144 | imgs[i] = imgs[i].strip() 145 | imgs[i] = imgs[i].strip('\'') 146 | # 图片路径 147 | self.imgs = imgs 148 | 149 | labels = data[1].strip() 150 | labels = labels[1: len(labels) - 1] 151 | labels = labels.split(',') 152 | for i in range(len(labels)): 153 | labels[i] = int(labels[i]) 154 | # 图片标签 155 | self.labels = np.array(labels) 156 | 157 | def __getitem__(self, index): 158 | """ 159 | 获取样本 160 | 161 | :param 162 | index -- 索引 163 | :return 164 | 返回(图片数据,图片标签) 165 | agumentation为True时,返回格式(NXCXHXW, ) 166 | agumentation为False时,返回格式(CXHXW, ) 167 | """ 168 | # 图片路径 169 | image_path = self.imgs[index] 170 | # 读取图片 171 | img = Image.open(image_path) 172 | # 裁剪图片 173 | img = img.resize((300, 300)) 174 | 175 | # 单个图片不复制 176 | if not self.agumentation: 177 | return self.transforms(img), self.labels[index] 178 | 179 | # 单个图片复制 180 | imgs = None 181 | for i in range(CONFIG.num): 182 | if imgs is None: 183 | imgs = self.transforms(img).view(1, 3, 224, 224) 184 | else: 185 | imgs = torch.cat((imgs, self.transforms(img).view(1, 3, 224, 224)), 0) 186 | return imgs, self.labels[index] 187 | 188 | def __len__(self): 189 | """ 190 | 获取数据集大小 191 | :return 192 | 数据集大小 193 | """ 194 | return len(self.imgs) 195 | 196 | ################################################################################ 197 | # 函数入口 198 | ################################################################################ 199 | if __name__ == '__main__': 200 | # # 图像、标签csv文件存储目录地址 201 | # path = CONFIG.data_path 202 | # # 划分数据集 203 | # divide_dataset(path) 204 | 205 | # 显示训练集、测试集各类别分布情况 206 | show_dateset_info() -------------------------------------------------------------------------------- /服务端/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 测试模型性能. 4 | """ 5 | from __future__ import print_function, division 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from data import SkinDiseaseDataset 11 | from sklearn.metrics import confusion_matrix, classification_report 12 | from utils import plot_confusion_matrix, balanced_multiclass_accuracy 13 | from models import inception, resnet, densenet, senet 14 | from model import MFFNet 15 | import argparse 16 | import numpy as np 17 | import CONFIG 18 | 19 | ################################################################################ 20 | # 测试模型 21 | ################################################################################ 22 | def test(test_path, agumentation, **kwargs): 23 | """ 24 | 测试模型性能 25 | :param 26 | test_path(str) -- 测试集地址 27 | agumentation(bool) -- 是否对单个图片多次复制 28 | :kwargs 29 | model(int) -- 模型 30 | """ 31 | # 设置超参数 32 | if agumentation: 33 | BATCH_SIZE = 1 34 | else: 35 | BATCH_SIZE = 32 36 | 37 | # 选择运行的cpu或gpu 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | # 定义损失函数 41 | # N / n,权重为各类别频率的倒数 42 | weight = torch.Tensor([9., 1.5, 19.48, 30.62, 9.11, 86.86, 71.]) 43 | weight = weight.to(device) 44 | criterion = nn.CrossEntropyLoss(weight = weight) 45 | 46 | # 数据处理 47 | normalize = transforms.Normalize( 48 | mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225] 50 | ) 51 | test_transform = transforms.Compose([ 52 | transforms.RandomCrop(224), 53 | transforms.ToTensor(), 54 | normalize] 55 | ) 56 | 57 | # 加载数据 58 | # 定义test_loader 59 | test_dataset = SkinDiseaseDataset(test_path, transforms = test_transform, agumentation = agumentation) 60 | test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE) 61 | 62 | # 加载模型 63 | if kwargs['model'] == 1: 64 | model = MFFNet() 65 | elif kwargs['model'] == 2: 66 | # 299 x 299 67 | model = inception(pretrained = False) 68 | elif kwargs['model'] == 3: 69 | model = resnet(pretrained = False) 70 | elif kwargs['model'] == 4: 71 | model = densenet(pretrained = False) 72 | elif kwargs['model'] == 5: 73 | model = senet(pretrained = False) 74 | # 加载模型权重 75 | model.load_state_dict(torch.load(CONFIG.best_model)) 76 | model = model.to(device) 77 | 78 | # 测试模式 79 | model.eval() 80 | # 各类别预测正确个数 81 | class_correct = list(0. for i in range(7)) 82 | # 各类别总个数 83 | class_total = list(0. for i in range(7)) 84 | # 损失 85 | sum_loss = 0.0 86 | # 总预测正确个数 87 | correct = 0 88 | # 总个数 89 | total = 0 90 | # 总迭代次数 91 | cnt = 0 92 | # 测试集增强模式 93 | if agumentation: 94 | # 预测标签情况 95 | x = [] 96 | # 真实标签情况 97 | y = [] 98 | for data in test_loader: 99 | cnt += 1 100 | 101 | # 加载数据 102 | image, label = data 103 | image = image.view(-1, 3, 224, 224) 104 | label = label[0] 105 | image, label = image.to(device), label.to(device) 106 | 107 | # 前向传播 108 | output = model(image) 109 | 110 | # 使用平均策略获取预测值 111 | output = output.detach() 112 | # 平均策略 113 | val = None 114 | for i in range(output.size(0)): 115 | if val is None: 116 | val = output[i] 117 | else: 118 | val = val + output[i] 119 | val = val / output.size(0) 120 | _, a = torch.max(val, 0) 121 | 122 | # 统计各个类预测正确的个数 123 | m = label.detach() 124 | class_correct[m] += 1 if a == m else 0 125 | class_total[m] += 1 126 | # 统计预测正确总个数 127 | correct += 1 if a == m else 0 128 | 129 | x.append(a.item()) 130 | y.append(m.item()) 131 | # list转化为numpy 132 | x = np.array(x) 133 | y = np.array(y) 134 | else: 135 | # 预测标签情况 136 | x = None 137 | # 真实标签情况 138 | y = None 139 | for data in test_loader: 140 | cnt += 1 141 | 142 | # 加载数据 143 | image, label = data 144 | image, label = image.to(device), label.to(device) 145 | 146 | # 前向传播 147 | output = model(image) 148 | loss = criterion(output, label) 149 | 150 | # 计算loss和acc 151 | sum_loss += loss.item() 152 | _, a = torch.max(output.detach(), 1) 153 | b = label.detach() 154 | total += label.size(0) 155 | correct += (a == b).sum() 156 | 157 | # 预测和真实标签情况 158 | if x is None: 159 | x = a 160 | y = b 161 | else: 162 | x = torch.cat((x, a)) 163 | y = torch.cat((y, b)) 164 | 165 | # 统计每个类别的正确预测情况 166 | for i in range(label.size(0)): 167 | m = b[i] 168 | class_correct[m] += 1 if a[i] == m else 0 169 | class_total[m] += 1 170 | # tensor转化为numpy 171 | x = x.cpu().detach().numpy() 172 | y = y.cpu().detach().numpy() 173 | 174 | # 打印结果 175 | cm_plot_labels = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC'] 176 | # 判断测试集是否增强 177 | if agumentation: 178 | # 打印acc 179 | print("test_acc:%.2f%%\n" % (100 * correct / cnt)) 180 | else: 181 | # 打印loss和acc 182 | print("test_loss:%.2f test_acc:%.2f%%\n" % (sum_loss / cnt, 100 * correct / total)) 183 | # 打印每个类别的acc 184 | for i in range(7): 185 | if class_total[i] > 0: 186 | print('Test Accuracy of %5s: %.2f%% (%2d/%2d)' % ( 187 | cm_plot_labels[i], 100 * class_correct[i] / class_total[i], 188 | class_correct[i], class_total[i])) 189 | else: 190 | print('Test Accuracy of %5s: N/A (no training examples)' % cm_plot_labels[i]) 191 | print('') 192 | 193 | # 计算混淆矩阵 194 | cm = confusion_matrix(y, x) 195 | print('') 196 | 197 | # 计算BMC 198 | balanced_multiclass_accuracy(cm) 199 | print('') 200 | 201 | # 可视化混淆矩阵 202 | plot_confusion_matrix(cm, cm_plot_labels, title = 'Confusion Matrix') 203 | print('') 204 | 205 | # 打印分类报告 206 | report = classification_report(y, x, target_names=cm_plot_labels) 207 | print(report) 208 | 209 | ################################################################################ 210 | # 函数入口 211 | ################################################################################ 212 | if __name__ == '__main__': 213 | # 测试集地址 214 | test_path = CONFIG.testdataset_path 215 | 216 | # 解析参数 217 | parser = argparse.ArgumentParser(description='test model') 218 | parser.add_argument('-m', '--model', type=int, default=1, 219 | help='choose models, 1)MFFNet 2)Inception 3)ResNet 4)DenseNet 5)SeNet, default 1', 220 | choices=[1, 2, 3, 4, 5]) 221 | args = parser.parse_args() 222 | args = {'model': args.model} 223 | 224 | # 测试模型 225 | test(test_path, agumentation = False, **args) -------------------------------------------------------------------------------- /服务端/ensemble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 测试集成模型性能. 4 | """ 5 | from __future__ import print_function, division 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | from data import SkinDiseaseDataset 9 | from models import inception, resnet, densenet, senet 10 | from model import MFFNet 11 | import torch 12 | import torch.nn as nn 13 | import CONFIG 14 | import numpy as np 15 | from utils import plot_confusion_matrix, balanced_multiclass_accuracy 16 | from sklearn.metrics import confusion_matrix, classification_report 17 | 18 | ################################################################################ 19 | # 测试集成模型 20 | ################################################################################ 21 | def test(test_path, agumentation, **kwargs): 22 | """ 23 | 测试集成模型性能 24 | :param 25 | test_path(str) -- 测试集地址 26 | agumentation(bool) -- 是否对单个图片多次复制 27 | :kwargs 28 | model(int) -- 模型 29 | """ 30 | # 设置超参数 31 | if agumentation: 32 | BATCH_SIZE = 1 33 | else: 34 | BATCH_SIZE = 32 35 | 36 | # 选择运行的cpu或gpu 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | # 数据处理 40 | normalize = transforms.Normalize( 41 | mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225] 43 | ) 44 | test_transform = transforms.Compose([ 45 | transforms.RandomCrop(224), 46 | transforms.ToTensor(), 47 | normalize] 48 | ) 49 | 50 | # 加载数据 51 | # 定义test_loader 52 | test_dataset = SkinDiseaseDataset(test_path, transforms = test_transform, agumentation = agumentation) 53 | test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE) 54 | 55 | # 加载模型 56 | models = [] 57 | for k, v in kwargs.items(): 58 | if k == '1': 59 | model = MFFNet() 60 | elif k == '2': 61 | # 299 x 299 62 | model = inception(pretrained = False) 63 | elif k == '3': 64 | model = resnet(pretrained = False) 65 | elif k == '4': 66 | model = densenet(pretrained = False) 67 | elif k == '5': 68 | model = senet(pretrained = False) 69 | model.load_state_dict(torch.load(v)) 70 | model = model.to(device) 71 | # 测试模式 72 | model.eval() 73 | models.append(model) 74 | 75 | # 测试模型 76 | # 各类别预测正确个数 77 | class_correct = list(0. for i in range(7)) 78 | # 各类别总个数 79 | class_total = list(0. for i in range(7)) 80 | # 总预测正确个数 81 | correct = 0 82 | # 总个数 83 | total = 0 84 | # 总迭代次数 85 | cnt = 0 86 | # 测试集增强模式 87 | if agumentation: 88 | # 预测标签情况 89 | x = [] 90 | # 真实标签情况 91 | y = [] 92 | for data in test_loader: 93 | cnt += 1 94 | 95 | # 加载数据 96 | image, label = data 97 | image = image.view(-1, 3, 224, 224) 98 | label = label[0] 99 | image, label = image.to(device), label.to(device) 100 | 101 | # 使用平均策略获取预测值,即最终的输出为各模型输出和的平均 102 | sum = None 103 | # 平均策略 104 | for model in models: 105 | output = model(image) 106 | 107 | # 使用平均策略获取模型的输出 108 | output = output.detach() 109 | # 平均策略 110 | val = None 111 | for i in range(output.size(0)): 112 | if val is None: 113 | val = output[i] 114 | else: 115 | val = val + output[i] 116 | val = val / output.size(0) 117 | 118 | if sum is None: 119 | sum = val 120 | else: 121 | sum += val 122 | val = sum / len(models) 123 | _, a = torch.max(val, 0) 124 | 125 | # 统计各个类预测正确的个数 126 | m = label.detach() 127 | class_correct[m] += 1 if a == m else 0 128 | correct += 1 if a == m else 0 129 | class_total[m] += 1 130 | 131 | x.append(a.item()) 132 | y.append(m.item()) 133 | # list转化为numpy 134 | x = np.array(x) 135 | y = np.array(y) 136 | else: 137 | # 预测标签情况 138 | x = None 139 | # 真实标签情况 140 | y = None 141 | for data in test_loader: 142 | cnt += 1 143 | 144 | # 加载数据 145 | image, label = data 146 | image, label = image.to(device), label.to(device) 147 | 148 | # 使用平均策略,获取输出,即最终的输出为各模型输出和的平均 149 | output = None 150 | # 平均策略 151 | for model in models: 152 | if output is None: 153 | output = model(image).detach() 154 | else: 155 | output += model(image).detach() 156 | output = output / len(models) 157 | 158 | # acc 159 | _, a = torch.max(output, 1) 160 | b = label.detach() 161 | total += label.size(0) 162 | correct += (a == b).sum() 163 | 164 | # 预测和真实标签情况 165 | if x is None: 166 | x = a 167 | y = b 168 | else: 169 | x = torch.cat((x, a)) 170 | y = torch.cat((y, b)) 171 | 172 | # 统计每个类别的正确预测情况 173 | for i in range(label.size(0)): 174 | m = b[i] 175 | class_correct[m] += 1 if a[i] == m else 0 176 | class_total[m] += 1 177 | # tensor转化为numpy 178 | x = x.cpu().detach().numpy() 179 | y = y.cpu().detach().numpy() 180 | 181 | # 打印结果 182 | cm_plot_labels = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC'] 183 | # 打印acc 184 | print("test_acc:%.2f%%\n" % (100 * correct / total)) 185 | # 打印每个类别的acc 186 | for i in range(7): 187 | if class_total[i] > 0: 188 | print('Test Accuracy of %5s: %.2f%% (%2d/%2d)' % ( 189 | cm_plot_labels[i], 100 * class_correct[i] / class_total[i], 190 | class_correct[i], class_total[i])) 191 | else: 192 | print('Test Accuracy of %5s: N/A (no training examples)' % cm_plot_labels[i]) 193 | print('') 194 | 195 | # 计算混淆矩阵 196 | cm = confusion_matrix(y, x) 197 | print('') 198 | 199 | # 计算BMC 200 | balanced_multiclass_accuracy(cm) 201 | print('') 202 | 203 | # 可视化混淆矩阵 204 | plot_confusion_matrix(cm, cm_plot_labels, title='Confusion Matrix') 205 | print('') 206 | 207 | # 打印分类报告 208 | report = classification_report(y, x, target_names=cm_plot_labels) 209 | print(report) 210 | 211 | ################################################################################ 212 | # 函数入口 213 | ################################################################################ 214 | if __name__ == '__main__': 215 | # 测试集地址 216 | test_path = CONFIG.testdataset_path 217 | 218 | # 集成模型配置. 1)MFFNet 2)Inception 3)ResNet 4)DenseNet 5)SeNet 219 | args = { # '1' : 'train_models/mffnet_model.pt', 220 | # '2': CONFIG.inception_model, 221 | # '3': CONFIG.resnet_model, 222 | '4': CONFIG.densenet_model, 223 | '5': CONFIG.senet_model} 224 | 225 | # 测试集成模型 226 | test(test_path, agumentation = False, **args) -------------------------------------------------------------------------------- /服务端/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 训练模型. 4 | """ 5 | from __future__ import print_function, division 6 | from torchsummary import summary 7 | import torch 8 | import torch.utils.data 9 | from torch.utils.data import DataLoader 10 | from data import SkinDiseaseDataset 11 | from utils import ImbalancedDatasetSampler, EarlyStopping, plot_loss 12 | from torchvision import transforms 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import numpy as np 16 | import argparse 17 | from models import inception, resnet, densenet, senet 18 | import CONFIG 19 | from model import MFFNet 20 | 21 | ################################################################################ 22 | # 训练模型 23 | ################################################################################ 24 | def train(train_path, val_path, **kwargs): 25 | """ 26 | 训练模型 27 | :param 28 | train_path(str) -- 训练集地址 29 | val_path(str) -- 验证集地址 30 | : kwargs 31 | model(int) -- 训练模型 32 | epoch(int) -- 训练轮数 33 | batch_size(int) -- 训练批次大小 34 | learn_rate(int) -- 学习率 35 | :return 36 | 返回训练集损失(list)、验证集损失(list) 37 | """ 38 | # 设置超参数 39 | lrs = [1e-3, 1e-4, 1e-5] 40 | # 学习率 41 | LR = lrs[kwargs['learn_rate'] - 1] 42 | # 训练轮数 43 | EPOCH = kwargs['epoch'] 44 | # 批次大小 45 | BATCH_SIZE = kwargs['batch_size'] 46 | 47 | # 数据处理 48 | normalize = transforms.Normalize( 49 | mean=[0.76209545, 0.54330575, 0.5679443], 50 | std=[0.14312604, 0.154518, 0.17225058] 51 | ) 52 | train_transform = transforms.Compose([ 53 | transforms.RandomCrop(224), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.RandomVerticalFlip(), 56 | transforms.RandomRotation(degrees = 180), 57 | transforms.ColorJitter(brightness = 0.1, contrast = 0.1, saturation = 0.1), 58 | transforms.ToTensor(), 59 | normalize] 60 | ) 61 | val_transform = transforms.Compose([ 62 | transforms.RandomCrop(224), 63 | transforms.ToTensor(), 64 | normalize] 65 | ) 66 | 67 | # 加载数据 68 | #定义trainloader 69 | train_dataset = SkinDiseaseDataset(train_path, transforms = train_transform, agumentation = False) 70 | train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True) # sampler = ImbalancedDatasetSampler(train_dataset) 71 | 72 | #定义valloader 73 | val_dataset = SkinDiseaseDataset(val_path, transforms = val_transform, agumentation = False) 74 | val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE) 75 | 76 | # 选择运行的cpu或gpu 77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 78 | 79 | # 加载模型 80 | if kwargs['model'] == 1: 81 | model = MFFNet() 82 | elif kwargs['model'] == 2: 83 | # 299 x 299 84 | model = inception() 85 | elif kwargs['model'] == 3: 86 | model = resnet() 87 | elif kwargs['model'] == 4: 88 | model = densenet() 89 | elif kwargs['model'] == 5: 90 | model = senet() 91 | # # 断点训练,加载模型权重 92 | # model.load_state_dict(torch.load(CONFIG.best_model)) 93 | model = model.to(device) 94 | 95 | # 定义损失函数 96 | # N / n,权重为各类别频率的倒数 97 | weight = torch.Tensor([9., 1.5, 19.48, 30.62, 9.11, 86.86, 71.]) 98 | weight = weight.to(device) 99 | criterion = nn.CrossEntropyLoss(weight = weight) 100 | 101 | # 定义优化器 102 | # optimizer = optim.SGD(model.parameters(), lr = LR, weight_decay = 1e-5) 103 | optimizer = optim.Adam(model.parameters(), lr = LR, weight_decay = 1e-5) 104 | 105 | # 定义学习率调度策略 106 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 7, verbose = True) 107 | 108 | # 可视化模型 109 | if kwargs['model'] == 2: 110 | summary(model, (3, 299, 299)) 111 | else: 112 | summary(model, (3, 224, 224)) 113 | print(model) 114 | 115 | # 在模型训练过程中,跟踪每一轮平均训练集损失 116 | avg_train_losses = [] 117 | # 在模型训练过程中,跟踪每一轮平均验证集损失 118 | avg_valid_losses = [] 119 | 120 | # EarlyStopping机制 121 | early_stopping = EarlyStopping(patience = 12, verbose = True) 122 | 123 | # 训练模型 124 | for epoch in range(EPOCH): 125 | # 训练模式 126 | model.train() 127 | # 损失 128 | sum_loss = 0.0 129 | # 预测正确样本数 130 | correct = 0 131 | # 总样本数 132 | total = 0 133 | # 迭代次数 134 | cnt = 0 135 | for data in train_loader: 136 | cnt += 1 137 | 138 | # 加载数据 139 | image, label = data 140 | image, label = image.to(device), label.to(device) 141 | 142 | # 梯度置零 143 | optimizer.zero_grad() 144 | 145 | # 前向传播、后向传播 146 | output = model(image) 147 | loss = criterion(output, label) 148 | # inceptionV3 149 | # loss = criterion(output, label) + 0.4 * criterion(aux, label) 150 | loss.backward() 151 | optimizer.step() 152 | 153 | # 计算loss and acc 154 | sum_loss += loss.item() 155 | _, a = torch.max(output.detach(), 1) 156 | b = label.detach() 157 | total += label.size(0) 158 | correct += (a == b).sum() 159 | # 打印loss和acc 160 | print('[ %d/%d ] train_loss:%.2f train_acc:%.2f%%' % (epoch + 1, EPOCH, sum_loss / cnt, 100 * correct / total)) 161 | avg_train_losses.append(sum_loss / cnt) 162 | 163 | # 验证模式 164 | model.eval() 165 | # 损失 166 | sum_loss = 0.0 167 | # 预测正确样本数 168 | correct = 0 169 | # 总样本数 170 | total = 0 171 | # 迭代次数 172 | cnt = 0 173 | for data in val_loader: 174 | cnt += 1 175 | 176 | # 加载数据 177 | image, label = data 178 | image, label = image.to(device), label.to(device) 179 | 180 | # 前向传播 181 | output = model(image) 182 | loss = criterion(output, label) 183 | 184 | # 计算loss和acc 185 | sum_loss += loss.item() 186 | _, a = torch.max(output.detach(), 1) 187 | b = label.detach() 188 | total += label.size(0) 189 | correct += (a == b).sum() 190 | # 打印loss和acc 191 | print(" val_loss:%.2f val_acc:%.2f%%" % (sum_loss / cnt, 100 * correct / total)) 192 | avg_valid_losses.append(sum_loss / cnt) 193 | 194 | # earlyStopping机制 195 | early_stopping(sum_loss / cnt, model) 196 | # 学习率调度机制 197 | scheduler.step(sum_loss / cnt) 198 | 199 | # 保存模型 200 | torch.save(model.state_dict(), CONFIG.intermediate_model + '/checkpoint_%d.pt' % (epoch + 1)) 201 | 202 | # 判断是否停止训练 203 | if early_stopping.early_stop: 204 | print("Early stopping") 205 | break 206 | 207 | return avg_train_losses, avg_valid_losses 208 | 209 | ################################################################################ 210 | # 函数入口 211 | ################################################################################ 212 | if __name__ == '__main__': 213 | # 训练集地址 214 | train_path = CONFIG.traindataset_path 215 | # 验证集地址 216 | val_path = CONFIG.testdataset_path 217 | 218 | # 解析参数 219 | parser = argparse.ArgumentParser(description='train model') 220 | parser.add_argument('-m', '--model', type = int, default = 1, help = 'choose models, 1)MFFNet 2)Inception 3)ResNet 4)DenseNet 5)SeNet, default 1', 221 | choices = [1, 2, 3, 4, 5]) 222 | parser.add_argument('-e', '--epoch', type = int, default = 100, help = 'set train epoches, default 100') 223 | parser.add_argument('-b', '--batch_size', type = int, default = 32, help = 'set batch size, default 32') 224 | parser.add_argument('-lr', '--learn_rate', type = int, default = 2, help = 'choose learn rate, 1)1e-3, 2)1e-4, 3)1e-5, default 1', 225 | choices = [1, 2, 3]) 226 | args = parser.parse_args() 227 | # 打印参数 228 | print('model', args.model) 229 | print('epoch', args.epoch) 230 | print('batch_size', args.batch_size) 231 | print('learn_rate', args.learn_rate) 232 | 233 | args = {'model' : args.model, 'epoch' : args.epoch, \ 234 | 'batch_size' : args.batch_size, 'learn_rate' : args.learn_rate} 235 | 236 | # 训练模型 237 | avg_train_losses, avg_valid_losses = train(train_path, val_path, **args) 238 | 239 | # 存储训练集loss、验证集loss 240 | with open(CONFIG.loss_log, 'wt') as f: 241 | f.write(str(avg_train_losses)) 242 | f.write('\n') 243 | f.write(str(avg_valid_losses)) 244 | 245 | # 可视化损失图 246 | plot_loss(avg_train_losses, avg_valid_losses) 247 | 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /服务端/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 模型预测结果. 4 | """ 5 | from __future__ import print_function, division 6 | import os 7 | from PIL import Image 8 | import torch 9 | from torchvision import transforms 10 | from models import inception, resnet, densenet, senet 11 | from model import MFFNet 12 | import numpy as np 13 | from utils import plot_image 14 | import CONFIG 15 | 16 | ################################################################################ 17 | # 获取预测图片路径,并以文本文件存储 18 | ################################################################################ 19 | def image_path(path): 20 | """ 21 | 获取预测图片路径,并以文本文件存储 22 | :param 23 | path(str) -- 存储图片目录地址 24 | """ 25 | print("start image_path") 26 | # 判断路径是否是有效目录 27 | if not os.path.isdir(path): 28 | print('It is not a directory') 29 | return 30 | 31 | # 获取图片路径 32 | paths = [] 33 | for img_path in os.listdir(path): 34 | if img_path.__contains__('.jpg'): 35 | paths.append(os.path.join(path, img_path)) 36 | 37 | # 存储图片路径 38 | with open(os.path.join(path, 'image_path.txt'), 'wt') as f: 39 | for img_path in paths: 40 | f.write(img_path + '\n') 41 | print("stop image_path") 42 | 43 | ################################################################################ 44 | # 图片处理,单张图片不复制 45 | ################################################################################ 46 | def process_data(path): 47 | """ 48 | 图片处理,单张图片不复制 49 | 50 | :param 51 | path(str) -- 图片路径文件地址 52 | :return(tensor) 53 | 图片数据,N X 3 X 224 X 224 54 | """ 55 | # 图片处理 56 | normalize = transforms.Normalize( 57 | mean=[0.485, 0.456, 0.406], 58 | std=[0.229, 0.224, 0.225] 59 | ) 60 | transform = transforms.Compose([ 61 | transforms.RandomCrop(224), 62 | transforms.ToTensor(), 63 | normalize] 64 | ) 65 | 66 | # 读取文件 67 | with open(path, 'rt') as f: 68 | lines = f.readlines() 69 | 70 | # 加载图片 71 | data = torch.zeros((0, 3, 224, 224)) 72 | for img_path in lines: 73 | img_path = img_path.strip() 74 | # 加载图片 75 | img = Image.open(img_path) 76 | # 裁剪图片 77 | img = img.resize((300, 300)) 78 | img = transform(img) 79 | img = img.view(1, 3, 224, 224) 80 | data = torch.cat((data, img), 0) 81 | 82 | return data 83 | 84 | ################################################################################ 85 | # 图片处理,单张图片复制 86 | ################################################################################ 87 | def process_data_agument(path, num = 16): 88 | """ 89 | 图片处理,单张图片不复制 90 | :param 91 | path(str) -- 图片路径地址 92 | num(int) -- 单张图片复制次数 93 | :return(tensor) 94 | 图片数据,N X 3 X 224 X 224 95 | """ 96 | # 图片处理 97 | normalize = transforms.Normalize( 98 | mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225] 100 | ) 101 | transform = transforms.Compose([ 102 | transforms.RandomCrop(224), 103 | transforms.ToTensor(), 104 | normalize] 105 | ) 106 | 107 | # 加载图片 108 | data = torch.zeros((0, 3, 224, 224)) 109 | img = Image.open(path) 110 | # 裁剪图片 111 | img = img.resize((300, 300)) 112 | for i in range(num): 113 | data = torch.cat((data, transform(img).view(1, 3, 224, 224)), 0) 114 | return data 115 | 116 | 117 | ################################################################################ 118 | # 模型预测,单张图片不复制 119 | ################################################################################ 120 | def predict(data, **kwargs): 121 | """ 122 | 模型预测,单张图片不复制 123 | :param 124 | data(tensor) -- 图片数据 125 | :kwargs 126 | 127 | :return(numpy) 128 | 预测结果 129 | """ 130 | # 选择运行的cpu或gpu 131 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 132 | 133 | # 加载模型 134 | models = [] 135 | for k, v in kwargs.items(): 136 | if k == '1': 137 | model = MFFNet() 138 | elif k == '2': 139 | # 299 x 299 140 | model = inception(pretrained = False) 141 | elif k == '3': 142 | model = resnet(pretrained = False) 143 | elif k == '4': 144 | model = densenet(pretrained = False) 145 | elif k == '5': 146 | model = senet(pretrained = False) 147 | # 加载权重 148 | model.load_state_dict(torch.load(v)) 149 | model = model.to(device) 150 | # 测试模式 151 | model.eval() 152 | models.append(model) 153 | 154 | # 使用平均策略获取集成模型输出 155 | data = data.to(device) 156 | output = None 157 | # 平均策略 158 | for model in models: 159 | if output is None: 160 | output = model(data).detach() 161 | else: 162 | output += model(data).detach() 163 | output = output / len(models) 164 | _, a = torch.max(output, 1) 165 | a = a.cpu().detach().numpy() 166 | 167 | # 预测结果 168 | return a 169 | 170 | 171 | ################################################################################ 172 | # 模型预测,单张图片复制 173 | ################################################################################ 174 | def predict_agument(data, **kwargs): 175 | """ 176 | 模型预测,单张图片复制 177 | :param 178 | data(tensor) -- 图片数据 179 | :kwargs 180 | 181 | :return(numpy) 182 | 预测结果 183 | """ 184 | # 选择运行的cpu或gpu 185 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 186 | 187 | # 加载模型 188 | models = [] 189 | for k, v in kwargs.items(): 190 | if k == '1': 191 | model = MFFNet() 192 | elif k == '2': 193 | # 299 x 299 194 | model = inception(pretrained = False) 195 | elif k == '3': 196 | model = resnet(pretrained = False) 197 | elif k == '4': 198 | model = densenet(pretrained = False) 199 | elif k == '5': 200 | model = senet(pretrained = False) 201 | # 加载权重 202 | model.load_state_dict(torch.load(v)) 203 | model = model.to(device) 204 | # 测试模式 205 | model.eval() 206 | models.append(model) 207 | 208 | # 使用平均策略获取集成模型输出 209 | data = data.to(device) 210 | sum = None 211 | # 平均策略 212 | for model in models: 213 | output = model(data) 214 | output = output.detach() 215 | val = torch.zeros(7) 216 | for i in range(output.size(0)): 217 | val = val + output[i] 218 | val = val / output.size(0) 219 | 220 | if sum is None: 221 | sum = val 222 | else: 223 | sum += val 224 | val = sum / len(models) 225 | _, a = torch.max(val, 0) 226 | 227 | return a.item() 228 | 229 | ################################################################################ 230 | # 可视化预测结果,单张图片不复制 231 | ################################################################################ 232 | def plot(path, **kwargs): 233 | """ 234 | 可视化预测结果,单张图片不复制 235 | :param 236 | path(str) -- 图片路径文件地址 237 | :kwargs 238 | """ 239 | # 模型预测 240 | cm_plot_labels = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC'] 241 | data = process_data(path) 242 | outcome = predict(data, kwargs) 243 | 244 | # 可视化预测结果 245 | title = [] 246 | for val in outcome: 247 | title.append(cm_plot_labels[val]) 248 | with open(path, 'rt') as f: 249 | lines = f.readlines() 250 | ims = np.zeros((0, 224, 224, 3)) 251 | for line in lines: 252 | line = line.strip() 253 | img = Image.open(line) 254 | img = np.array(img) 255 | img = img[np.newaxis, :, :, :] 256 | ims = np.concatenate((ims, img), 0) 257 | plot_image(ims, title) 258 | 259 | 260 | ################################################################################ 261 | # 可视化预测结果,单张图片复制 262 | ################################################################################ 263 | def plot_agument(path, **kwargs): 264 | """ 265 | 可视化预测结果,单张图片复制 266 | :param 267 | path(str) -- 图片路径文件地址 268 | :kwargs 269 | """ 270 | # 模型预测 271 | cm_plot_labels = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC'] 272 | with open(path, 'rt') as f: 273 | lines = f.readlines() 274 | outcome = [] 275 | for line in lines: 276 | line = line.strip() 277 | data = process_data_agument(line) 278 | output = predict_agument(data, kwargs) 279 | outcome.append(output) 280 | 281 | # 可视化预测结果 282 | title = [] 283 | for val in outcome: 284 | title.append(cm_plot_labels[val]) 285 | ims = np.zeros((0, 224, 224, 3)) 286 | for line in lines: 287 | line = line.strip() 288 | img = Image.open(line) 289 | img = np.array(img) 290 | img = img[np.newaxis, :, :, :] 291 | ims = np.concatenate((ims, img), 0) 292 | 293 | plot_image(ims, title) 294 | 295 | 296 | ################################################################################ 297 | # 函数入口 298 | ################################################################################ 299 | if __name__ == '__main__': 300 | # 预测图片目录 301 | path = CONFIG.predict_dir 302 | 303 | # 模型配置 1)MFFNet 2)Inception 3)ResNet 4)DenseNet 5)SeNet 304 | args = { # '1' : 'train_models/mffnet_model.pt', 305 | # '2': CONFIG.inception_model, 306 | # '3': CONFIG.resnet_model, 307 | '4': CONFIG.densenet_model, 308 | '5': CONFIG.senet_model} 309 | 310 | # 获取预测图片路径 311 | # image_path(path) 312 | path = os.path.join(path, 'image_path.txt') 313 | 314 | # 可视化预测结果,单张图片不复制 315 | plot(path, args) 316 | # 可视化预测结果,单张图片复制 317 | plot_agument(path, args) 318 | -------------------------------------------------------------------------------- /服务端/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | 常见工具函数. 4 | balanced_multiclass_accuracy 5 | --BMC(Balanced Multiclass Accuracy)评价指标. 6 | ImbalancedDatasetSampler 7 | --概率抽样. 8 | EarlyStopping 9 | --EarlyStopping机制. 10 | plot_confusion_matrix 11 | -- 绘制混淆矩阵. 12 | compute_mean_std 13 | -- 计算平均值(mean)和标准差(std). 14 | plot_loss 15 | -- 绘制损失图 16 | plot_image 17 | -- 绘制图片 18 | """ 19 | from __future__ import print_function, division 20 | import torch 21 | from data import SkinDiseaseDataset 22 | import numpy as np 23 | import CONFIG 24 | import matplotlib.pyplot as plt 25 | import itertools 26 | import random 27 | from PIL import Image 28 | import math 29 | 30 | ################################################################################ 31 | # BMC(Balanced Multiclass Accuracy)评价指标. 32 | # 即混淆矩阵各类召回率(Recall)的和的平均 33 | ################################################################################ 34 | def balanced_multiclass_accuracy(cm): 35 | """ 36 | BMC(Balanced Multiclass Accuracy)评价指标. 37 | param: 38 | cm(numpy) -- 混淆矩阵 39 | return: 40 | 无返回值,直接输出结果 41 | """ 42 | # 类别个数 43 | n = len(cm) 44 | # 各类召回率(Recall)的和 45 | recalls = 0. 46 | # 打印每个类别的精确率(Precision)和召回率(Recall) 47 | for i in range(len(cm[0])): 48 | rowsum, colsum = sum(cm[i]), sum(cm[r][i] for r in range(n)) 49 | try: 50 | print('%d ' %i, 'precision: %.2f' % (cm[i][i] / float(colsum)), 'recall: %.2f' % (cm[i][i] / float(rowsum))) 51 | recalls += (cm[i][i] / float(rowsum)) 52 | except ZeroDivisionError: 53 | print('%d ' %i, 'precision: %s' % '0', 'recall: %s' % '0') 54 | # 打印BMC值 55 | print('balanced_multiclass_accuracy: %.2f' % (recalls / n)) 56 | 57 | ################################################################################ 58 | # 概率抽样. 59 | # 即从给定的不平衡数据集索引列表中随机抽样元素 60 | ################################################################################ 61 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 62 | """ 63 | 概率抽样. 64 | """ 65 | def __init__(self, dataset, indices=None, num_samples=None): 66 | """ 67 | 初始化函数 68 | param: 69 | dataset -- 数据集(torch.utils.data.Dataset) 70 | indices -- 索引列表(list,optional) 71 | num_samples -- 取样数量(int,optional) 72 | """ 73 | # 若索引列表(indices)未提供,则考虑所有元素 74 | self.indices = list(range(len(dataset))) \ 75 | if indices is None else indices 76 | 77 | # 若取样数量(num_samples)未提供,则每次迭代取样数量为总样本数量 78 | self.num_samples = len(self.indices) \ 79 | if num_samples is None else num_samples 80 | 81 | # 数据集中各类分布 82 | label_to_count = {} 83 | for idx in self.indices: 84 | label = self._get_label(dataset, idx) 85 | if label in label_to_count: 86 | label_to_count[label] += 1 87 | else: 88 | label_to_count[label] = 1 89 | 90 | # 每个样本的权重 91 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 92 | for idx in self.indices] 93 | self.weights = torch.DoubleTensor(weights) 94 | 95 | def _get_label(self, dataset, idx): 96 | """ 97 | 返回指定索引的标签 98 | 99 | :param 100 | dataset -- 数据集(torch.utils.data.Dataset) 101 | idx -- 索引(int) 102 | :return 103 | 返回指定索引的标签 104 | """ 105 | dataset_type = type(dataset) 106 | if dataset_type is SkinDiseaseDataset: 107 | return dataset[idx][1] 108 | else: 109 | raise NotImplementedError 110 | 111 | def __iter__(self): 112 | """ 113 | 返回概率抽样后索引列表 114 | """ 115 | return (self.indices[i] for i in torch.multinomial( 116 | self.weights, self.num_samples, replacement = True)) 117 | 118 | def __len__(self): 119 | """ 120 | 获取取样数量 121 | """ 122 | return self.num_samples 123 | 124 | ################################################################################ 125 | # EarlyStopping机制. 126 | # 如果给定patience后,验证集损失还没有改善,则停止训练. 127 | ################################################################################ 128 | class EarlyStopping: 129 | """ 130 | EarlyStopping机制. 131 | """ 132 | def __init__(self, patience = 7, verbose = False): 133 | """ 134 | 初始化函数 135 | param: 136 | patience(int) -- 在上一次验证集损失改善后等待多少epoch 137 | verbose(bool) -- 是否打印信息 138 | :param verbose: 139 | """ 140 | self.patience = patience 141 | self.verbose = verbose 142 | self.counter = 0 143 | self.best_score = None 144 | self.early_stop = False 145 | self.val_loss_min = np.Inf 146 | 147 | def __call__(self, val_loss, model): 148 | 149 | score = -val_loss 150 | 151 | if self.best_score is None: 152 | self.best_score = score 153 | self.save_checkpoint(val_loss, model) 154 | elif score <= self.best_score: 155 | self.counter += 1 156 | print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience)) 157 | if self.counter >= self.patience: 158 | self.early_stop = True 159 | else: 160 | self.best_score = score 161 | self.save_checkpoint(val_loss, model) 162 | self.counter = 0 163 | 164 | def save_checkpoint(self, val_loss, model): 165 | """ 166 | 当验证集损失下降时,存储模型 167 | param: 168 | val_loss -- 验证集损失 169 | model -- 模型 170 | """ 171 | if self.verbose: 172 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(self.val_loss_min, val_loss)) 173 | torch.save(model.state_dict(), CONFIG.best_model) 174 | self.val_loss_min = val_loss 175 | 176 | ################################################################################ 177 | # 绘制混淆矩阵. 178 | ################################################################################ 179 | def plot_confusion_matrix(cm, classes, 180 | normalize = False, 181 | title = 'Confusion matrix', 182 | cmap = plt.cm.Blues): 183 | """ 184 | 绘制混淆矩阵 185 | :param 186 | cm(numpy) -- 混淆矩阵 187 | classes(list) --类别名称 188 | normalize(bool) -- 是否归一化 189 | title(str) -- 标题 190 | cmap -- 颜色 191 | """ 192 | # 是否归一化 193 | if normalize: 194 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 195 | print("Normalized confusion matrix") 196 | else: 197 | print('Confusion matrix, without normalization') 198 | 199 | fig = plt.figure(figsize=(10, 10)) 200 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 201 | plt.title(title) 202 | plt.colorbar() 203 | tick_marks = np.arange(len(classes)) 204 | plt.xticks(tick_marks, classes, rotation = 45) 205 | plt.yticks(tick_marks, classes) 206 | 207 | # 绘制混淆矩阵 208 | fmt = '.2f' if normalize else 'd' 209 | thresh = cm.max() / 2. 210 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 211 | plt.text(j, i, format(cm[i, j], fmt), 212 | horizontalalignment="center", 213 | color="white" if cm[i, j] > thresh else "black") 214 | 215 | plt.ylabel('True label') 216 | plt.xlabel('Predicted label') 217 | plt.tight_layout() 218 | # 显示混淆矩阵 219 | plt.show() 220 | # 保存混淆矩阵图 221 | fig.savefig(CONFIG.confusion_matrix_image, bbox_inches='tight') 222 | 223 | ################################################################################ 224 | # 计算平均值(mean)和标准差(std). 225 | # 正则化,即(0,255)-->(0,1) 226 | # cv2.imread()-->BGR-->0~255-->(H,W,C) 227 | # Image.open()-->RGB-->0~255-->(H,W,C) 228 | ################################################################################ 229 | def compute_mean_std(): 230 | 231 | # 训练集地址 232 | train_txt_path = CONFIG.traindataset_path 233 | 234 | # 挑选多少图片进行计算 235 | CNum = CONFIG.cnum 236 | 237 | img_h, img_w = 300, 300 238 | imgs = np.zeros([img_h, img_w, 3, 0]) 239 | means, stdevs = [], [] 240 | 241 | with open(train_txt_path, 'rt') as f: 242 | # 读取数据 243 | data = f.readlines() 244 | lines = data[0].strip() 245 | lines = lines[1: len(lines) - 1] 246 | lines = lines.split(',') 247 | for i in range(len(lines)): 248 | lines[i] = lines[i].strip() 249 | lines[i] = lines[i].strip('\'') 250 | # shuffle , 随机挑选图片 251 | random.shuffle(lines) 252 | 253 | for i in range(CNum): 254 | img_path = lines[i] 255 | 256 | # # cv2.imread读取 257 | # img = cv2.imread(img_path) 258 | # img = cv2.resize(img, (img_w, img_h)) 259 | 260 | # PIL Image.open读取 261 | img = Image.open(img_path) 262 | img = img.resize((img_w, img_h)) 263 | img = np.array(img) 264 | 265 | img = img[:, :, :, np.newaxis] 266 | imgs = np.concatenate((imgs, img), axis = 3) 267 | print(i) 268 | 269 | imgs = imgs.astype(np.float32)/255. 270 | 271 | 272 | for i in range(3): 273 | # 拉成一行 274 | pixels = imgs[:,:,i,:].ravel() 275 | means.append(np.mean(pixels)) 276 | stdevs.append(np.std(pixels)) 277 | 278 | # # cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转 279 | # means.reverse() # BGR --> RGB 280 | # stdevs.reverse() 281 | 282 | # 打印平均值和标准差 283 | print("normMean = {}".format(means)) 284 | print("normStd = {}".format(stdevs)) 285 | print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs)) 286 | 287 | ################################################################################ 288 | # 绘制损失图. 289 | ################################################################################ 290 | def plot_loss(train_loss, val_loss): 291 | """ 292 | 绘制损失图 293 | 294 | :param 295 | train_loss(list) -- 训练集损失 296 | val_loss(list) -- 验证集损失 297 | """ 298 | # 可视化训练集、验证集损失 299 | fig = plt.figure(figsize=(10, 8)) 300 | plt.plot(range(1, len(train_loss) + 1), train_loss, label='Training Loss') 301 | plt.plot(range(1, len(val_loss) + 1), val_loss, label='Validation Loss') 302 | 303 | # 寻找验证集损失最小的索引 304 | minposs = val_loss.index(min(val_loss)) + 1 305 | plt.axvline(minposs, linestyle='--', color='r', label='Early Stopping Checkpoint') 306 | 307 | # 寻找最大损失值 308 | max_val = max(max(train_loss), max(val_loss)) 309 | max_val = math.ceil(max_val) 310 | # 寻找最小损失值 311 | min_val = min(min(train_loss), min(val_loss)) 312 | min_val = math.floor(min_val) 313 | 314 | plt.xlabel('epochs') 315 | plt.ylabel('loss') 316 | plt.ylim(min_val, max_val) # consistent scale 317 | plt.xlim(0, len(train_loss) + 1) # consistent scale 318 | plt.grid(True) 319 | plt.legend() 320 | plt.tight_layout() 321 | # 显示图片 322 | plt.show() 323 | # 存储图片 324 | fig.savefig(CONFIG.loss_image, bbox_inches='tight') 325 | 326 | ################################################################################ 327 | # 绘制图片. 328 | ################################################################################ 329 | def plot_image(ims, figsize = (12,6), rows = 5, interp = False, titles = None): 330 | """ 331 | 绘制图片 332 | 333 | :param 334 | ims(numpy) -- 图片数据 335 | figsize -- 画布大小 336 | rows(int) -- 行数 337 | interp -- 填充图片方式 338 | titles(list) -- 图片标题 339 | """ 340 | # 转换图片 341 | if type(ims[0]) is np.ndarray: 342 | ims = np.array(ims).astype(np.uint8) 343 | if (ims.shape[-1] != 3): 344 | ims = ims.transpose((0,2,3,1)) 345 | 346 | # 设置画布大小 347 | f = plt.figure(figsize = figsize) 348 | 349 | # 获取列数 350 | cols = len(ims)//rows if len(ims) % 2 == 0 else len(ims)//rows + 1 351 | 352 | # 绘制图片 353 | for i in range(len(ims)): 354 | sp = f.add_subplot(rows, cols, i+1) 355 | sp.axis('Off') 356 | if titles is not None: 357 | sp.set_title(titles[i], fontsize=16) 358 | plt.imshow(ims[i], interpolation=None if interp else 'none') 359 | 360 | -------------------------------------------------------------------------------- /服务端/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 124 | 125 | 126 | 127 | plot_image 128 | 129 | 130 | 131 | 149 | 150 | 151 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 |