├── README.md ├── main.py ├── requirements.txt └── src ├── __pycache__ ├── model.cpython-312.pyc └── train.cpython-312.pyc ├── cat_to_name.json ├── dsa ├── evaluate.py ├── model.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Flower Classification using CNN 2 | 3 | This project implements a flower classification model using classic convolutional neural network (CNN) architectures. The model is trained to classify 102 different types of flowers based on images. Pre-trained models like ResNet, VGG, and Inception are fine-tuned to achieve high accuracy in this image classification task. 4 | 5 | ## Project Overview 6 | 7 | This project utilizes transfer learning with pre-trained deep learning models for flower classification. The training process involves fine-tuning pre-trained CNN models to classify flowers into 102 categories. The goal is to train a model that can recognize various flower species with high accuracy, even when faced with new images. 8 | 9 | ### Features: 10 | - Image preprocessing and augmentation (random cropping, flipping, etc.) 11 | - Training and fine-tuning of classic CNN architectures (ResNet, VGG, Inception) 12 | - Validation and testing accuracy tracking 13 | - Save and load the best-performing model 14 | - Predict flower species for new images 15 | 16 | ## Dataset 17 | 18 | The dataset contains 102 different categories of flowers, each category having multiple images. The images are processed and standardized before being fed into the network. 19 | 20 | - **Training set**: Used to train the model. 21 | - **Validation set**: Used to validate the model during training. 22 | - **Test set**: Used to evaluate the final performance of the model. 23 | 24 | ### Preprocessing: 25 | - Resizing images to the required input size of the model. 26 | - Normalizing pixel values to [0, 1] range based on mean and standard deviation. 27 | - Applying data augmentation techniques for better generalization. 28 | 29 | ## Model Architecture 30 | 31 | We use transfer learning by leveraging pre-trained CNN architectures like: 32 | - **ResNet** 33 | - **VGG** 34 | - **Inception** 35 | 36 | The models are modified for our task by adjusting the final fully connected layers to output 102 flower categories. 37 | 38 | ## Training 39 | 40 | ### Steps: 41 | 1. Load the pre-trained model (ResNet, VGG, etc.). 42 | 2. Freeze the initial layers and fine-tune only the final layers. 43 | 3. Set up a loss function (cross-entropy loss) and an optimizer (e.g., Adam). 44 | 4. Apply a learning rate scheduler to dynamically adjust the learning rate. 45 | 5. Train the model using the training dataset, with periodic evaluation on the validation dataset. 46 | 6. Save the model with the best validation accuracy. 47 | 48 | ### Hyperparameters: 49 | - **Optimizer**: Adam or SGD 50 | - **Learning rate**: Dynamic adjustment with a scheduler 51 | - **Loss function**: Cross-entropy loss 52 | 53 | ## How to Use 54 | 55 | 1. Clone the repository: 56 | 57 | 2. Install dependencies: 58 | pip install -r requirements.txt 59 | 3. Train the model: 60 | Modify the configuration in main.py to set your dataset path and model parameters. 61 | python main.py 62 | 4. Evaluate the model: 63 | python evaluate.py 64 | 5. Predict with the trained model: 65 | You can use the saved model for prediction on new images. 66 | Results 67 | ![img.png](img.png) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import torch.optim as optim 5 | 6 | from torchvision import transforms, models, datasets 7 | 8 | 9 | from src.model import initialize_model 10 | from src.train import train_model 11 | 12 | #基于花的卷积网络模型搭建,使用resnet神经网络框架进行迁移学习 13 | #1.读取数据 14 | data_dir = './flower_data/' 15 | train_dir = data_dir + '/train' 16 | valid_dir = data_dir + '/valid' 17 | #2.数据预处理 18 | #2.1数据增强 增加原始数据的多样性和泛化能力,并将图片数据归一化 19 | data_transforms = { 20 | #训练数据 21 | 'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选 22 | transforms.CenterCrop(224),#从中心开始裁剪 23 | transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率 24 | transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转 25 | transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相 26 | transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差 29 | ]), 30 | #测试数据 31 | 'valid': transforms.Compose([transforms.Resize(256), 32 | transforms.CenterCrop(224), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]), 36 | } 37 | #2.2创建数据加载器 使用DataLoader模块 38 | batch_size=8 #批次大小 39 | #将不同的数据用ImageFolder加载并按照data_transforms的参数进行预处理后存入image_datasets 40 | image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','valid']} 41 | #构建数据加载器 42 | dataloaders={x: torch.utils.data.DataLoader(image_datasets[x],batch_size,shuffle=True) for x in ['train','valid']} 43 | datasets_sizes={x:len(image_datasets[x]) for x in ['train','valid']} 44 | 45 | model_name='resnet' #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'] 46 | #是否用别人训练好的权重 47 | feature_extract=True 48 | #是否用GPU进行训练 49 | train_on_gpu=torch.cuda.is_available() 50 | if not train_on_gpu : 51 | print('CUDA is not available. Training on CPU....') 52 | else: 53 | print('CUDA is available Train on GPU....' ) 54 | device =torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | #检查训练了哪些层是否和自己初始化的一样只有全连接层 56 | model_ft,input_size=initialize_model(model_name,102,feature_extract,use_pretrained=True) 57 | model_ft=model_ft.to(device)#将模型放在GPU、 58 | # print(model_ft) 59 | filename='CNNFlower.pth' 60 | params_to_update=model_ft.parameters() 61 | print("Params to learn:") 62 | if feature_extract: 63 | params_to_update=[] 64 | for name,param in model_ft.named_parameters(): 65 | if param.requires_grad==True: 66 | params_to_update.append(param) 67 | print("\t",name) 68 | else: 69 | for name,param in model_ft.named_parameters(): 70 | if param.requires_grad==True: 71 | print("\t",name) 72 | #优化器和损失函数设置 传入params_to_update只对需要更新的层进行更新 73 | optimizer_ft=optim.Adam(params_to_update,lr=1e-2) 74 | #每7次epoch,学习率衰减为原来的1/10 75 | scheduler=optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) 76 | #因为resnet网络最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLLoss()整合 77 | criterion=nn.NLLLoss() 78 | model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft,scheduler,device,10, model_name=="inception",filename) 79 | 80 | 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Python 版本要求 2 | python==3.12.4 3 | 4 | # PyTorch 及其依赖项 5 | torch==2.0.1+cu118 # CUDA 11.8 版本 6 | torchvision==0.15.2+cu118 # 与 PyTorch 版本匹配的 torchvision 版本 7 | torchaudio==2.0.2+cu118 # 可选的音频支持库 8 | 9 | # 或者你可以使用 CPU 版本,如果不需要 CUDA 支持: 10 | # torch==2.0.1+cpu 11 | # torchvision==0.15.2+cpu 12 | # torchaudio==2.0.2+cpu 13 | 14 | # 图像处理 15 | Pillow==9.2.0 16 | matplotlib==3.7.2 # 用于绘图 17 | numpy==1.24.2 # 数值计算库 18 | 19 | # 数据处理和增强 20 | scikit-learn==1.3.0 # 机器学习工具 21 | pandas==2.1.1 # 数据分析库 22 | -------------------------------------------------------------------------------- /src/__pycache__/model.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdwcwe/CNNClassifier/e7007dce3ab50c2974101b6136ec931fbfad6f82/src/__pycache__/model.cpython-312.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdwcwe/CNNClassifier/e7007dce3ab50c2974101b6136ec931fbfad6f82/src/__pycache__/train.cpython-312.pyc -------------------------------------------------------------------------------- /src/cat_to_name.json: -------------------------------------------------------------------------------- 1 | {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"} -------------------------------------------------------------------------------- /src/dsa: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from matplotlib import pyplot as plt 7 | from PIL import Image 8 | from torchvision.datasets import ImageFolder 9 | 10 | from src.model import initialize_model 11 | from torchvision import transforms, models, datasets 12 | import os 13 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 14 | 15 | def process_image(image_path): 16 | # 读取测试数据 17 | img = Image.open(image_path) 18 | # Resize,thumbnail方法只能进行缩小,所以进行了判断 19 | if img.size[0] > img.size[1]: 20 | img.thumbnail((10000, 256)) 21 | else: 22 | img.thumbnail((256, 10000)) 23 | # Crop操作 保证输入的大小是一致的 24 | left_margin = (img.width - 224) / 2 25 | bottom_margin = (img.height - 224) / 2 26 | right_margin = left_margin + 224 27 | top_margin = bottom_margin + 224 28 | img = img.crop((left_margin, bottom_margin, right_margin, 29 | top_margin)) 30 | # 相同的预处理方法 31 | img = np.array(img) / 255 32 | mean = np.array([0.485, 0.456, 0.406]) # provided mean 33 | std = np.array([0.229, 0.224, 0.225]) # provided std 34 | img = (img - mean) / std 35 | 36 | # 注意颜色通道应该放在第一个位置 37 | img = img.transpose((2, 0, 1)) 38 | 39 | return img 40 | 41 | 42 | def im_convert(tensor): 43 | """ 展示数据""" 44 | 45 | image = tensor.to("cpu").clone().detach() 46 | image = image.numpy().squeeze() 47 | image = image.transpose(1, 2, 0) 48 | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) 49 | image = image.clip(0, 1) 50 | 51 | return image 52 | data_dir = '../flower_data/' 53 | train_dir = data_dir + '/train' 54 | valid_dir = data_dir + '/valid' 55 | model_name='resnet' #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'] 56 | #是否用别人训练好的权重 57 | feature_extract=True 58 | #是否用GPU进行训练 59 | train_on_gpu=torch.cuda.is_available() 60 | if not train_on_gpu : 61 | print('CUDA is not available. Training on CPU....') 62 | else: 63 | print('CUDA is available Train on GPU....' ) 64 | device =torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 65 | data_transforms = { 66 | #训练数据 67 | 'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选 68 | transforms.CenterCrop(224),#从中心开始裁剪 69 | transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率 70 | transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转 71 | transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相 72 | transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B 73 | transforms.ToTensor(), 74 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差 75 | ]), 76 | #测试数据 77 | 'valid': transforms.Compose([transforms.Resize(256), 78 | transforms.CenterCrop(224), 79 | transforms.ToTensor(), 80 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 81 | ]), 82 | } 83 | #2.2创建数据加载器 使用DataLoader模块 84 | batch_size=8 #批次大小 85 | #将不同的数据用ImageFolder加载并按照data_transforms的参数进行预处理后存入image_datasets 86 | image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','valid']} 87 | #构建数据加载器 88 | dataloaders={x: torch.utils.data.DataLoader(image_datasets[x],batch_size,shuffle=True) for x in ['train','valid']} 89 | img_dataset = ImageFolder("../flower_data/train") 90 | #获取不同类别文件夹的名字 91 | class_names=img_dataset.classes 92 | #3.加载模型 93 | model, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True) 94 | 95 | # GPU模式 96 | model = model.to(device) 97 | filename="../CNNFlower.pth" 98 | # 加载模型 99 | checkpoint = torch.load(filename) 100 | best_acc = checkpoint['best_acc'] 101 | model.load_state_dict(checkpoint['state_dict'])#让模型按照训练好的权重参数加载 102 | #从数据加载器中获取数据进行预测 103 | dataiter=iter(dataloaders['valid']) 104 | images,labels=next(dataiter) 105 | model.eval() 106 | if train_on_gpu: 107 | output = model(images.cuda()) 108 | else: 109 | output = model(images) 110 | #获取预测结果 111 | _,preds_tensor=torch.max(output,1) 112 | preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy()) 113 | #根据结果画图展示 114 | fig=plt.figure(figsize=(20,20)) 115 | columns=4 116 | rows=2 117 | with open('cat_to_name.json', 'r') as f: 118 | cat_to_name = json.load(f) 119 | for idx in range(columns*rows): 120 | ax=fig.add_subplot(rows,columns,idx+1,xticks=[],yticks=[]) 121 | plt.imshow(im_convert(images[idx])) 122 | ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 123 | color=("green" if cat_to_name[str(preds[idx])] == cat_to_name[str(labels[idx].item())] else "red")) 124 | plt.show() -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | from torch import nn 3 | from torchvision import models 4 | 5 | 6 | 7 | 8 | #是否所有层的梯度都要更新 9 | def set_parameter_requires_grad(model,feature_extracting): 10 | if feature_extracting: 11 | for param in model.parameters(): 12 | param.requires_grad=False 13 | #初始化模型 14 | def initialize_model(model_name,num_classes,feature_extract,use_pretrained=True): 15 | # 选择合适的模型,不同模型的初始化方法稍微有点区别 16 | # 输入模型名字,类别数量,是否冻结权重, 17 | # 输出模型实列和输入图片尺寸 18 | model_ft = None 19 | input_size = 0 20 | 21 | if model_name == "resnet": 22 | """ Resnet152 23 | """ 24 | model_ft = models.resnet152(pretrained=use_pretrained) 25 | set_parameter_requires_grad(model_ft, feature_extract) 26 | num_ftrs = model_ft.fc.in_features 27 | model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes), 28 | nn.LogSoftmax(dim=1)) 29 | input_size = 224 30 | 31 | elif model_name == "alexnet": 32 | """ Alexnet 33 | """ 34 | model_ft = models.alexnet(pretrained=use_pretrained) 35 | set_parameter_requires_grad(model_ft, feature_extract) 36 | num_ftrs = model_ft.classifier[6].in_features 37 | model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes) 38 | input_size = 224 39 | 40 | elif model_name == "vgg": 41 | """ VGG11_bn 42 | """ 43 | model_ft = models.vgg16(pretrained=use_pretrained) 44 | set_parameter_requires_grad(model_ft, feature_extract) 45 | num_ftrs = model_ft.classifier[6].in_features 46 | model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes) 47 | input_size = 224 48 | 49 | elif model_name == "squeezenet": 50 | """ Squeezenet 51 | """ 52 | model_ft = models.squeezenet1_0(pretrained=use_pretrained) 53 | set_parameter_requires_grad(model_ft, feature_extract) 54 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) 55 | model_ft.num_classes = num_classes 56 | input_size = 224 57 | 58 | elif model_name == "densenet": 59 | """ Densenet 60 | """ 61 | model_ft = models.densenet121(pretrained=use_pretrained) 62 | set_parameter_requires_grad(model_ft, feature_extract) 63 | num_ftrs = model_ft.classifier.in_features 64 | model_ft.classifier = nn.Linear(num_ftrs, num_classes) 65 | input_size = 224 66 | 67 | elif model_name == "inception": 68 | """ Inception v3 69 | Be careful, expects (299,299) sized images and has auxiliary output 70 | """ 71 | model_ft = models.inception_v3(pretrained=use_pretrained) 72 | set_parameter_requires_grad(model_ft, feature_extract) 73 | # Handle the auxilary net 74 | num_ftrs = model_ft.AuxLogits.fc.in_features 75 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) 76 | # Handle the primary net 77 | num_ftrs = model_ft.fc.in_features 78 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 79 | input_size = 299 80 | 81 | else: 82 | print("Invalid model name, exiting...") 83 | exit() 84 | 85 | return model_ft, input_size 86 | #模型训练 87 | 88 | 89 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | import torch 4 | def train_model(model,dataloaders,criterion,optimizer,scheduler,device,num_epoch,is_inception,filename): 5 | since=time.time() 6 | best_acc=0 #最好的准确率 7 | model.to(device) 8 | val_acc_history=[] 9 | train_acc_history=[] 10 | train_losses=[] 11 | valid_losses=[] 12 | LRs=[optimizer.param_groups[0]['lr']] #获取当前学习率 13 | best_model_wts=copy.deepcopy(model.state_dict())#保存最好的模型权重参数 14 | for epoch in range(num_epoch): 15 | print('Epoch {}/{}'.format(epoch,num_epoch-1)) 16 | print('-'*10) 17 | #训练和验证 18 | for phase in ['train','valid']: 19 | if phase=='train': 20 | model.train() 21 | else: 22 | model.eval() 23 | running_loss=0.0#该阶段所有批次损失总合 24 | running_corrects=0#该阶段所有批次预测正确总数 25 | for inputs ,labels in dataloaders[phase]: 26 | inputs=inputs.to(device) 27 | labels=labels.to(device) 28 | #梯度清0 29 | optimizer.zero_grad() 30 | #如果是训练阶段则开启梯度计算,否则不开启节约内存 31 | with torch.set_grad_enabled(phase=='train'): 32 | if is_inception and phase == 'train': # 针对 Inception 模型的特殊处理。 33 | outputs, aux_outputs = model(inputs) 34 | loss1 = criterion(outputs, labels) 35 | loss2 = criterion(aux_outputs, labels) 36 | loss = loss1 + 0.4 * loss2 37 | else: 38 | outputs=model(inputs) 39 | loss=criterion(outputs,labels) 40 | _,preds=torch.max(outputs,1) 41 | if phase=='train': 42 | loss.backward() 43 | optimizer.step() 44 | running_loss+=loss.item()*inputs.size(0) 45 | running_corrects+=torch.sum(preds==labels.data) 46 | epoch_loss=running_loss/len(dataloaders[phase].dataset) 47 | epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset) 48 | time_elapsed=time.time()-since 49 | print("Time elapsed {:.0f}m {:.0f}s".format(time_elapsed//60,time_elapsed%60)) 50 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 51 | if phase=='valid' and epoch_acc>best_acc: 52 | best_acc=epoch_acc 53 | best_model_wts=copy.deepcopy(model.state_dict()) 54 | state={ 55 | 'state_dict':model.state_dict(), 56 | 'best_acc':best_acc, 57 | 'optimizer':optimizer.state_dict(), 58 | } 59 | torch.save(state,filename) 60 | if phase=='valid': 61 | val_acc_history.append(epoch_acc) 62 | valid_losses.append(epoch_loss) 63 | scheduler.step(epoch_loss) # 基于损失的变化调整学习率。 64 | if phase == 'train': 65 | train_acc_history.append(epoch_acc) 66 | train_losses.append(epoch_loss) 67 | print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr'])) 68 | LRs.append(optimizer.param_groups[0]['lr']) 69 | print() 70 | time_elapsed = time.time() - since 71 | #打印训练的总时间 72 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 73 | print('Best val Acc: {:4f}'.format(best_acc)) 74 | model.load_state_dict(best_model_wts)#通过保存的最佳参数让模型恢复到结果最好的状态 75 | return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 76 | 77 | --------------------------------------------------------------------------------