├── .gitignore ├── CAM.py ├── README.md ├── WebUI.py ├── dataset.py ├── model.py ├── resnet.py ├── results ├── csv │ ├── MultCNN_bs12_Is256_test_F1.csv │ ├── MultCNN_bs12_Is256_test_acc.csv │ ├── MultCNN_bs12_Is256_test_loss.csv │ ├── MultCNN_bs12_Is256_train_F1.csv │ ├── MultCNN_bs12_Is256_train_acc.csv │ ├── MultCNN_bs12_Is256_train_loss.csv │ ├── MultResNet_bs12_Is256_test_F1.csv │ ├── MultResNet_bs12_Is256_test_acc.csv │ ├── MultResNet_bs12_Is256_test_loss.csv │ ├── MultResNet_bs12_Is256_train_F1.csv │ ├── MultResNet_bs12_Is256_train_acc.csv │ └── MultResNet_bs12_Is256_train_loss.csv └── png │ ├── MultCNN_bs12_Is256_test_F1.png │ ├── MultCNN_bs12_Is256_test_acc.png │ ├── MultCNN_bs12_Is256_test_loss.png │ ├── MultCNN_bs12_Is256_train_F1.png │ ├── MultCNN_bs12_Is256_train_acc.png │ ├── MultCNN_bs12_Is256_train_loss.png │ ├── MultResNet_bs12_Is256_test_F1.png │ ├── MultResNet_bs12_Is256_test_acc.png │ ├── MultResNet_bs12_Is256_test_loss.png │ ├── MultResNet_bs12_Is256_train_F1.png │ ├── MultResNet_bs12_Is256_train_acc.png │ └── MultResNet_bs12_Is256_train_loss.png ├── run.py ├── templates └── index.html ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | logs/* 3 | weight/* 4 | __pycache__/* 5 | .vscode/* 6 | -------------------------------------------------------------------------------- /CAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import cv2 7 | 8 | # 定义 GradCAM 方法的实现函数。 9 | 10 | import numpy as np 11 | 12 | 13 | class GradCAM: 14 | def __init__(self, model, target_layers): 15 | self.model = model 16 | self.target_layers = target_layers 17 | self.feature_maps, self.grads = [], {} 18 | 19 | self.forward_hook, self.backward_hook = [], [] 20 | 21 | self.model.eval() 22 | for i in range(len(self.target_layers)): 23 | for j in range(len(self.target_layers[i])): 24 | self.forward_hook.append( 25 | self.target_layers[i][j].register_forward_hook( 26 | self.save_feature_maps 27 | ) 28 | ) 29 | self.backward_hook.append( 30 | self.target_layers[i][j].register_backward_hook(self.save_grads) 31 | ) 32 | 33 | def save_feature_maps(self, module, input, output): 34 | self.feature_maps.append((id(module), output.detach())) 35 | 36 | def save_grads(self, module, grad_input, grad_output): 37 | self.grads[id(module)] = grad_output[0].detach() 38 | 39 | def __call__(self, input_tensor, class_idx=None): 40 | self.feature_maps = [] 41 | self.grads = {} 42 | output = self.model(input_tensor) 43 | if class_idx is None: 44 | class_idx = np.argmax(output.cpu().detach().numpy()) 45 | 46 | self.model.zero_grad() 47 | output[0, class_idx].backward() 48 | self.feature_maps[0], self.feature_maps[1] = ( 49 | self.feature_maps[0 : len(self.target_layers[0])], 50 | self.feature_maps[len(self.target_layers[0]) :], 51 | ) 52 | cams = [[], []] 53 | for i in range(len(self.target_layers)): 54 | for j in range(len(self.target_layers[i])): 55 | weights = self.grads[self.feature_maps[i][j][0]].mean( 56 | dim=(-2, -1), keepdim=True 57 | ) 58 | cams[i].append( 59 | (weights * self.feature_maps[i][j][1]).sum(dim=1, keepdim=True) 60 | ) 61 | cams[i][j] = nn.functional.relu(cams[i][j]) 62 | 63 | return cams, class_idx #cams is [[sxcams],[mxcams]] 64 | 65 | 66 | if __name__ == "__main__": 67 | # 加载预训练模型 68 | model = models.resnet50(pretrained=True) 69 | model.eval() 70 | 71 | # 加载测试图像 72 | img_path = "test.jpg" 73 | img = Image.open(img_path).resize((224, 224), Image.ANTIALIAS) 74 | 75 | # 图像预处理 76 | preprocess = transforms.Compose( 77 | [ 78 | transforms.Resize((224, 224)), 79 | transforms.ToTensor(), 80 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 81 | ] 82 | ) 83 | img_tensor = preprocess(img) 84 | img_tensor = img_tensor.unsqueeze(0) # 添加 batch 维度 85 | 86 | # 最后,可以使用定义的 GradCAM 方法来生成 CAM 图像,并将 CAM 图像与原始图像叠加以便于可视化。 87 | 88 | # 定义 GradCAM 方法 89 | gradcam = GradCAM(model, model.layer4[2].conv3) 90 | 91 | # 生成 CAM 图像 92 | cam = gradcam(img_tensor) 93 | 94 | # 将 CAM 图像与原始图像叠加 95 | cam = nn.functional.interpolate( 96 | cam, size=img_tensor.shape[-2:], mode="bilinear", align_corners=False 97 | ) 98 | cam = cam.cpu().detach().numpy()[0, 0] 99 | cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) 100 | cam = np.uint8(255 * cam) 101 | cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET) 102 | 103 | result = cv2.addWeighted( 104 | cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR), 0.5, cam, 0.5, 0 105 | ) 106 | 107 | cv2.imshow("gray_scale", result) 108 | cv2.waitKey(0) 109 | print("OVER") 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YanyxGraduationProject 2 | 基于医学舌象和目象的病证诊断和可解释性研究 3 | 多通道卷积神经网络分类、VGGNet、ResNet 4 | GradCAM特征图可视化 5 | flask框架搭建webui 6 | -------------------------------------------------------------------------------- /WebUI.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, make_response, render_template, jsonify 2 | from flask_cors import CORS 3 | from test import * 4 | import base64 5 | import os,io,cv2 6 | from PIL import Image 7 | import numpy as np 8 | app = Flask(__name__) 9 | CORS(app, supports_credentials=True) 10 | 11 | size=(128,128) 12 | def filestorage2img(filestoragelist): # 从缓冲区列表中读取文件并转化为图片 13 | images = [] 14 | for filestorage in filestoragelist: 15 | images.append(Image.open(io.BytesIO(filestorage.read()))) 16 | return images 17 | def array2base64str(img): # 将array转化为base64str序列 18 | img=Image.fromarray(img) 19 | buf=io.BytesIO() 20 | img.save(buf,format='png') 21 | img=base64.b64encode(buf.getbuffer()).decode("ascii") 22 | return img 23 | 24 | @app.route("/", methods=["GET"]) 25 | def home(): 26 | return render_template("index.html") 27 | 28 | 29 | @app.route("/upload", methods=["POST"]) 30 | def upload(): 31 | sx_files = request.files.getlist("sx[]") 32 | mx_files = request.files.getlist("mx[]") 33 | sx_img = filestorage2img(sx_files) 34 | mx_img = filestorage2img(mx_files) 35 | 36 | # patient_path = r"patient/Feiyinxing/303" 37 | 38 | # images = [Image.open(os.path.join(patient_path + ".jpg"))] + [ 39 | # Image.open(os.path.join(patient_path, i)) for i in os.listdir(patient_path) 40 | # ] 41 | # sx_img,mx_img=images[0:1],images[1:] 42 | 43 | results, images, label = deal(sx_img + mx_img, size=size) 44 | # results=[[cv2.resize(np.array(i), dsize=size, interpolation=cv2.INTER_CUBIC) for j in range(10)] for i in sx_img+mx_img] 45 | # images=[cv2.resize(np.array(i), dsize=size, interpolation=cv2.INTER_CUBIC) for i in sx_img+mx_img] 46 | 47 | for i in range(len(results)): 48 | for j in range(len(results[i])): 49 | results[i][j]=array2base64str(results[i][j]) 50 | 51 | images=[array2base64str(img) for img in images] 52 | return jsonify({'results': results,'images':images,'label':label}) 53 | 54 | if __name__ == "__main__": 55 | app.run(debug=True,host="0.0.0.0",port=7777) 56 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | from sklearn.model_selection import train_test_split 5 | from PIL import Image 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import cv2 9 | 10 | 11 | def get_data(root_dir): 12 | tongue_path = os.path.join(root_dir, "GXY_tongue") 13 | eyes_path = os.path.join(root_dir, "GXY-eyes") 14 | tongue = [os.path.join(tongue_path, i) for i in os.listdir(tongue_path)] 15 | eyes = [os.path.join(eyes_path, i) for i in os.listdir(eyes_path)] 16 | 17 | tongue = [os.path.join(i, j) for i in tongue for j in os.listdir(i)] 18 | eyes = [os.path.join(i, j) for i in eyes for j in os.listdir(i)] 19 | 20 | tongue = [os.path.join(i, j) for i in tongue for j in os.listdir(i)] 21 | eyes = [os.path.join(i, j) for i in eyes for j in os.listdir(i)] 22 | 23 | tongue.sort(key=lambda x: x.rsplit("/")[-1]) 24 | eyes.sort(key=lambda x: x.rsplit("/")[-1]) 25 | return list(zip(tongue, eyes)) 26 | 27 | label={"Feiyinxu": torch.tensor(0), "Yinxu": torch.tensor(1)} 28 | class MyDataset(Dataset): 29 | def __init__(self, x, train=True, transform=None): 30 | super().__init__() 31 | self.data = x 32 | self.train = train 33 | self.transform = transform 34 | self.label = label 35 | # t=[i for i in os.listdir(root_dir) if os.path.isdir(i)] 36 | 37 | def __getitem__(self, index): 38 | tonguepath, eyespath = self.data[index] 39 | 40 | label = os.path.split(os.path.split(tonguepath)[0])[1] 41 | tongue = np.array(Image.open(tonguepath)) 42 | eyes = [ 43 | np.array(Image.open(os.path.join(eyespath, i))) 44 | for i in os.listdir(eyespath) 45 | ] 46 | if self.transform is not None: 47 | tongue = self.transform(tongue) 48 | eyes = [self.transform(i) for i in eyes] 49 | # cv2.imshow("test", tongue) 50 | return [tongue]+eyes, self.label[label] 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | 55 | 56 | if __name__ == "__main__": 57 | 58 | data = get_data( 59 | r"/media/codelearner/E2EE175BEE1726F7/Users/QuickLearner/Documents/python/graduationProject/data" 60 | ) 61 | a = MyDataset( 62 | x=data, 63 | transform=transforms.Compose( 64 | [ 65 | transforms.ToPILImage(), 66 | transforms.Resize((128, 128)), 67 | transforms.ToTensor(), 68 | ] 69 | ), 70 | ) 71 | cnt=1 72 | for (tongue, eyes), label in a: 73 | if len(eyes)!=10: 74 | print(cnt) 75 | cnt+=1 76 | # cv2.imshow("gray_scale", eyes[0]) 77 | # cv2.waitKey(0) 78 | # print(label) 79 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from resnet import Bottleneck, ResNet 4 | 5 | 6 | def get_feature_classifier(model): 7 | features = nn.Sequential( 8 | *[ 9 | model.conv1, 10 | model.bn1, 11 | model.relu, 12 | model.maxpool, 13 | model.layer1, 14 | model.layer2, 15 | model.layer3, 16 | ] 17 | ) 18 | classifier = nn.Sequential( 19 | *[ 20 | model.layer4, 21 | ] 22 | ) 23 | return features, classifier 24 | 25 | 26 | def jointImages(images): 27 | mx1 = torch.cat(tuple(images[0:5]), dim=3) 28 | mx2 = torch.cat(tuple(images[5:]), dim=3) 29 | mx = torch.cat((mx1, mx2), dim=2) 30 | return mx 31 | 32 | 33 | class MultResNet(nn.Module): 34 | def __init__(self, sx_model_index, mx_model_index): 35 | super().__init__() 36 | model1 = ResNet(Bottleneck, sx_model_index) 37 | model2 = ResNet(Bottleneck, mx_model_index) 38 | self.features1, self.classifier1 = get_feature_classifier(model1) 39 | self.features2, self.classifier2 = get_feature_classifier(model2) 40 | 41 | self.features = [ 42 | self.features1, 43 | self.features2, 44 | ] 45 | self.classifier = [ 46 | self.classifier1, 47 | self.classifier2, 48 | ] 49 | 50 | self.sxfc = self.getClassifier() 51 | self.mxfc0 = self.getClassifier() 52 | self.mxfc1 = self.getClassifier() 53 | self.mxfc2 = self.getClassifier() 54 | self.mxfc3 = self.getClassifier() 55 | self.mxfc4 = self.getClassifier() 56 | self.mxfc5 = self.getClassifier() 57 | self.mxfc6 = self.getClassifier() 58 | self.mxfc7 = self.getClassifier() 59 | self.mxfc8 = self.getClassifier() 60 | self.mxfc9 = self.getClassifier() 61 | self.mxfc = [ 62 | self.mxfc0, 63 | self.mxfc1, 64 | self.mxfc2, 65 | self.mxfc3, 66 | self.mxfc4, 67 | self.mxfc5, 68 | self.mxfc6, 69 | self.mxfc7, 70 | self.mxfc8, 71 | self.mxfc9, 72 | ] 73 | 74 | def forward(self, x): 75 | xx = [x[0], jointImages(x[1:])] 76 | 77 | out = [] 78 | for i in range(len(self.features)): 79 | out.append(self.features[i](xx[i])) 80 | out[i] = self.classifier[i](out[i]) 81 | 82 | feature_size = out[0].shape[-2:] 83 | out[0] = nn.functional.avg_pool2d(out[0], feature_size) 84 | out[1] = nn.functional.avg_pool2d(out[1], feature_size, feature_size) 85 | 86 | out[0] = out[0].view(out[0].shape[0], -1) 87 | out[0] = self.sxfc(out[0]) 88 | out[1] = out[1].view(out[1].shape[0], out[1].shape[1], -1) 89 | out[1] = torch.chunk(out[1], 10, 2) 90 | 91 | out[1] = [self.mxfc[i](out[1][i].view(out[0].shape[0], -1)) for i in range(len(out[1]))] 92 | out[1] = sum(out[1]) 93 | out = 5 * out[0] + out[1] 94 | return out 95 | 96 | def getClassifier(self): 97 | return nn.Sequential( 98 | nn.Linear(2048, 256), 99 | nn.Dropout(0.5), 100 | nn.ReLU(), 101 | nn.Linear(256, 2), 102 | ) 103 | 104 | 105 | class MultCNN(nn.Module): 106 | def __init__(self, *index): 107 | super(MultCNN, self).__init__() 108 | self.sxcnn = self.VGGPipeline() 109 | self.mxcnn = self.VGGPipeline() 110 | self.sxfc = self.getClassifier() 111 | self.mxfc0 = self.getClassifier() 112 | self.mxfc1 = self.getClassifier() 113 | self.mxfc2 = self.getClassifier() 114 | self.mxfc3 = self.getClassifier() 115 | self.mxfc4 = self.getClassifier() 116 | self.mxfc5 = self.getClassifier() 117 | self.mxfc6 = self.getClassifier() 118 | self.mxfc7 = self.getClassifier() 119 | self.mxfc8 = self.getClassifier() 120 | self.mxfc9 = self.getClassifier() 121 | self.mxfc = [ 122 | self.mxfc0, 123 | self.mxfc1, 124 | self.mxfc2, 125 | self.mxfc3, 126 | self.mxfc4, 127 | self.mxfc5, 128 | self.mxfc6, 129 | self.mxfc7, 130 | self.mxfc8, 131 | self.mxfc9, 132 | ] 133 | 134 | def forward(self, x): 135 | xx = [x[0], jointImages(x[1:])] 136 | 137 | out = [] 138 | out.append(self.sxcnn(xx[0])) 139 | out.append(self.mxcnn(xx[1])) 140 | 141 | feature_size = out[0].shape[-2:] 142 | out[0] = nn.functional.avg_pool2d(out[0], feature_size) 143 | out[1] = nn.functional.avg_pool2d(out[1], feature_size, feature_size) 144 | 145 | out[0] = out[0].view(out[0].shape[0], -1) 146 | out[0] = self.sxfc(out[0]) 147 | out[1] = out[1].view(out[1].shape[0], out[1].shape[1], -1) 148 | out[1] = torch.chunk(out[1], 10, 2) 149 | 150 | out[1] = [self.mxfc[i](out[1][i].view(out[0].shape[0], -1)) for i in range(len(out[1]))] 151 | out[1] = sum(out[1]) 152 | out = out[0] + out[1] 153 | return out 154 | 155 | def getClassifier(self): 156 | return nn.Sequential( 157 | nn.Linear(512, 64), 158 | nn.Dropout(0.5), 159 | nn.ReLU(), 160 | nn.Linear(64, 2), 161 | ) 162 | 163 | def VGGPipeline(self): 164 | return nn.Sequential( 165 | nn.Conv2d(3, 64, 3, 1, 1), 166 | nn.BatchNorm2d(64), 167 | nn.ReLU(inplace=True), 168 | nn.Conv2d(64, 64, 3, 1, 1), 169 | nn.BatchNorm2d(64), 170 | nn.ReLU(inplace=True), 171 | nn.MaxPool2d(2, 2), 172 | nn.Conv2d(64, 128, 3, 1, 1), 173 | nn.BatchNorm2d(128), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(128, 128, 3, 1, 1), 176 | nn.BatchNorm2d(128), 177 | nn.ReLU(inplace=True), 178 | nn.MaxPool2d(2, 2), 179 | nn.Conv2d(128, 256, 3, 1, 1), 180 | nn.BatchNorm2d(256), 181 | nn.ReLU(inplace=True), 182 | nn.Conv2d(256, 256, 3, 1, 1), 183 | nn.BatchNorm2d(256), 184 | nn.ReLU(inplace=True), 185 | nn.MaxPool2d(2, 2), 186 | nn.Conv2d(256, 512, 3, 1, 1), 187 | nn.BatchNorm2d(512), 188 | nn.ReLU(inplace=True), 189 | nn.Conv2d(512, 512, 3, 1, 1), 190 | nn.BatchNorm2d(512), 191 | nn.ReLU(inplace=True), 192 | nn.MaxPool2d(2, 2), 193 | nn.Conv2d(512, 512, 3, 1, 1), 194 | nn.BatchNorm2d(512), 195 | nn.ReLU(inplace=True), 196 | nn.Conv2d(512, 512, 3, 1, 1), 197 | nn.BatchNorm2d(512), 198 | nn.ReLU(inplace=True), 199 | nn.MaxPool2d(2, 2), 200 | ) 201 | 202 | 203 | model = {"MultResNet": MultResNet, "MultCNN": MultCNN} 204 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | def conv3x3(in_channels, out_channels, stride=1): 5 | return nn.Conv2d( 6 | in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False 7 | ) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 14 | super(BasicBlock, self).__init__() 15 | # conv1 16 | self.conv1 = conv3x3(in_channels, out_channels, stride) 17 | self.bn1 = nn.BatchNorm2d(out_channels) 18 | self.relu = nn.ReLU(inplace=True) 19 | # conv2 20 | self.conv2 = conv3x3(out_channels, out_channels) 21 | self.bn2 = nn.BatchNorm2d(out_channels) 22 | # downsample 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | if self.downsample is not None: 35 | residual = self.downsample(residual) 36 | out += residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 45 | super(Bottleneck, self).__init__() 46 | # conv1 1x1 zip channels 47 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(out_channels) 49 | # conv2 3x3 hidden feature extract 50 | self.conv2 = conv3x3(out_channels, out_channels, stride=stride) 51 | self.bn2 = nn.BatchNorm2d(out_channels) 52 | # conv3 1x1 unzip channels 53 | self.conv3 = nn.Conv2d( 54 | out_channels, out_channels * 4, kernel_size=1, bias=False 55 | ) 56 | self.bn3 = nn.BatchNorm2d(out_channels * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv3(out) 72 | out = self.bn3(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(residual) 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet(nn.Module): 83 | def __init__(self, block, layers, num_classes=2): 84 | super(ResNet, self).__init__() 85 | self.in_channels = 64 86 | # conv1 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 88 | self.bn1 = nn.BatchNorm2d(64) 89 | self.relu = nn.ReLU(inplace=True) 90 | # conv2 91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 92 | self.layer1 = self._make_layer(block, 64, layers[0]) 93 | # conv3 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | # conv4 96 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 97 | # conv5 98 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 99 | 100 | self.avgpool = nn.AvgPool2d(4) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | # 初始化权重 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_layer(self, block, channels, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.in_channels != channels * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d( 117 | self.in_channels, 118 | channels * block.expansion, 119 | kernel_size=1, 120 | stride=stride, 121 | bias=False, 122 | ), 123 | nn.BatchNorm2d(channels * block.expansion), 124 | ) 125 | layers = [] 126 | layers.append(block(self.in_channels, channels, stride, downsample)) 127 | self.in_channels = channels * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.in_channels, channels)) 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | x = nn.functional.avg_pool2d(x,x.shape[-2:]) 143 | # x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | 147 | return x -------------------------------------------------------------------------------- /results/csv/MultCNN_bs12_Is256_test_F1.csv: -------------------------------------------------------------------------------- 1 | 95.65,88.37,23.08,93.33,80.0,85.71,51.61,76.92,76.92,76.92,76.92,76.92,88.37,88.37,82.93,93.33,82.93,95.65,68.57,95.65,85.71,90.91,88.37,82.93,85.71,85.71,85.71,85.71,80.0,90.91,88.37,90.91,93.33,95.65,90.91,95.65,93.33,73.68,88.37,85.71,85.71,85.71,76.92,76.92,73.68,90.91,85.71,93.33,80.0,88.37,73.68,73.68,85.71,82.93,58.82,85.71,80.0,80.0,80.0,70.27,70.27,82.93,88.37,73.68,88.37,90.91,88.37,82.93,62.86,54.55,66.67,80.0,80.0,76.92,80.0,70.27,76.92,85.71,88.37,82.93,85.71,80.0,88.37,56.25,76.92,76.92,85.71,76.92,76.92,76.92,76.92,82.93,73.68,80.0,54.55,56.25,90.91,82.93,88.37,66.67 2 | -------------------------------------------------------------------------------- /results/csv/MultCNN_bs12_Is256_test_acc.csv: -------------------------------------------------------------------------------- 1 | 91.67,79.17,16.67,87.5,66.67,75.0,37.5,62.5,62.5,62.5,62.5,62.5,79.17,79.17,70.83,87.5,70.83,91.67,54.17,91.67,75.0,83.33,79.17,70.83,75.0,75.0,75.0,75.0,66.67,83.33,79.17,83.33,87.5,91.67,83.33,91.67,87.5,58.33,79.17,75.0,75.0,75.0,62.5,62.5,58.33,83.33,75.0,87.5,66.67,79.17,58.33,58.33,75.0,70.83,41.67,75.0,66.67,66.67,66.67,54.17,54.17,70.83,79.17,58.33,79.17,83.33,79.17,70.83,45.83,37.5,50.0,66.67,66.67,62.5,66.67,54.17,62.5,75.0,79.17,70.83,75.0,66.67,79.17,41.67,62.5,62.5,75.0,62.5,62.5,62.5,62.5,70.83,58.33,66.67,37.5,41.67,83.33,70.83,79.17,50.0 2 | -------------------------------------------------------------------------------- /results/csv/MultCNN_bs12_Is256_test_loss.csv: -------------------------------------------------------------------------------- 1 | 0.05,0.05,0.11,0.05,0.07,0.05,0.07,0.06,0.07,0.07,0.06,0.09,0.05,0.06,0.08,0.05,0.05,0.09,0.06,0.03,0.06,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.06,0.05,0.06,0.06,0.05,0.03,0.04,0.06,0.04,0.08,0.06,0.06,0.06,0.06,0.07,0.08,0.07,0.06,0.07,0.06,0.07,0.06,0.08,0.09,0.06,0.07,0.13,0.06,0.08,0.08,0.08,0.12,0.09,0.07,0.09,0.07,0.08,0.06,0.09,0.07,0.11,0.11,0.11,0.1,0.1,0.1,0.1,0.11,0.09,0.07,0.06,0.07,0.08,0.09,0.08,0.13,0.11,0.12,0.08,0.1,0.1,0.1,0.1,0.09,0.09,0.08,0.21,0.14,0.11,0.06,0.09,0.15 2 | -------------------------------------------------------------------------------- /results/csv/MultCNN_bs12_Is256_train_F1.csv: -------------------------------------------------------------------------------- 1 | 64.44,72.85,68.46,71.85,76.01,77.7,78.7,81.59,79.3,76.47,77.66,77.04,75.1,76.39,79.1,78.07,76.81,75.62,76.34,77.62,80.0,78.0,83.27,81.68,79.1,82.66,82.86,84.56,79.55,84.13,85.61,80.31,79.72,81.54,80.99,84.67,84.25,86.57,85.39,88.81,87.97,86.36,86.86,86.82,90.44,86.14,85.07,86.13,88.48,86.76,87.68,87.45,90.0,89.38,91.95,93.89,93.63,92.66,93.68,92.83,90.3,90.51,87.64,87.36,88.39,87.97,89.3,91.97,91.89,93.89,96.27,97.34,97.01,98.47,94.74,96.95,93.94,93.63,94.46,90.42,92.19,93.89,92.19,91.32,94.12,95.82,98.86,99.25,98.14,96.65,98.86,98.48,97.71,97.73,95.52,93.23,91.67,90.84,92.48,95.56 2 | -------------------------------------------------------------------------------- /results/csv/MultCNN_bs12_Is256_train_acc.csv: -------------------------------------------------------------------------------- 1 | 54.29,60.95,60.95,63.81,69.05,70.48,71.9,75.71,71.9,69.52,70.95,70.48,69.05,67.62,73.33,71.9,69.52,67.14,70.48,69.52,74.76,68.57,78.57,77.14,73.33,77.62,77.14,80.0,74.29,79.52,80.48,75.71,72.86,77.14,74.29,80.0,79.52,82.86,81.43,85.71,84.76,82.86,82.86,83.81,87.62,82.38,80.95,81.9,85.24,82.86,83.81,84.29,87.62,86.19,90.0,92.38,91.9,90.95,91.9,90.95,87.62,87.62,84.29,84.29,85.24,84.76,86.19,89.52,90.0,92.38,95.24,96.67,96.19,98.1,93.33,96.19,92.38,91.9,92.86,88.1,90.0,92.38,90.0,89.05,92.38,94.76,98.57,99.05,97.62,95.71,98.57,98.1,97.14,97.14,94.29,91.43,89.52,88.57,90.48,94.29 2 | -------------------------------------------------------------------------------- /results/csv/MultCNN_bs12_Is256_train_loss.csv: -------------------------------------------------------------------------------- 1 | 0.08,0.07,0.06,0.05,0.05,0.05,0.05,0.05,0.05,0.04,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.06,0.05,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.04,0.05,0.04,0.05,0.04,0.04,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.04,0.03,0.03,0.03,0.04,0.03,0.03,0.03,0.03,0.02,0.02,0.02,0.02,0.02,0.02,0.03,0.03,0.03,0.03,0.03,0.03,0.03,0.02,0.02,0.02,0.01,0.01,0.01,0.01,0.01,0.01,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.02,0.03,0.02,0.02,0.01 2 | -------------------------------------------------------------------------------- /results/csv/MultResNet_bs12_Is256_test_F1.csv: -------------------------------------------------------------------------------- 1 | 24.0,0.0,0.0,23.08,95.65,88.37,68.57,85.71,85.71,82.93,85.71,82.93,93.33,54.55,73.68,88.37,73.68,88.37,29.63,85.0,87.8,64.71,78.95,87.8,93.02,93.02,93.02,85.0,87.8,85.0,64.71,95.65,85.71,95.65,93.33,80.0,41.38,85.71,76.92,82.93,76.92,76.92,76.92,80.0,80.0,80.0,88.37,66.67,73.68,62.86,70.27,82.93,51.61,95.45,82.93,88.37,85.71,88.37,82.93,80.0,93.33,75.68,41.38,85.71,54.55,45.16,80.0,80.0,90.91,88.37,85.71,85.71,85.71,85.71,85.71,80.0,82.93,82.93,80.0,88.37,62.86,82.93,90.91,88.37,88.37,85.71,88.37,93.33,93.33,93.33,95.65,95.65,90.91,90.91,95.65,85.0,85.71,56.25,66.67,80.0 2 | -------------------------------------------------------------------------------- /results/csv/MultResNet_bs12_Is256_test_acc.csv: -------------------------------------------------------------------------------- 1 | 20.83,8.33,8.33,16.67,91.67,79.17,54.17,75.0,75.0,70.83,75.0,70.83,87.5,37.5,58.33,79.17,58.33,79.17,20.83,75.0,79.17,50.0,66.67,79.17,87.5,87.5,87.5,75.0,79.17,75.0,50.0,91.67,75.0,91.67,87.5,66.67,29.17,75.0,62.5,70.83,62.5,62.5,62.5,66.67,66.67,66.67,79.17,50.0,58.33,45.83,54.17,70.83,37.5,91.67,70.83,79.17,75.0,79.17,70.83,66.67,87.5,62.5,29.17,75.0,37.5,29.17,66.67,66.67,83.33,79.17,75.0,75.0,75.0,75.0,75.0,66.67,70.83,70.83,66.67,79.17,45.83,70.83,83.33,79.17,79.17,75.0,79.17,87.5,87.5,87.5,91.67,91.67,83.33,83.33,91.67,75.0,75.0,41.67,50.0,66.67 2 | -------------------------------------------------------------------------------- /results/csv/MultResNet_bs12_Is256_test_loss.csv: -------------------------------------------------------------------------------- 1 | 0.08,0.36,0.13,0.22,0.04,0.05,0.1,0.06,0.07,0.07,0.07,0.06,0.07,0.16,0.12,0.05,0.07,0.05,0.13,0.11,0.04,0.08,0.07,0.05,0.05,0.05,0.05,0.05,0.04,0.08,0.1,0.04,0.05,0.05,0.04,0.05,0.15,0.07,0.08,0.06,0.07,0.08,0.07,0.07,0.07,0.07,0.09,0.09,0.12,0.16,0.09,0.06,0.16,0.04,0.07,0.06,0.06,0.05,0.06,0.06,0.05,0.1,0.19,0.15,0.16,0.17,0.05,0.07,0.06,0.06,0.06,0.06,0.07,0.07,0.07,0.07,0.07,0.08,0.08,0.07,0.12,0.2,0.09,0.08,0.06,0.07,0.06,0.06,0.06,0.06,0.06,0.06,0.06,0.06,0.06,0.05,0.14,0.15,0.09,0.11 2 | -------------------------------------------------------------------------------- /results/csv/MultResNet_bs12_Is256_train_F1.csv: -------------------------------------------------------------------------------- 1 | 62.88,68.44,69.77,72.03,77.86,83.57,82.49,86.23,88.55,90.91,91.32,91.39,86.49,84.67,81.34,80.62,76.69,81.06,80.99,86.79,89.53,92.8,96.97,98.51,97.76,98.88,97.74,99.24,97.38,88.28,76.34,83.52,84.29,88.15,87.45,90.07,89.15,94.03,95.42,98.86,98.5,99.62,99.25,99.24,99.25,98.11,91.67,85.39,84.64,89.47,93.58,95.82,95.85,96.6,98.86,100.0,99.62,99.25,100.0,100.0,98.48,95.09,90.04,90.49,92.13,91.05,94.81,97.69,98.07,98.87,100.0,100.0,100.0,99.62,100.0,99.62,100.0,99.62,98.1,91.54,89.31,88.72,88.81,91.67,96.18,97.32,98.46,98.87,98.48,98.86,100.0,99.62,100.0,100.0,96.6,95.02,91.04,92.61,95.85,97.76 2 | -------------------------------------------------------------------------------- /results/csv/MultResNet_bs12_Is256_train_acc.csv: -------------------------------------------------------------------------------- 1 | 53.33,60.48,62.86,65.24,72.38,78.1,78.57,81.9,85.71,88.57,89.05,89.05,83.33,80.0,76.19,76.19,70.48,76.19,74.29,83.33,86.19,91.43,96.19,98.1,97.14,98.57,97.14,99.05,96.67,85.71,70.48,78.57,80.48,84.76,84.29,86.67,86.67,92.38,94.29,98.57,98.1,99.52,99.05,99.05,99.05,97.62,89.52,81.43,80.48,86.67,91.9,94.76,94.76,95.71,98.57,100.0,99.52,99.05,100.0,100.0,98.1,93.81,87.14,88.1,90.0,89.05,93.33,97.14,97.62,98.57,100.0,100.0,100.0,99.52,100.0,99.52,100.0,99.52,97.62,89.52,86.67,85.71,85.71,89.52,95.24,96.67,98.1,98.57,98.1,98.57,100.0,99.52,100.0,100.0,95.71,93.81,88.57,90.95,94.76,97.14 2 | -------------------------------------------------------------------------------- /results/csv/MultResNet_bs12_Is256_train_loss.csv: -------------------------------------------------------------------------------- 1 | 0.32,0.14,0.11,0.09,0.07,0.05,0.04,0.03,0.02,0.03,0.02,0.02,0.03,0.04,0.06,0.05,0.07,0.05,0.04,0.03,0.03,0.02,0.01,0.01,0.01,0.01,0.01,0.01,0.01,0.03,0.07,0.04,0.03,0.03,0.03,0.03,0.03,0.02,0.02,0.01,0.01,0.01,0.01,0.0,0.0,0.01,0.02,0.04,0.04,0.03,0.02,0.01,0.01,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.02,0.03,0.02,0.03,0.02,0.01,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.01,0.02,0.02,0.03,0.04,0.02,0.01,0.01,0.01,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.01,0.02,0.03,0.02,0.01,0.01 2 | -------------------------------------------------------------------------------- /results/png/MultCNN_bs12_Is256_test_F1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultCNN_bs12_Is256_test_F1.png -------------------------------------------------------------------------------- /results/png/MultCNN_bs12_Is256_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultCNN_bs12_Is256_test_acc.png -------------------------------------------------------------------------------- /results/png/MultCNN_bs12_Is256_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultCNN_bs12_Is256_test_loss.png -------------------------------------------------------------------------------- /results/png/MultCNN_bs12_Is256_train_F1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultCNN_bs12_Is256_train_F1.png -------------------------------------------------------------------------------- /results/png/MultCNN_bs12_Is256_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultCNN_bs12_Is256_train_acc.png -------------------------------------------------------------------------------- /results/png/MultCNN_bs12_Is256_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultCNN_bs12_Is256_train_loss.png -------------------------------------------------------------------------------- /results/png/MultResNet_bs12_Is256_test_F1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultResNet_bs12_Is256_test_F1.png -------------------------------------------------------------------------------- /results/png/MultResNet_bs12_Is256_test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultResNet_bs12_Is256_test_acc.png -------------------------------------------------------------------------------- /results/png/MultResNet_bs12_Is256_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultResNet_bs12_Is256_test_loss.png -------------------------------------------------------------------------------- /results/png/MultResNet_bs12_Is256_train_F1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultResNet_bs12_Is256_train_F1.png -------------------------------------------------------------------------------- /results/png/MultResNet_bs12_Is256_train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultResNet_bs12_Is256_train_acc.png -------------------------------------------------------------------------------- /results/png/MultResNet_bs12_Is256_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-learn-er/YanyxGraduationProject/ceee4d175afe6bf274da8722c756a5912e2cd24f/results/png/MultResNet_bs12_Is256_train_loss.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from model import * 5 | import matplotlib.pyplot as plt 6 | import torchvision.transforms as transforms 7 | from dataset import get_data 8 | import torch.optim as optim 9 | from torch import nn 10 | from train import train_test 11 | import os 12 | import csv 13 | from dataset import MyDataset 14 | from torch.utils.data import DataLoader 15 | 16 | # from PyQt5.QtCore import QLibraryInfo 17 | 18 | # os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = QLibraryInfo.location( 19 | # QLibraryInfo.PluginsPath 20 | # ) 21 | size = (256, 256) 22 | sx_model_index = [3, 4, 3, 3] 23 | mx_model_index = [3, 4, 3, 3] 24 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--models", nargs="+", default=["MultResNet", "MultCNN"], type=str 31 | ) 32 | parser.add_argument("--batch_size", default=12, type=int) 33 | parser.add_argument("--epochs", default=5, type=int) 34 | parser.add_argument("--lr", default=0.0005, type=float) 35 | parser.add_argument("--size", default=size, type=tuple) 36 | opt = parser.parse_args() 37 | model_names = opt.models 38 | batch_size = opt.batch_size 39 | epochs = opt.epochs 40 | lr = opt.lr 41 | size = opt.size 42 | 43 | root_dir = os.getcwd() 44 | weight_dir = os.path.join(root_dir, "weight") 45 | if not os.path.exists(weight_dir): 46 | os.makedirs(weight_dir) 47 | if not os.path.exists("results"): 48 | os.makedirs("results") 49 | # 准备数据 50 | data = get_data(os.path.join(root_dir, "data")) 51 | # 数据转换器 52 | data_transform = transforms.Compose( 53 | [ 54 | transforms.ToPILImage(), 55 | transforms.Resize(size), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 58 | ] 59 | ) 60 | full_dataset = MyDataset(x=data, transform=data_transform) 61 | train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [int(0.9*len(full_dataset)), len(full_dataset) - int(0.9*len(full_dataset))]) 62 | train_loader = DataLoader( 63 | dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False 64 | ) 65 | test_loader = DataLoader( 66 | dataset=test_dataset, batch_size=batch_size, shuffle=True, drop_last=False 67 | ) 68 | # 可视化writer 69 | writer = SummaryWriter("logs") 70 | # 定义模型 71 | for model_name in model_names: 72 | xuxu = model[model_name](sx_model_index, mx_model_index).to(device) 73 | # if weight_dir is not None and model_name + ".pth" in os.listdir(weight_dir): 74 | # try: 75 | # xuxu.load_state_dict( 76 | # torch.load(os.path.join(weight_dir, model_name + ".pth")) 77 | # ) 78 | # except: 79 | # print( 80 | # "****(= 7 ^ 7 =)---- {}: model structure has been changed!!!".format( 81 | # model_name 82 | # ) 83 | # ) 84 | # else: 85 | # print( 86 | # "\/\/\/(= ^ _ ^ =)//// {}: model load successfully !!!".format( 87 | # model_name 88 | # ) 89 | # ) 90 | 91 | # 定义优化器和误差函数 92 | optimizer = optim.Adam(xuxu.parameters(), lr=lr) 93 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=8, last_epoch=-1) 94 | # lr_scheduler = optim.lr_scheduler.StepLR( 95 | # optimizer, step_size=1, gamma=0.9) 96 | criterion = nn.CrossEntropyLoss().to(device) 97 | # 记录列表清零 98 | record=[[] for i in range(6)] 99 | 100 | for i in range(epochs): 101 | train_test( 102 | i, 103 | train_loader, 104 | test_loader, 105 | batch_size, 106 | data_transform, 107 | xuxu, 108 | optimizer, 109 | lr_scheduler, 110 | criterion, 111 | device, 112 | writer, 113 | record 114 | ) 115 | 116 | torch.save(xuxu.state_dict(), os.path.join( 117 | weight_dir, model_name + ".pth")) 118 | 119 | 120 | train_acculist, train_F1list, train_losslist,test_acculist, test_F1list, test_losslist = record 121 | def plot(l, xlabel, ylabel, title): 122 | plt.plot(l) 123 | plt.xlabel(xlabel) 124 | plt.ylabel(ylabel) 125 | plt.title("{0}_bs{1}_Is{2}_{3}".format( 126 | model_name, batch_size, size[0], title)) 127 | plt.savefig( 128 | "results/png/{0}_bs{1}_Is{2}_{3}.png".format(model_name, batch_size, size[0], title)) 129 | plt.cla() 130 | plot(train_acculist, "epochs", "acc", "train_acc") 131 | plot(train_F1list, "epochs", "F1", "train_F1") 132 | plot(train_losslist, "epochs", "loss", "train_loss") 133 | plot(test_acculist, "epochs", "acc", "test_acc") 134 | plot(test_F1list, "epochs", "F1", "test_F1") 135 | plot(test_losslist, "epochs", "loss", "test_loss") 136 | 137 | def write(l, name): 138 | with open( 139 | "results/csv/{0}_bs{1}_Is{2}_{3}.csv".format(model_name, batch_size, size[0], name), "w", encoding="utf-8" 140 | ) as f: 141 | writercsv = csv.writer(f) 142 | writercsv.writerows([list(map(lambda x:round(x, 2), l))]) 143 | write(train_acculist, "train_acc") 144 | write(train_F1list, "train_F1") 145 | write(train_losslist, "train_loss") 146 | write(test_acculist, "test_acc") 147 | write(test_F1list, "test_F1") 148 | write(test_losslist, "test_loss") 149 | 150 | writer.close() 151 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 病症诊断和可解释性研究 7 | 8 | 9 | 10 | 54 | 55 | 56 | 57 | 62 | 63 |
64 |
65 |
66 |
67 | 68 |
69 | 70 | 71 |
72 |
73 |
74 |
75 |
76 |
77 | 78 |
79 | 80 | 81 |
82 |
83 |
84 |
85 |
86 |
87 | 88 |
89 |
90 |
91 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import * 3 | 4 | import torchvision.transforms as transforms 5 | from torch import nn 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | import cv2 10 | from CAM import GradCAM 11 | from run import sx_model_index, mx_model_index, size 12 | 13 | label = {0: "Feiyinxing", 1: "Yinxing"} 14 | # 数据转换器 15 | data_transform = transforms.Compose( 16 | [ 17 | transforms.ToPILImage(), 18 | transforms.Resize(size), 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 21 | ] 22 | ) 23 | root_dir = r"." 24 | batch_size = 1 25 | weight_dir = os.path.join(root_dir, "weight") 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | 28 | xuxu = MultResNet(sx_model_index, mx_model_index).to(device) 29 | model_name = "MultResNet" 30 | 31 | if weight_dir is not None and model_name + ".pth" in os.listdir(weight_dir): 32 | try: 33 | xuxu.load_state_dict(torch.load(os.path.join(weight_dir, model_name + ".pth"))) 34 | except: 35 | print( 36 | "****(= 7 ^ 7 =)---- {}: model structure has been changed!!! ----(T^T)----".format( 37 | model_name 38 | ) 39 | ) 40 | else: 41 | print( 42 | "\/\/\/(= ^ _ ^ =)//// {}: model load successfully !!!".format(model_name) 43 | ) 44 | 45 | sx_layers = [ 46 | xuxu.features1[0], 47 | xuxu.features1[4][0].conv1, 48 | xuxu.features1[4][1].conv3, 49 | xuxu.features1[4][2].conv3, 50 | xuxu.features1[5][0].conv3, 51 | xuxu.features1[5][1].conv3, 52 | xuxu.features1[5][2].conv3, 53 | xuxu.features1[5][3].conv3, 54 | xuxu.features1[6][0].conv3, 55 | xuxu.features1[6][1].conv3, 56 | xuxu.features1[6][2].conv3, 57 | xuxu.classifier1[0][0].conv3, 58 | xuxu.classifier1[0][1].conv3, 59 | xuxu.classifier1[0][2].conv3, 60 | ] 61 | mx_layers = [ 62 | xuxu.features2[0], 63 | xuxu.features2[4][0].conv1, 64 | xuxu.features2[4][1].conv3, 65 | xuxu.features2[4][2].conv3, 66 | xuxu.features2[5][0].conv3, 67 | xuxu.features2[5][1].conv3, 68 | xuxu.features2[5][2].conv3, 69 | xuxu.features2[5][3].conv3, 70 | xuxu.features2[6][0].conv3, 71 | xuxu.features2[6][1].conv3, 72 | xuxu.features2[6][2].conv3, 73 | xuxu.classifier2[0][0].conv3, 74 | xuxu.classifier2[0][1].conv3, 75 | xuxu.classifier2[0][2].conv3, 76 | ] 77 | target_layers = [sx_layers, mx_layers] 78 | gradcam = GradCAM(xuxu, target_layers) 79 | 80 | 81 | def overlap(cam, image, size): 82 | image = np.array(image) 83 | image = cv2.resize(image, dsize=size, interpolation=cv2.INTER_CUBIC) 84 | cam = nn.functional.interpolate( 85 | cam, size=size, mode="bilinear", align_corners=False 86 | ) 87 | cam = cam.cpu().detach().numpy()[0, 0] 88 | cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) 89 | cam = np.uint8(255 * cam) 90 | cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET) 91 | 92 | result = cv2.addWeighted(cv2.cvtColor(image, cv2.COLOR_RGB2BGR), 0.5, cam, 0.5, 0) 93 | # result = cv2.resize(result, dsize=(512, 512), interpolation=cv2.INTER_CUBIC) 94 | return image, result 95 | 96 | 97 | def deal(images, size=(128, 128)): 98 | images = [image.resize(size, Image.ANTIALIAS) for image in images] 99 | x = [data_transform(np.array(image)) for image in images] 100 | x = [i.unsqueeze(0).to(device) for i in x] 101 | original_cam, class_idx = gradcam(x) #得到舌象的类激活图和大张目象的类激活图 102 | sxcams, mxcams = [original_cam[0]], [] 103 | for mxcam in original_cam[1]: 104 | cams = [] 105 | h, w = mxcam.shape[-2:] # 切割类激活图 106 | for i in range(0, h, h // 2): 107 | for j in range(0, w, w // 5): 108 | cams.append(mxcam[:, :, i : i + h // 2, j : j + w // 5]) 109 | mxcams.append(cams) 110 | tt = [[] for i in range(10)] #翻转类激活图 111 | for i in range(len(mxcams)): 112 | for j in range(len(mxcams[i])): 113 | tt[j].append(mxcams[i][j]) 114 | mxcams = tt 115 | cams = sxcams + mxcams 116 | for i in range(len(cams)): 117 | for j in range(len(cams[i])): 118 | images[i], cams[i][j] = overlap(cams[i][j], images[i], size=size) 119 | 120 | return cams, images, label[class_idx] 121 | 122 | 123 | if __name__ == "__main__": 124 | patient_path = r"patient/Feiyinxing/303" 125 | 126 | if not os.path.exists(weight_dir): 127 | os.makedirs(weight_dir) 128 | 129 | # 准备数据 130 | images = [Image.open(os.path.join(patient_path + ".jpg"))] + [ 131 | Image.open(os.path.join(patient_path, i)) for i in os.listdir(patient_path) 132 | ] 133 | # images = [image.resize((128, 128), Image.ANTIALIAS) for image in images] 134 | # x = [data_transform(np.array(image)) for image in images] 135 | 136 | # x = [i.unsqueeze(0).to(device) for i in x] 137 | 138 | # cams, class_idx = gradcam(x) 139 | results, images, label = deal(images, size=(64, 64)) 140 | # print(label[class_idx]) 141 | # 将 CAM 图像与原始图像叠加 142 | 143 | # for i in range(len(cams)): 144 | # images[i], cams[i] = overlap(cams[i], images[i]) 145 | # cv2.imshow("original",cv2.resize(np.array(images[1]), dsize=(512, 512), interpolation=cv2.INTER_CUBIC)) 146 | for i in range(len(results)): 147 | cv2.imshow("images" + str(i), images[i]) 148 | cv2.imshow("result" + str(i), results[i]) 149 | cv2.waitKey(0) 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import KFold 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from model import MultCNN 5 | import random 6 | 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import DataLoader 9 | from dataset import MyDataset, get_data 10 | import torch.optim as optim 11 | from torch import nn 12 | esp = 1e-6 13 | def get_num_correct(preds, labels,TP_TN_FP_FN): 14 | preds = preds.argmax(dim=1) 15 | for i in range(len(preds)): 16 | if preds[i] == 0 and labels[i] == 0: 17 | TP_TN_FP_FN[0] += 1 18 | if preds[i] == 1 and labels[i] == 1: 19 | TP_TN_FP_FN[1] += 1 20 | if preds[i] == 0 and labels[i] == 1: 21 | TP_TN_FP_FN[2] += 1 22 | if preds[i] == 1 and labels[i] == 0: 23 | TP_TN_FP_FN[3] += 1 24 | def get_acc_F1(TP_TN_FP_FN): 25 | TP,TN,FP,FN = TP_TN_FP_FN 26 | P = TP / (TP + FP + esp) 27 | R = TP / (TP + FN + esp) 28 | F1 = 2 * P * R / (P + R + esp) 29 | acc = (TP + TN) / (TP + TN + FP + FN + esp) 30 | return acc,F1 31 | # 将data中取出train和test的dataset,再分别转化为dataloader 32 | def get_train_test_dataloader( 33 | data, batch_size, train_index, test_index, data_transform 34 | ): 35 | train_fold = torch.utils.data.dataset.Subset(data, train_index) 36 | test_fold = torch.utils.data.dataset.Subset(data, test_index) 37 | 38 | train_fold = MyDataset(x=train_fold, transform=data_transform) 39 | test_fold = MyDataset(x=test_fold, transform=data_transform) 40 | 41 | # 打包成DataLoader类型用于训练 42 | train_loader = DataLoader( 43 | dataset=train_fold, batch_size=batch_size, shuffle=True, drop_last=False 44 | ) 45 | test_loader = DataLoader( 46 | dataset=test_fold, batch_size=batch_size, shuffle=True, drop_last=False 47 | ) 48 | return train_loader, test_loader 49 | 50 | 51 | def train_test( 52 | i, train_loader,test_loader,batch_size, data_transform, model, optimizer, lr_scheduler, criterion, device, writer,record 53 | ): 54 | # kf = KFold(n_splits=10, shuffle=True, random_state=random.randint(0, 10)) 55 | 56 | num = 1 57 | 58 | train_acculist, train_F1list, train_losslist,test_acculist, test_F1list, test_losslist = record 59 | # for train_index, test_index in kf.split(data): 60 | # train_loader, test_loader = get_train_test_dataloader( 61 | # data, batch_size, train_index, test_index, data_transform 62 | # ) 63 | for num in range(10): 64 | train_loss = 0 65 | train_TP_TN_FP_FN=[0,0,0,0] 66 | test_loss = 0 67 | test_TP_TN_FP_FN=[0,0,0,0] 68 | batch_num = 1 69 | # 验证 70 | print("验证开始.....") 71 | model.eval() 72 | with torch.no_grad(): 73 | for test_data in test_loader: 74 | images, labels = test_data 75 | images=[image.to(device) for image in images] 76 | labels = labels.to(device) 77 | outputs = model(images) 78 | loss = criterion(outputs, labels) 79 | test_loss += loss.item() 80 | get_num_correct(outputs, labels,test_TP_TN_FP_FN) 81 | 82 | # 开始进行训练 83 | print("训练开始.....") 84 | model.train() 85 | for batch in train_loader: 86 | # if batch_num==2: 87 | # break 88 | images, labels = batch 89 | images=[image.to(device) for image in images] 90 | labels = labels.to(device) 91 | preds = model(images) 92 | loss = criterion(preds, labels) 93 | optimizer.zero_grad() 94 | loss.backward() 95 | optimizer.step() 96 | train_loss += loss.item() 97 | get_num_correct(preds, labels,train_TP_TN_FP_FN) 98 | batch_num += 1 99 | lr_scheduler.step() 100 | 101 | train_size,test_size=sum(train_TP_TN_FP_FN),sum(test_TP_TN_FP_FN) 102 | train_acc,train_F1 = get_acc_F1(train_TP_TN_FP_FN) 103 | test_acc,test_F1 = get_acc_F1(test_TP_TN_FP_FN) 104 | print( 105 | "epoch {} num {}: \ntrain_acc: {:.2f}%\ntrain_F1: {:.2f}%\ntrain_loss: {:.2f}\ntest_acc: {:.2f}%\ntest_F1: {:.2f}%\ntest_loss: {:.2f}\n ".format( 106 | i, 107 | num, 108 | train_acc * 100, 109 | train_F1*100, 110 | train_loss / (train_size+esp), 111 | test_acc * 100, 112 | test_F1*100, 113 | test_loss / (test_size+esp), 114 | ) 115 | ) 116 | train_acculist.append(train_acc * 100) 117 | train_F1list.append(train_F1*100) 118 | train_losslist.append(train_loss/(train_size+esp)) 119 | test_acculist.append(test_acc * 100) 120 | test_F1list.append(test_F1*100) 121 | test_losslist.append(test_loss/(test_size+esp)) 122 | # writer.add_scalar("train_accu", train_correct / train_size * 100, i * 10 + num) 123 | # writer.add_scalar("train_loss", train_loss/train_size, i * 10 + num) 124 | # writer.add_scalar("test_accu", test_correct / test_size * 100, i * 10 + num) 125 | # writer.add_scalar("test_loss", test_loss/test_size, i * 10 + num) 126 | num += 1 127 | 128 | 129 | if __name__ == "__main__": 130 | root_dir = r"/media/codelearner/E2EE175BEE1726F7/Users/QuickLearner/Documents/python/graduationProject/data" 131 | 132 | # 准备数据 133 | data = get_data(root_dir) 134 | # 数据转换器 135 | data_transform = transforms.Compose( 136 | [transforms.ToPILImage(), transforms.Resize((128, 128)), transforms.ToTensor()] 137 | ) 138 | # 可视化writer 139 | writer = SummaryWriter("logs") 140 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 141 | # 定义模型 142 | model = MultCNN(device).to(device) 143 | # 定义优化器和误差函数 144 | optimizer = optim.Adam(model.parameters(), lr=0.001) 145 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8) 146 | criterion = nn.CrossEntropyLoss().to(device) 147 | 148 | 149 | epoch = 10 150 | for i in range(epoch): 151 | train_test( 152 | i, 153 | data, 154 | 4, 155 | data_transform, 156 | model, 157 | optimizer, 158 | lr_scheduler, 159 | criterion, 160 | device, 161 | writer, 162 | ) 163 | 164 | writer.close() 165 | --------------------------------------------------------------------------------