└── resnet18.py /resnet18.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu May 14 20:01:53 2020 4 | 调用resnet预训练模型进行图片分类 5 | 数据集采用hymenoptera_data 6 | 数据集下载地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip 7 | @author: 8 | """ 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | from torchvision import datasets, transforms, models 14 | import matplotlib.pyplot as plt 15 | import copy 16 | 17 | #torchvision的models中有很多与训练好的模型,如resnet、vgg、alexnet等 18 | data_dir = "./datasets/hymenoptera_data" 19 | model_name = "resnet" 20 | num_classes = 2 21 | batch_size = 32 22 | num_epochs = 8 23 | input_size = 224 24 | lr = 1e-3 25 | momentum = 0.9 26 | is_fixed = True 27 | use_pretrained = True 28 | is_train = True 29 | is_test = True 30 | 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | 33 | #验证 34 | def test(model, test_loader, loss_func): 35 | model.eval() 36 | loss_val = 0.0 37 | corrects = 0.0 38 | for images, labels in test_loader: 39 | images = images.to(device) 40 | labels = labels.to(device) 41 | with torch.no_grad(): 42 | outputs = model(images) 43 | loss = loss_func(outputs, labels) 44 | 45 | _, predicts = torch.max(outputs, 1) 46 | 47 | loss_val += loss.item() * images.size(0) 48 | corrects += torch.sum(predicts.view(-1) == labels.view(-1)).item() 49 | 50 | test_loss = loss_val / len(test_loader.dataset) 51 | test_acc = corrects / len(test_loader.dataset) 52 | 53 | print("Test Loss: {}, Test Acc: {}".format(test_loss, test_acc)) 54 | 55 | return test_acc 56 | 57 | #训练 58 | def train(model, train_loader, test_loader, loss_func, optimizer, num_epochs): 59 | #初始化最好的验证准确率 60 | best_val_acc = 0.0 61 | #初始化最好的模型参数,采用deepcopy为防止优化过程中修改到best_model_params 62 | best_model_params = copy.deepcopy(model.state_dict()) 63 | for epoch in range(num_epochs): 64 | model.train() 65 | loss_val = 0.0 66 | corrects = 0.0 67 | for images, labels in train_loader: 68 | images = images.to(device) 69 | labels = labels.to(device) 70 | 71 | outputs = model(images) 72 | loss = loss_func(outputs, labels) 73 | 74 | #找出输出的最大概率所在的为 75 | #二分类中:如果第一个样本输出的最大值出现在第0为,则其预测值为0 76 | _, predicts = torch.max(outputs, 1) 77 | 78 | optimizer.zero_grad() 79 | loss.backward() 80 | optimizer.step() 81 | 82 | #loss.item()为一个batch的平均loss的值 83 | #images.size(0)为当前batch中有多少样本量 84 | #loss.item() * images.size(0)表示一个batch的总loss值 85 | loss_val += loss.item() * images.size(0) 86 | 87 | #view(-1)表示将tensor resize成一个维度为[batch_size]的tensor 88 | #计算预测值与标签值相同的数量 89 | corrects += torch.sum(predicts.view(-1) == labels.view(-1)).item() 90 | 91 | #计算每个epoch的平均loss 92 | train_loss = loss_val / len(train_loader.dataset) 93 | #预测准确的数量除以总的样本量即为准确率 94 | train_acc = corrects / len(train_loader.dataset) 95 | 96 | print("Train Loss: {}, Train Acc: {}".format(train_loss, train_acc)) 97 | 98 | #调用测试 99 | test_acc = test(model, test_loader, loss_func) 100 | #根据测试准确率跟新最佳模型的参数 101 | if(best_val_acc < test_acc): 102 | best_val_acc = test_acc 103 | best_model_params = copy.deepcopy(model.state_dict()) 104 | #将模型的最优参数载入模型 105 | model.load_state_dict(best_model_params) 106 | return model 107 | 108 | 109 | def set_parameters_require_grad(model, is_fixed): 110 | #默认parameter.requires_grad = True 111 | #当采用固定预训练模型参数的方法进行训练时,将预训练模型的参数设置成不需要计算梯度 112 | if(is_fixed): 113 | for parameter in model.parameters(): 114 | parameter.requires_grad = False 115 | 116 | def init_model(model_name, num_classes, is_fixed, use_pretrained): 117 | if(model_name == 'resnet'): 118 | #调用resnet模型,resnet18表示18层的resnet模型, 119 | #pretrained=True表示需要加载预训练好的模型参数,pretrained=False表示不加载预训练好的模型参数 120 | model = models.resnet18(pretrained=use_pretrained) #调用预训练的resnet18模型 121 | #设置参数是否需要计算梯度 122 | #is_fixed=True表示模型参数不需要跟新(不需要计算梯度) 123 | #is_fixed=False表示模型参数需要fineturn(需要计算梯度) 124 | set_parameters_require_grad(model, is_fixed) 125 | 126 | in_features = model.fc.in_features #取出全连接层的输入特征维度 127 | 128 | #重新定义resnet18模型的全连接层,使其满足新的分类任务 129 | #此时模型的全连接层默认需要计算梯度 130 | model.fc = nn.Linear(in_features, num_classes) 131 | 132 | return model 133 | 134 | #获取数据,并对数据做预处理 135 | #该数据集已经被预处理成了可用ImageFolder处理的形式 136 | def get_datasets(data_dir, input_size, is_train_data): 137 | if(is_train_data): 138 | images = datasets.ImageFolder(os.path.join(data_dir, "train"), 139 | transforms.Compose([ 140 | transforms.RandomResizedCrop(input_size), 141 | transforms.RandomHorizontalFlip(), 142 | transforms.ToTensor() 143 | ])) 144 | else: 145 | images = datasets.ImageFolder(os.path.join(data_dir, "val"), 146 | transforms.Compose([ 147 | transforms.Resize(input_size), 148 | transforms.CenterCrop(input_size), 149 | transforms.ToTensor() 150 | ])) 151 | return images 152 | 153 | ''' 154 | #图片展示 155 | unloader = transforms.ToPILImage() 156 | 157 | def image_show(tensor): 158 | image = tensor.cpu().clone() 159 | image = image.squeeze(0) 160 | image = unloader(image) 161 | plt.figure() 162 | plt.imshow(image) 163 | ''' 164 | 165 | #获取需要更新的模型参数 166 | def get_require_updated_params(model, is_fixed): 167 | if(is_fixed): 168 | require_update_params = [] 169 | for param in model.parameters(): 170 | if(param.requires_grad): 171 | require_update_params.append(param) 172 | return require_update_params 173 | else: 174 | return model.parameters() 175 | 176 | train_images = get_datasets(data_dir, input_size, is_train_data=True) 177 | test_images = get_datasets(data_dir, input_size, is_train_data=False) 178 | 179 | train_loader = torch.utils.data.DataLoader(train_images, batch_size=batch_size, shuffle=True) 180 | test_loader = torch.utils.data.DataLoader(test_images, batch_size=batch_size) 181 | 182 | #image = next(iter(train_loader))[0] 183 | #image_show(image[1]) 184 | 185 | 186 | model = init_model(model_name, num_classes, is_fixed, use_pretrained) 187 | model = model.to(device) 188 | 189 | require_update_params = get_require_updated_params(model, is_fixed) 190 | 191 | #将需要跟新的参数放入优化器中进行优化 192 | optimizer = torch.optim.SGD(require_update_params, lr=lr, momentum=momentum) 193 | #交叉熵损失函数 194 | loss_func = nn.CrossEntropyLoss() 195 | 196 | if(is_train): 197 | model = train(model, train_loader, test_loader, loss_func, optimizer, num_epochs) 198 | torch.save(model.state_dict(),"resnet.pt") 199 | if(is_test): 200 | model.load_state_dict(torch.load("resnet.pt")) 201 | acc = test(model, test_loader, loss_func) 202 | print("Best Test Acc: {}".format(acc)) 203 | 204 | 205 | 206 | --------------------------------------------------------------------------------