├── 1.jpg ├── 2.jpg ├── Model.py ├── README.md ├── __pycache__ ├── Model.cpython-36.pyc ├── config.cpython-36.pyc ├── datasets.cpython-36.pyc └── test.cpython-36.pyc ├── config.py ├── dandelion'.jpg ├── datasets.py ├── example ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg └── dandelion'.jpg ├── input_data.rar ├── logs ├── Vgg.txt └── myModel.txt ├── predict_gui.py ├── resnet.txt ├── test.py ├── train.py └── utils ├── __pycache__ └── utils.cpython-36.pyc └── utils.py /1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/1.jpg -------------------------------------------------------------------------------- /2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/2.jpg -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from torch import nn,optim 5 | from config import config 6 | 7 | 8 | 9 | def l2_norm(x): 10 | norm = torch.norm(x,p =2 ,dim =1 ,keepdim= True) 11 | x = torch.div(x,norm) 12 | return x 13 | # 自己搭建一个简单卷积神经网络 14 | class myModel(nn.Module): 15 | def __init__(self,num_classes): 16 | super(myModel,self).__init__() 17 | self.layer1 = nn.Sequential( 18 | nn.Conv2d(3,16,3), #in_channels out_channels kernel_size 19 | nn.BatchNorm2d(16), 20 | nn.ReLU(True), 21 | nn.MaxPool2d(kernel_size= 2,stride = 2) #149 22 | ) 23 | self.layer2 = nn.Sequential( 24 | nn.Conv2d(16,32,3,2), #74 # 25 | nn.BatchNorm2d(32), 26 | nn.ReLU(True), 27 | nn.MaxPool2d(kernel_size =2,stride=2) #37 28 | 29 | ) 30 | self.layer3 = nn.Sequential( 31 | nn.Conv2d(32,32,3,2), #18 32 | nn.BatchNorm2d(32), 33 | nn.ReLU(True), 34 | nn.MaxPool2d(kernel_size= 2, stride = 2) #9 35 | ) 36 | self.fc1 = nn.Sequential( 37 | nn.Linear(2592,120), 38 | nn.ReLU(True) 39 | ) 40 | self.fc2 = nn.Sequential( 41 | nn.Linear(120,84), 42 | nn.ReLU(True), 43 | nn.Linear(84,num_classes) 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.layer1(x) 48 | x = self.layer2(x) 49 | x = self.layer3(x) 50 | x = x.view(x.size(0),-1) 51 | x = self.fc1(x) 52 | x = self.fc2(x) 53 | return x 54 | 55 | 56 | 57 | class ResNet18(nn.Module): 58 | def __init__(self,model,num_classes = 1000): 59 | super(ResNet18,self).__init__() 60 | self.backbone = model 61 | 62 | self.fc1 = nn.Linear(512,1024) 63 | self.dropout = nn.Dropout(0.5) 64 | self.fc2 = nn.Linear(1024,num_classes) 65 | def forward(self, x): 66 | x = self.backbone.conv1(x) 67 | x= self.backbone.bn1 (x) 68 | x = self.backbone.relu(x) 69 | x= self.backbone.maxpool(x) 70 | 71 | x= self.backbone.layer1(x) 72 | x = self.backbone.layer2(x) 73 | x= self.backbone.layer3(x) 74 | x = self.backbone.layer4(x) 75 | 76 | x = self.backbone.avgpool(x) 77 | 78 | x= x.view(x.size(0),-1) 79 | x= l2_norm(x) 80 | x = self.dropout(x) 81 | x = self.fc1(x) 82 | x = l2_norm(x) 83 | x = self.dropout(x) 84 | x = self.fc2(x) 85 | return x 86 | class ResNet101(nn.Module): 87 | def __init__(self,model,num_classes =1000): 88 | super(ResNet101,self).__init__() 89 | self.backbone = model 90 | 91 | self.fc1 = nn.Linear(2048,2048) 92 | self.dropout = nn.Dropout(0.5) 93 | self.fc2 = nn.Linear(2048,num_classes) 94 | 95 | def forward(self,x): 96 | x = self.backbone.conv1(x) 97 | x = self.backbone.bn1(x) 98 | x = self.backbone.relu(x) 99 | x = self.backbone.maxpool(x) 100 | 101 | x = self.backbone.layer1(x) 102 | x = self.backbone.layer2(x) 103 | x = self.backbone.layer3(x) 104 | x = self.backbone.layer4(x) 105 | 106 | x = self.backbone.avgpool(x) 107 | 108 | x = x.view(x.size(0),-1) 109 | x = l2_norm(x) 110 | x = self.dropout(x) 111 | x = self.fc1(x) 112 | x = l2_norm(x) 113 | x = self.dropout(x) 114 | x = self.fc2(x) 115 | 116 | return x 117 | 118 | def get_net(): 119 | #backbone = torchvision.models.resnet18(pretrained=True) 120 | #models = ResNet18(backbone,config.num_classes) 121 | backbone = torchvision.models.resnet101(pretrained=True) 122 | models = ResNet101(backbone, config.num_classes) 123 | return models -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-image-classification 2 | 1、项目介绍: 3 | -------- 4 | 适合小白入门的图像分类项目,从熟悉到熟练图像分类的流程,搭建自己的分类网络结构以及在 pytorch 中运用经典的分类网络。利用gui图形化界面进行测试单张图片。代码注释清楚,很容易理解。详细介绍可以访问[`我的博客`](https://blog.csdn.net/weixin_43962659/article/details/103381731) 5 | 6 | 2、环境: 7 | ----- 8 | * pytorch 1.2.0 9 | * python3 以上 10 | * wxpython :安装方式 conda install wxpython 11 | * opencv-python 12 | 13 | 3、数据准备: 14 | ----- 15 | 下载数据集即四类花的分类,然后解压放到文件夹data里。 16 | 文件夹树结构: 17 | * ./pytorch-image-classification 18 | * data 19 | * input_data 20 | * example 21 | * checkpoints 22 | * logs 23 | * utils 24 | * train.py Model.py READEME.md 等根目录的文件 25 | 26 | 4、快速开始: 27 | --------- 28 | 下载本项目:git clone https://github.com/lilei1128/pytorch-image-classification.git 29 | 修改config.py中设置的路径和其他参数。 30 | 运行train.py进行训练,也可自行修改model中的网络结构再训练。 31 | 最后进行测试。 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /__pycache__/Model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/__pycache__/Model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #可以根据自己的情况进行修改 2 | class MyConfigs(): 3 | 4 | data_folder = './data/flowers/' 5 | test_data_folder = "" 6 | model_name = "resnet" #Vgg ResNet152 myModel 7 | weights = "./checkpoints/" 8 | logs = "./logs/" 9 | example_folder = "./example/" 10 | freeze = True 11 | # 12 | epochs = 300 13 | batch_size = 16 14 | img_height = 227 #网络输入的高和宽 15 | img_width = 227 16 | num_classes = 20 17 | lr = 1e-2 18 | lr_decay = 1e-4 19 | weight_decay = 2e-4 20 | ratio = 0.2 21 | config = MyConfigs() 22 | -------------------------------------------------------------------------------- /dandelion'.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/dandelion'.jpg -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset,DataLoader 3 | import cv2 4 | import os 5 | from tqdm import tqdm 6 | from config import config 7 | from glob import glob 8 | import os 9 | from torchvision import transforms 10 | import numpy as np 11 | import random 12 | from shutil import copy 13 | from PIL import Image 14 | import math 15 | 16 | np.random.seed(666) #设置随机种子 为了保证每次划分训练集和测试机的是相同的 17 | 18 | 19 | ''' 20 | # 1. 对于mini_data 数据集的解析 21 | def parse_data_config(data_path): 22 | files = [] 23 | # 24 | for img in os.listdir(data_path): 25 | image = data_path + img 26 | label = img.split("__")[0][3:] 27 | files.append((image,label)) 28 | return files 29 | 30 | #划分训练集和测试集 31 | # ratio 为划分为测试集的比例 32 | def divide_data(data_path,ratio): 33 | files = parse_data_config(data_path) 34 | temp = np.array(files) 35 | test_data = [] 36 | train_data = [] 37 | for i in range(config.num_classes): 38 | temp_data = [] 39 | for data in temp: 40 | if data[1] == str(i): 41 | temp_data.append(data) 42 | np.random.shuffle(np.array(temp_data)) 43 | test_data =test_data + temp_data[:int(ratio * len(temp_data))] 44 | train_data = train_data + temp_data[int(ratio*len(temp_data))+1:] 45 | # np.random.shuffle(temp) 46 | # test_data = files[:int(ratio * len(files))] 47 | # train_data = files[int(ratio*len(files))+1:] 48 | 49 | # 从训练集中挑选 10 中图片保存到 example 文件夹中 50 | if not os.path.exists(config.example_folder): 51 | os.mkdir(config.example_folder) 52 | else: 53 | for i in os.listdir(config.example_folder): 54 | os.remove(os.path.join(config.example_folder+i)) 55 | for i in range(10): 56 | index = random.randint(0,len(test_data)-1) # 随机生成图片的索引 57 | copy(test_data[index][0],config.example_folder) # 将挑选的图像复制到example文件夹 58 | 59 | return test_data, train_data 60 | ''' 61 | # 2. 对于flowers 数据集的解析 62 | def get_files(file_dir,ratio): 63 | roses = [] 64 | labels_roses = [] 65 | tulips = [] 66 | labels_tulips = [] 67 | dandelion = [] 68 | labels_dandelion=[] 69 | sunflowers = [] 70 | labels_sunflowers = [] 71 | for file in os.listdir(file_dir +'roses'): 72 | roses.append(file_dir + 'roses' + '/' + file) 73 | labels_roses.append(0) 74 | for file in os.listdir(file_dir + 'tulips'): 75 | tulips.append(file_dir + 'tulips' + '/' + file) 76 | labels_tulips.append(1) 77 | for file in os.listdir(file_dir + 'dandelion'): 78 | tulips.append(file_dir + 'dandelion' + '/' +file) 79 | labels_dandelion.append(2) 80 | for file in os.listdir(file_dir + 'sunflowers'): 81 | sunflowers.append(file_dir + 'sunflowers' + '/' +file) 82 | labels_sunflowers.append(3) 83 | 84 | image_list = np.hstack((roses ,tulips, dandelion, sunflowers)) 85 | labels_list = np.hstack((labels_roses, labels_tulips, labels_dandelion, labels_sunflowers)) 86 | temp = np.array([image_list, labels_list]) 87 | temp = temp.transpose() 88 | np.random.shuffle(temp) 89 | all_image_list = list(temp[:,0]) 90 | all_label_list = list(temp[:,1]) 91 | all_label_list = [int(i) for i in all_label_list] 92 | length = len(all_image_list) 93 | n_test = int(math.ceil(length * ratio)) 94 | n_train = length - n_test 95 | 96 | tra_image = all_image_list[0:n_train] 97 | tra_label = all_label_list[0:n_train] 98 | 99 | test_image = all_image_list[n_train:-1] 100 | test_label = all_label_list[n_train:-1] 101 | 102 | train_data = [(tra_image[i],tra_label[i]) for i in range(len(tra_image))] 103 | test_data = [(test_image[i],test_label[i]) for i in range(len(test_image))] 104 | # print("train_data = ",test_image) 105 | # print("test_data = " , test_label) 106 | return test_data,train_data 107 | 108 | #这个数据集类的作用就是加载训练和测试时的数据 109 | class datasets(Dataset): 110 | def __init__(self,data,transform = None,test = False): 111 | imgs = [] 112 | labels = [] 113 | self.test = test 114 | self.len = len(data) 115 | self.data = data 116 | self.transform = transform 117 | for i in self.data: 118 | imgs.append(i[0]) 119 | self.imgs = imgs 120 | labels.append(int(i[1]) ) #pytorch中交叉熵需要从0开始 121 | self.labels = labels 122 | def __getitem__(self,index): 123 | if self.test: 124 | filename = self.imgs[index] 125 | filename = filename 126 | img_path = self.imgs[index] 127 | img = cv2.imread(img_path) 128 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 129 | img = cv2.resize(img, (config.img_width, config.img_height)) 130 | img = transforms.ToTensor()(img) 131 | return img,filename 132 | else: 133 | img_path = self.imgs[index] 134 | label = self.labels[index] 135 | #label = int(label) 136 | img = cv2.imread(img_path) 137 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 138 | img = cv2.resize(img,(config.img_width,config.img_height)) 139 | # img = transforms.ToTensor()(img) 140 | 141 | if self.transform is not None: 142 | img = Image.fromarray(img) 143 | img = self.transform(img) 144 | 145 | else: 146 | img = transforms.ToTensor()(img) 147 | return img,label 148 | 149 | def __len__(self): 150 | return len(self.data)#self.len 151 | 152 | def collate_fn(batch): #表示如何将多个样本拼接成一个batch 153 | imgs = [] 154 | label = [] 155 | for sample in batch: 156 | imgs.append(sample[0]) 157 | label.append(sample[1]) 158 | 159 | return torch.stack(imgs, 0),label 160 | 161 | 162 | #用于调试代码 163 | if __name__ == '__main__': 164 | test_data,_ = get_files(config.data_folder,0.2) 165 | for i in (test_data): 166 | print(i) 167 | print(len(test_data)) 168 | 169 | transform = transforms.Compose([transforms.ToTensor()]) 170 | data = datasets(test_data,transform = transform) 171 | #print(data[0]) -------------------------------------------------------------------------------- /example/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/example/1.jpg -------------------------------------------------------------------------------- /example/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/example/2.jpg -------------------------------------------------------------------------------- /example/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/example/3.jpg -------------------------------------------------------------------------------- /example/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/example/4.jpg -------------------------------------------------------------------------------- /example/dandelion'.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/example/dandelion'.jpg -------------------------------------------------------------------------------- /input_data.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/input_data.rar -------------------------------------------------------------------------------- /logs/Vgg.txt: -------------------------------------------------------------------------------- 1 | Get Better top1 : 5.147058823529412 saving weights to ./checkpoints/Vgg.pth 2 | Get Better top1 : 12.867647058823529 saving weights to ./checkpoints/Vgg.pth 3 | Get Better top1 : 15.073529411764707 saving weights to ./checkpoints/Vgg.pth 4 | Get Better top1 : 7.352941176470588 saving weights to ./checkpoints/Vgg.pth 5 | Get Better top1 : 15.441176470588236 saving weights to ./checkpoints/Vgg.pth 6 | Get Better top1 : 16.544117647058822 saving weights to ./checkpoints/Vgg.pth 7 | Get Better top1 : 17.647058823529413 saving weights to ./checkpoints/Vgg.pth 8 | Get Better top1 : 37.86764705882353 saving weights to ./checkpoints/Vgg.pth 9 | Get Better top1 : 38.23529411764706 saving weights to ./checkpoints/Vgg.pth 10 | Get Better top1 : 42.64705882352941 saving weights to ./checkpoints/Vgg.pth 11 | Get Better top1 : 43.01470588235294 saving weights to ./checkpoints/Vgg.pth 12 | Get Better top1 : 5.147058823529412 saving weights to ./checkpoints/Vgg.pth 13 | Get Better top1 : 6.25 saving weights to ./checkpoints/Vgg.pth 14 | Get Better top1 : 10.294117647058824 saving weights to ./checkpoints/Vgg.pth 15 | Get Better top1 : 13.970588235294118 saving weights to ./checkpoints/Vgg.pth 16 | Get Better top1 : 22.794117647058822 saving weights to ./checkpoints/Vgg.pth 17 | Get Better top1 : 23.897058823529413 saving weights to ./checkpoints/Vgg.pth 18 | Get Better top1 : 32.35294117647059 saving weights to ./checkpoints/Vgg.pth 19 | Get Better top1 : 33.455882352941174 saving weights to ./checkpoints/Vgg.pth 20 | Get Better top1 : 40.80882352941177 saving weights to ./checkpoints/Vgg.pth 21 | Get Better top1 : 47.794117647058826 saving weights to ./checkpoints/Vgg.pth 22 | Get Better top1 : 61.39705882352941 saving weights to ./checkpoints/Vgg.pth 23 | Get Better top1 : 63.60294117647059 saving weights to ./checkpoints/Vgg.pth 24 | Get Better top1 : 5.147058823529412 saving weights to ./checkpoints/Vgg.pth 25 | Get Better top1 : 6.985294117647059 saving weights to ./checkpoints/Vgg.pth 26 | Get Better top1 : 7.352941176470588 saving weights to ./checkpoints/Vgg.pth 27 | Get Better top1 : 20.955882352941178 saving weights to ./checkpoints/Vgg.pth 28 | Get Better top1 : 22.426470588235293 saving weights to ./checkpoints/Vgg.pth 29 | Get Better top1 : 31.61764705882353 saving weights to ./checkpoints/Vgg.pth 30 | Get Better top1 : 32.720588235294116 saving weights to ./checkpoints/Vgg.pth 31 | Get Better top1 : 39.705882352941174 saving weights to ./checkpoints/Vgg.pth 32 | Get Better top1 : 11.764705882352942 saving weights to ./checkpoints/Vgg.pth 33 | Get Better top1 : 30.514705882352942 saving weights to ./checkpoints/Vgg.pth 34 | Get Better top1 : 31.985294117647058 saving weights to ./checkpoints/Vgg.pth 35 | Get Better top1 : 34.19117647058823 saving weights to ./checkpoints/Vgg.pth 36 | Get Better top1 : 44.48529411764706 saving weights to ./checkpoints/Vgg.pth 37 | Get Better top1 : 57.35294117647059 saving weights to ./checkpoints/Vgg.pth 38 | Get Better top1 : 58.088235294117645 saving weights to ./checkpoints/Vgg.pth 39 | Get Better top1 : 63.970588235294116 saving weights to ./checkpoints/Vgg.pth 40 | Get Better top1 : 10.294117647058824 saving weights to ./checkpoints/Vgg.pth 41 | Get Better top1 : 22.426470588235293 saving weights to ./checkpoints/Vgg.pth 42 | Get Better top1 : 40.0735294117647 saving weights to ./checkpoints/Vgg.pth 43 | Get Better top1 : 47.794117647058826 saving weights to ./checkpoints/Vgg.pth 44 | Get Better top1 : 5.147058823529412 saving weights to ./checkpoints/Vgg.pth 45 | Get Better top1 : 15.808823529411764 saving weights to ./checkpoints/Vgg.pth 46 | Get Better top1 : 21.323529411764707 saving weights to ./checkpoints/Vgg.pth 47 | Get Better top1 : 12.867647058823529 saving weights to ./checkpoints/Vgg.pth 48 | Get Better top1 : 29.044117647058822 saving weights to ./checkpoints/Vgg.pth 49 | Get Better top1 : 47.05882352941177 saving weights to ./checkpoints/Vgg.pth 50 | Get Better top1 : 34.57236842105263 saving weights to ./checkpoints/Vgg.pth 51 | Get Better top1 : 41.151315789473685 saving weights to ./checkpoints/Vgg.pth 52 | Get Better top1 : 51.546052631578945 saving weights to ./checkpoints/Vgg.pth 53 | Get Better top1 : 26.666666658301104 saving weights to ./checkpoints/Vgg.pth 54 | Get Better top1 : 30.471491211339046 saving weights to ./checkpoints/Vgg.pth 55 | Get Better top1 : 37.54385963239168 saving weights to ./checkpoints/Vgg.pth 56 | Get Better top1 : 45.94298242267809 saving weights to ./checkpoints/Vgg.pth 57 | Get Better top1 : 46.75438594818115 saving weights to ./checkpoints/Vgg.pth 58 | Get Better top1 : 63.38815789473684 saving weights to ./checkpoints/Vgg.pth 59 | Get Better top1 : 65.84429821215178 saving weights to ./checkpoints/Vgg.pth 60 | Get Better top1 : 66.51315789473684 saving weights to ./checkpoints/Vgg.pth 61 | Get Better top1 : 66.84210526315789 saving weights to ./checkpoints/Vgg.pth 62 | Get Better top1 : 69.14473684210526 saving weights to ./checkpoints/Vgg.pth 63 | Get Better top1 : 73.75 saving weights to ./checkpoints/Vgg.pth 64 | Get Better top1 : 27.160087710932682 saving weights to ./checkpoints/Vgg.pth 65 | Get Better top1 : 40.51535084373072 saving weights to ./checkpoints/Vgg.pth 66 | Get Better top1 : 49.91228063482987 saving weights to ./checkpoints/Vgg.pth 67 | Get Better top1 : 53.83771926478336 saving weights to ./checkpoints/Vgg.pth 68 | Get Better top1 : 53.848684210526315 saving weights to ./checkpoints/Vgg.pth 69 | Get Better top1 : 54.66008768583599 saving weights to ./checkpoints/Vgg.pth 70 | Get Better top1 : 57.64254379272461 saving weights to ./checkpoints/Vgg.pth 71 | Get Better top1 : 58.60745610688862 saving weights to ./checkpoints/Vgg.pth 72 | Get Better top1 : 68.01535084373073 saving weights to ./checkpoints/Vgg.pth 73 | Get Better top1 : 71.12938589798777 saving weights to ./checkpoints/Vgg.pth 74 | Get Better top1 : 72.44517537167198 saving weights to ./checkpoints/Vgg.pth 75 | Get Better top1 : 72.45614031741493 saving weights to ./checkpoints/Vgg.pth 76 | Get Better top1 : 73.45394736842105 saving weights to ./checkpoints/Vgg.pth 77 | Get Better top1 : 77.07236842105263 saving weights to ./checkpoints/Vgg.pth 78 | Get Better top1 : 28.333333316602204 saving weights to ./checkpoints/Vgg.pth 79 | Get Better top1 : 33.77192979109915 saving weights to ./checkpoints/Vgg.pth 80 | Get Better top1 : 37.36842105263158 saving weights to ./checkpoints/Vgg.pth 81 | Get Better top1 : 37.37938594818115 saving weights to ./checkpoints/Vgg.pth 82 | Get Better top1 : 38.026315789473685 saving weights to ./checkpoints/Vgg.pth 83 | Get Better top1 : 38.85964910607589 saving weights to ./checkpoints/Vgg.pth 84 | Get Better top1 : 39.18859647449694 saving weights to ./checkpoints/Vgg.pth 85 | Get Better top1 : 39.68201752712852 saving weights to ./checkpoints/Vgg.pth 86 | Get Better top1 : 40.997807000812735 saving weights to ./checkpoints/Vgg.pth 87 | Get Better top1 : 41.820175421865365 saving weights to ./checkpoints/Vgg.pth 88 | Get Better top1 : 26.666666658301104 saving weights to ./checkpoints/Vgg.pth 89 | Get Better top1 : 35.89912279028641 saving weights to ./checkpoints/Vgg.pth 90 | Get Better top1 : 48.56359647449694 saving weights to ./checkpoints/Vgg.pth 91 | Get Better top1 : 48.92543852956671 saving weights to ./checkpoints/Vgg.pth 92 | Get Better top1 : 56.151315789473685 saving weights to ./checkpoints/Vgg.pth 93 | Get Better top1 : 66.05263157894737 saving weights to ./checkpoints/Vgg.pth 94 | Get Better top1 : 67.00657894736842 saving weights to ./checkpoints/Vgg.pth 95 | Get Better top1 : 67.17105263157895 saving weights to ./checkpoints/Vgg.pth 96 | Get Better top1 : 68.32236842105263 saving weights to ./checkpoints/Vgg.pth 97 | Get Better top1 : 69.96710526315789 saving weights to ./checkpoints/Vgg.pth 98 | Get Better top1 : 71.9517543190404 saving weights to ./checkpoints/Vgg.pth 99 | Get Better top1 : 73.76096484535618 saving weights to ./checkpoints/Vgg.pth 100 | Get Better top1 : 74.27631578947368 saving weights to ./checkpoints/Vgg.pth 101 | Get Better top1 : 71.6337718963623 saving weights to ./checkpoints/Vgg.pth 102 | Get Better top1 : 75.43859642430355 saving weights to ./checkpoints/Vgg.pth 103 | Get Better top1 : 77.5767543190404 saving weights to ./checkpoints/Vgg.pth 104 | Get Better top1 : 36.20614034251163 saving weights to ./checkpoints/Vgg.pth 105 | Get Better top1 : 26.666666658301104 saving weights to ./checkpoints/Vgg.pth 106 | Get Better top1 : 36.403508738467565 saving weights to ./checkpoints/Vgg.pth 107 | Get Better top1 : 40.19736842105263 saving weights to ./checkpoints/Vgg.pth 108 | Get Better top1 : 45.767543842917995 saving weights to ./checkpoints/Vgg.pth 109 | Get Better top1 : 46.94078947368421 saving weights to ./checkpoints/Vgg.pth 110 | Get Better top1 : 50.86622805344431 saving weights to ./checkpoints/Vgg.pth 111 | Get Better top1 : 55.67982452794125 saving weights to ./checkpoints/Vgg.pth 112 | Get Better top1 : 62.291666532817636 saving weights to ./checkpoints/Vgg.pth 113 | Get Better top1 : 65.25219284860711 saving weights to ./checkpoints/Vgg.pth 114 | Get Better top1 : 65.59210526315789 saving weights to ./checkpoints/Vgg.pth 115 | Get Better top1 : 66.07456126965974 saving weights to ./checkpoints/Vgg.pth 116 | Get Better top1 : 26.666666658301104 saving weights to ./checkpoints/Vgg.pth 117 | Get Better top1 : 31.282894736842106 saving weights to ./checkpoints/Vgg.pth 118 | Get Better top1 : 38.04824558057283 saving weights to ./checkpoints/Vgg.pth 119 | Get Better top1 : 39.85745610688862 saving weights to ./checkpoints/Vgg.pth 120 | Get Better top1 : 42.00657894736842 saving weights to ./checkpoints/Vgg.pth 121 | Get Better top1 : 42.01754379272461 saving weights to ./checkpoints/Vgg.pth 122 | Get Better top1 : 51.56798242267809 saving weights to ./checkpoints/Vgg.pth 123 | Get Better top1 : 56.688596424303554 saving weights to ./checkpoints/Vgg.pth 124 | Get Better top1 : 60.26315789473684 saving weights to ./checkpoints/Vgg.pth 125 | Get Better top1 : 62.401315789473685 saving weights to ./checkpoints/Vgg.pth 126 | Get Better top1 : 62.741228003250924 saving weights to ./checkpoints/Vgg.pth 127 | Get Better top1 : 64.90131578947368 saving weights to ./checkpoints/Vgg.pth 128 | Get Better top1 : 66.05263157894737 saving weights to ./checkpoints/Vgg.pth 129 | Get Better top1 : 66.53508768583599 saving weights to ./checkpoints/Vgg.pth 130 | Get Better top1 : 67.86184210526316 saving weights to ./checkpoints/Vgg.pth 131 | Get Better top1 : 68.02631578947368 saving weights to ./checkpoints/Vgg.pth 132 | Get Better top1 : 68.33333326640881 saving weights to ./checkpoints/Vgg.pth 133 | Get Better top1 : 69.17763157894737 saving weights to ./checkpoints/Vgg.pth 134 | Get Better top1 : 69.34210526315789 saving weights to ./checkpoints/Vgg.pth 135 | Get Better top1 : 69.47368421052632 saving weights to ./checkpoints/Vgg.pth 136 | Get Better top1 : 70.3179824226781 saving weights to ./checkpoints/Vgg.pth 137 | Get Better top1 : 71.3157894736842 saving weights to ./checkpoints/Vgg.pth 138 | Get Better top1 : 71.83114021702816 saving weights to ./checkpoints/Vgg.pth 139 | Get Better top1 : 71.96271926478336 saving weights to ./checkpoints/Vgg.pth 140 | Get Better top1 : 68.65131578947368 saving weights to ./checkpoints/Vgg.pth 141 | Get Better top1 : 69.32017537167198 saving weights to ./checkpoints/Vgg.pth 142 | Get Better top1 : 70.14254379272461 saving weights to ./checkpoints/Vgg.pth 143 | Get Better top1 : 71.14035084373073 saving weights to ./checkpoints/Vgg.pth 144 | Get Better top1 : 71.30482452794125 saving weights to ./checkpoints/Vgg.pth 145 | Get Better top1 : 72.13815789473684 saving weights to ./checkpoints/Vgg.pth 146 | Get Better top1 : 26.666666658301104 saving weights to ./checkpoints/Vgg.pth 147 | Get Better top1 : 31.94078947368421 saving weights to ./checkpoints/Vgg.pth 148 | Get Better top1 : 42.5 saving weights to ./checkpoints/Vgg.pth 149 | Get Better top1 : 47.60964905588251 saving weights to ./checkpoints/Vgg.pth 150 | Get Better top1 : 52.5328947368421 saving weights to ./checkpoints/Vgg.pth 151 | Get Better top1 : 53.366228003250924 saving weights to ./checkpoints/Vgg.pth 152 | Get Better top1 : 57.971491161145664 saving weights to ./checkpoints/Vgg.pth 153 | Get Better top1 : 59.956140317414935 saving weights to ./checkpoints/Vgg.pth 154 | Get Better top1 : 62.92763157894737 saving weights to ./checkpoints/Vgg.pth 155 | Get Better top1 : 65.88815789473684 saving weights to ./checkpoints/Vgg.pth 156 | Get Better top1 : 67.36842105263158 saving weights to ./checkpoints/Vgg.pth 157 | Get Better top1 : 68.83771926478336 saving weights to ./checkpoints/Vgg.pth 158 | Get Better top1 : 69.97807010851409 saving weights to ./checkpoints/Vgg.pth 159 | Get Better top1 : 70.15350873846756 saving weights to ./checkpoints/Vgg.pth 160 | Get Better top1 : 72.96052631578948 saving weights to ./checkpoints/Vgg.pth 161 | Get Better top1 : 74.11184210526316 saving weights to ./checkpoints/Vgg.pth 162 | Get Better top1 : 75.75657894736842 saving weights to ./checkpoints/Vgg.pth 163 | Get Better top1 : 76.58991221377724 saving weights to ./checkpoints/Vgg.pth 164 | Get Better top1 : 77.5657894736842 saving weights to ./checkpoints/Vgg.pth 165 | Get Better top1 : 78.07017537167198 saving weights to ./checkpoints/Vgg.pth 166 | Get Better top1 : 78.39912274009303 saving weights to ./checkpoints/Vgg.pth 167 | -------------------------------------------------------------------------------- /logs/myModel.txt: -------------------------------------------------------------------------------- 1 | Get Better top1 : [4.6296296] saving weights to ./checkpoints/myModel.pth 2 | Get Better top1 : [15.046296] saving weights to ./checkpoints/myModel.pth 3 | Get Better top1 : 2.7777777777777777 saving weights to ./checkpoints/myModel.pth 4 | Get Better top1 : 12.847222222222221 saving weights to ./checkpoints/myModel.pth 5 | Get Better top1 : 8.333333333333334 saving weights to ./checkpoints/myModel.pth 6 | Get Better top1 : 9.375 saving weights to ./checkpoints/myModel.pth 7 | -------------------------------------------------------------------------------- /predict_gui.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from matplotlib import pyplot as plt 3 | import cv2 4 | import wx #图形化界面 conda install wxpython 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import Model 10 | from test import * 11 | class ClassificationFrame(wx.Frame): 12 | 13 | def __init__(self,*args,**kw): 14 | super(ClassificationFrame,self).__init__(*args,**kw) 15 | pnl = wx.Panel(self) 16 | self.pnl = pnl 17 | st = wx.StaticText(pnl,label = "花朵识别",pos = (200,0)) 18 | font = st.GetFont() 19 | font.PointSize+=10 20 | font = font.Bold() 21 | st.SetFont(font) 22 | 23 | btn = wx.Button(pnl,-1,'select') 24 | btn.Bind(wx.EVT_BUTTON,self.OnSelect) 25 | 26 | self.makeMenuBar() 27 | 28 | self.CreateStatusBar() 29 | self.SetStatusText("Welcome to flowers world") 30 | 31 | def makeMenuBar(self): 32 | fileMenu = wx.Menu() 33 | helloItem = fileMenu.Append(-1, "&Hello...\tCtrl-H","Help string shown in status bar for this menu item") 34 | fileMenu.AppendSeparator() 35 | exitItem = fileMenu.Append(wx.ID_EXIT) 36 | helpMenu = wx.Menu() 37 | aboutItem = helpMenu.Append(wx.ID_ABOUT) 38 | 39 | menuBar = wx.MenuBar() 40 | menuBar.Append(fileMenu, "&File") 41 | menuBar.Append(helpMenu, "Help") 42 | 43 | self.SetMenuBar(menuBar) 44 | 45 | self.Bind(wx.EVT_MENU, self.OnHello, helloItem) 46 | self.Bind(wx.EVT_MENU, self.OnExit, exitItem) 47 | self.Bind(wx.EVT_MENU, self.OnAbout, aboutItem) 48 | 49 | 50 | def OnExit(self, event): 51 | self.Close(True) 52 | 53 | def OnHello(self, event): 54 | wx.MessageBox("Hello again from wxPython") 55 | 56 | def OnAbout(self, event): 57 | """Display an About Dialog""" 58 | wx.MessageBox("This is a wxPython Hello World sample", 59 | "About Hello World 2", 60 | wx.OK | wx.ICON_INFORMATION) 61 | 62 | def OnSelect(self, event): 63 | wildcard = "image source(*.jpg)|*.jpg|" \ 64 | "Compile Python(*.pyc)|*.pyc|" \ 65 | "All file(*.*)|*.*" 66 | dialog = wx.FileDialog(None, "Choose a file", os.getcwd(), 67 | "", wildcard, wx.ID_OPEN) 68 | if dialog.ShowModal() == wx.ID_OK: 69 | model = Model.get_net() 70 | checkpoint = torch.load(config.weights + config.model_name + '.pth') 71 | model.load_state_dict(checkpoint["state_dict"]) 72 | print(dialog.GetPath()) 73 | img = cv2.imread(dialog.GetPath()) 74 | 75 | result = test_one_image(img,model) 76 | result_text = wx.StaticText(self.pnl, label=result, pos=(320, 0)) 77 | font = result_text.GetFont() 78 | font.PointSize += 8 79 | result_text.SetFont(font) 80 | self.initimage(name= dialog.GetPath()) 81 | 82 | # 生成图片控件 83 | def initimage(self, name): 84 | imageShow = wx.Image(name, wx.BITMAP_TYPE_ANY) 85 | sb = wx.StaticBitmap(self.pnl, -1, imageShow.ConvertToBitmap(), pos=(200,100), size=(400,400)) 86 | return sb 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | app = wx.App() 92 | frm = ClassificationFrame(None, title='flower wolrd', size=(1000,600)) 93 | frm.Show() 94 | app.MainLoop() -------------------------------------------------------------------------------- /resnet.txt: -------------------------------------------------------------------------------- 1 | ./data/flowers/dandelion/1515samples2.jpg, dandelion 2 | ./data/flowers/dandelion/2308samples2.jpg, dandelion 3 | ./data/flowers/dandelion/2099samples2.jpg, roses 4 | ./data/flowers/sunflowers/2657samples3.jpg, sunflowers 5 | ./data/flowers/tulips/1384samples1.jpg, tulips 6 | ./data/flowers/dandelion/2212samples2.jpg, dandelion 7 | ./data/flowers/roses/615samples0.jpg, tulips 8 | ./data/flowers/dandelion/1736samples2.jpg, dandelion 9 | ./data/flowers/roses/147samples0.jpg, sunflowers 10 | ./data/flowers/roses/152samples0.jpg, roses 11 | ./data/flowers/tulips/779samples1.jpg, roses 12 | ./data/flowers/tulips/997samples1.jpg, sunflowers 13 | ./data/flowers/roses/44samples0.jpg, dandelion 14 | ./data/flowers/dandelion/1498samples2.jpg, dandelion 15 | ./data/flowers/tulips/1390samples1.jpg, dandelion 16 | ./data/flowers/dandelion/1521samples2.jpg, dandelion 17 | ./data/flowers/dandelion/2307samples2.jpg, dandelion 18 | ./data/flowers/roses/7samples0.jpg, sunflowers 19 | ./data/flowers/dandelion/1526samples2.jpg, dandelion 20 | ./data/flowers/sunflowers/2575samples3.jpg, dandelion 21 | ./data/flowers/dandelion/1852samples2.jpg, roses 22 | ./data/flowers/roses/159samples0.jpg, roses 23 | ./data/flowers/tulips/1147samples1.jpg, tulips 24 | ./data/flowers/roses/601samples0.jpg, roses 25 | ./data/flowers/sunflowers/2550samples3.jpg, sunflowers 26 | ./data/flowers/roses/225samples0.jpg, roses 27 | ./data/flowers/dandelion/1933samples2.jpg, dandelion 28 | ./data/flowers/tulips/1220samples1.jpg, dandelion 29 | ./data/flowers/dandelion/2225samples2.jpg, sunflowers 30 | ./data/flowers/dandelion/2094samples2.jpg, dandelion 31 | ./data/flowers/sunflowers/2499samples3.jpg, tulips 32 | ./data/flowers/dandelion/1789samples2.jpg, dandelion 33 | ./data/flowers/sunflowers/2602samples3.jpg, dandelion 34 | ./data/flowers/tulips/1230samples1.jpg, tulips 35 | ./data/flowers/dandelion/1698samples2.jpg, dandelion 36 | ./data/flowers/dandelion/1967samples2.jpg, dandelion 37 | ./data/flowers/tulips/994samples1.jpg, dandelion 38 | ./data/flowers/sunflowers/2778samples3.jpg, sunflowers 39 | ./data/flowers/dandelion/1805samples2.jpg, dandelion 40 | ./data/flowers/dandelion/1627samples2.jpg, dandelion 41 | ./data/flowers/tulips/1085samples1.jpg, dandelion 42 | ./data/flowers/tulips/1038samples1.jpg, tulips 43 | ./data/flowers/roses/48samples0.jpg, tulips 44 | ./data/flowers/tulips/985samples1.jpg, tulips 45 | ./data/flowers/tulips/1202samples1.jpg, roses 46 | ./data/flowers/tulips/1387samples1.jpg, tulips 47 | ./data/flowers/tulips/1354samples1.jpg, sunflowers 48 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch import nn ,optim 6 | from torch.autograd import Variable 7 | from config import config 8 | from datasets import * 9 | import Model 10 | from utils.utils import accuracy 11 | classes= {0:"roses",1:"tulips",2:"dandelion",3:"sunflowers"} 12 | 13 | #用于评估模型 14 | def evaluate(test_loader,model,criterion): 15 | sum = 0 16 | test_loss_sum = 0 17 | test_top1_sum = 0 18 | model.eval() 19 | 20 | for ims, label in test_loader: 21 | input_test = Variable(ims).cuda() 22 | target_test = Variable(torch.from_numpy(np.array(label)).long()).cuda() 23 | output_test = model(input_test) 24 | loss = criterion(output_test, target_test) 25 | top1_test = accuracy(output_test, target_test, topk=(1,)) 26 | sum += 1 27 | test_loss_sum += loss.data.cpu().numpy() 28 | test_top1_sum += top1_test[0].cpu().numpy()[0] 29 | avg_loss = test_loss_sum / sum 30 | avg_top1 = test_top1_sum / sum 31 | return avg_loss, avg_top1 32 | 33 | 34 | def test(test_loader,model): 35 | if torch.cuda.is_available(): 36 | model.cuda() 37 | model.eval() 38 | predict_file = open("%s.txt" % config.model_name, 'w') 39 | for i, (input,filename) in enumerate(tqdm(test_loader)): 40 | if torch.cuda.is_available(): 41 | input = Variable(input).cuda() 42 | else: 43 | input= Variable(input) 44 | #print("input.size = ",input.data.shape) 45 | y_pred = model(input) 46 | smax = nn.Softmax(1) 47 | smax_out = smax(y_pred) 48 | pred_label = np.argmax(smax_out.cpu().data.numpy()) 49 | predict_file.write(filename[0]+', ' +classes[pred_label]+'\n') 50 | #print(filename[0],"的预测的结果为:",labelText[pred_label]) 51 | 52 | 53 | def test_one_image(image,model): 54 | 55 | 56 | model.eval() 57 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 58 | image = cv2.resize(image, (config.img_height, config.img_width)) 59 | img = transforms.ToTensor()(image) 60 | img = img.unsqueeze(0) # 增加一个维度 61 | 62 | img = Variable(img) 63 | 64 | 65 | y_pred = model(img) 66 | smax = nn.Softmax(1) 67 | smax_out = smax(y_pred) 68 | pred_label = np.argmax(smax_out.cpu().data.numpy()) 69 | # print(smax_out.cpu().data.numpy()) 70 | # print(pred_label) 71 | # print(smax_out.cpu().data.numpy()[0][pred_label]) 72 | if pred_label == 0: 73 | result = '这是玫瑰花的概率为:%.4f'%smax_out.cpu().data.numpy()[0][pred_label] 74 | 75 | elif pred_label == 1: 76 | result = '这是郁金香的概率为:%.4f' % smax_out.cpu().data.numpy()[0][pred_label] 77 | elif pred_label ==2: 78 | result = '这是蒲公英的概率为:%.4f' % smax_out.cpu().data.numpy()[0][pred_label] 79 | elif pred_label == 3: 80 | result = '这是向日葵的概率为:%.4f' % smax_out.cpu().data.numpy()[0][pred_label] 81 | 82 | return result 83 | 84 | if __name__ == '__main__': 85 | 86 | #1. 定义测试集 87 | test_list, _ = get_files(config.data_folder,config.ratio) 88 | test_loader = DataLoader(datasets(test_list, transform=None,test = True), batch_size= 1, shuffle=False, 89 | collate_fn=collate_fn, num_workers=4) # 测试时这里的batch_size = 1 90 | 91 | #2. 加载模型及其参数 92 | model = Model.get_net() 93 | checkpoint = torch.load(config.weights+ config.model_name+'.pth') 94 | model.load_state_dict(checkpoint["state_dict"]) 95 | #optimizer.load_state_dict(checkpoint["optimizer"]) 96 | print("Start Test.......") 97 | test(test_loader,model) 98 | 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import os 5 | from config import config 6 | import Model 7 | from torch import optim,nn 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | from datasets import * 11 | from test import * 12 | from utils.utils import* 13 | if __name__ == '__main__' : 14 | # 1.创建文件夹 15 | if not os.path.exists(config.example_folder): 16 | os.mkdir(config.example_folder) 17 | if not os.path.exists(config.weights): 18 | os.mkdir(config.weights) 19 | if not os.path.exists(config.logs): 20 | os.mkdir(config.logs) 21 | 22 | 23 | # 2 定义模型 24 | model = Model.get_net() 25 | if torch.cuda.is_available(): 26 | model =model.cuda() 27 | 28 | #print(model) 29 | 30 | 31 | optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=config.weight_decay) 32 | criterion = nn.CrossEntropyLoss().cuda() 33 | 34 | # 3.是否需要加载checkpoints 训练 35 | start_epoch = 0 36 | current_accuracy = 0 37 | resume = False #false不加载模型 38 | if resume: 39 | checkpoint = torch.load(config.weights+ config.model_name+'.pth') 40 | start_epoch = checkpoint["epoch"] 41 | model.load_state_dict(checkpoint["state_dict"]) 42 | optimizer.load_state_dict(checkpoint["optimizer"]) 43 | 44 | # 4. 定义训练集 测试集 45 | transform = transforms.Compose([ 46 | transforms.RandomResizedCrop(90), 47 | transforms.ColorJitter(0.05, 0.05, 0.05), 48 | transforms.RandomRotation(30), 49 | transforms.RandomGrayscale(p = 0.5), 50 | transforms.Resize((config.img_width, config.img_height)), 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225])]) 54 | #transform = transforms.Compose([transforms.ToTensor()]) 55 | 56 | #_, train_list = divide_data(config.data_folder,config.ratio) 57 | _, train_list = get_files(config.data_folder,config.ratio) 58 | input_data = datasets(train_list,transform= transform) 59 | #train_data = DataLoader(input_data) 60 | train_loader = DataLoader(input_data,batch_size = config.batch_size,shuffle = True,collate_fn = collate_fn ,pin_memory=False,num_workers=4) 61 | 62 | #测试集 不要数据增强 transform = None 63 | #test_list, _=divide_data(config.data_folder,config.ratio) 64 | test_list, _ = get_files(config.data_folder, config.ratio) 65 | test_loader = DataLoader(datasets(test_list,transform = None),batch_size= config.batch_size,shuffle =False,collate_fn = collate_fn,num_workers=4) 66 | 67 | #设置动态变换的学习率 lr每经过50个epoch 就变为原来的0.1倍 68 | #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 69 | 70 | 71 | train_loss = [] 72 | acc = [] 73 | test_loss = [] 74 | #5. 开始训练 75 | print("------ Start Training ------\n") 76 | for epoch in range(start_epoch,config.epochs): 77 | model.train() 78 | config.lr = lr_step(epoch) 79 | optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=config.weight_decay) 80 | 81 | loss_epoch = 0 82 | for index,(input,target) in enumerate(train_loader): #inpu.size(0) = batch_size = 16 83 | model.train() 84 | 85 | input = Variable(input).cuda() 86 | target = Variable(torch.from_numpy(np.array(target)).long()).cuda() 87 | output = model(input) 88 | loss = criterion(output,target) 89 | 90 | optimizer.zero_grad() 91 | loss.backward() 92 | optimizer.step() 93 | 94 | loss_epoch += loss.item() 95 | 96 | if (index+1) % 10 == 0: 97 | print("Epoch: {} [{:>3d}/{}]\t Loss: {:.6f} ".format(epoch+1,index*config.batch_size,len(train_loader.dataset),loss.item())) 98 | if (epoch+1) % 1 ==0: 99 | print("\n------ Evaluate ------") 100 | model.eval() 101 | # evaluate the model on the test data 102 | test_loss1, accTop1 = evaluate(test_loader,model,criterion) 103 | acc.append(accTop1) 104 | print("type(accTop1) =",type(accTop1)) 105 | test_loss.append(test_loss1) 106 | train_loss.append(loss_epoch/len(train_loader)) 107 | print("Test_epoch: {} Test_accuracy: {:.4}% Test_Loss: {:.6f}".format(epoch+1,accTop1,test_loss1)) 108 | save_model = accTop1 > current_accuracy #测试的准确率大于当前准确率为True 109 | accTop1 = max(current_accuracy,accTop1) 110 | current_accuracy = accTop1 111 | save_checkpoint({ 112 | "epoch": epoch + 1, 113 | "model_name": config.model_name, 114 | "state_dict": model.state_dict(), 115 | "accTop1": current_accuracy, 116 | "optimizer": optimizer.state_dict(), 117 | }, save_model) 118 | 119 | -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilei1128/pytorch-image-classification/b6658abddc8f7636a2baa91dea2f364c171e77ac/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #This file include some functions 2 | from config import config 3 | import torch 4 | import os 5 | import shutil 6 | 7 | 8 | # 保存模型 9 | def save_checkpoint(state, save_model): 10 | filename = config.weights + config.model_name + ".pth" #os.sep在linux下为‘/’ 11 | torch.save(state, filename) 12 | if save_model: 13 | message = config.weights + config.model_name+ '.pth' 14 | print("Get Better top1 : %s saving weights to %s"%(state["accTop1"],message)) 15 | with open("./logs/%s.txt"%config.model_name,"a") as f: 16 | print("Get Better top1 : %s saving weights to %s"%(state["accTop1"],message),file=f) 17 | 18 | 19 | # 计算模型的准确度 20 | def accuracy(output,target,topk = (1, 5)): 21 | '''计算模型的precision top1 and top5''' 22 | maxk = max(topk) 23 | batch_size = target.size(0) # size(0) = batch_size size(1) = num_classes 24 | _, pred = output.topk(maxk, 1, True, True) # 1 是dim维度 25 | pred = pred.t() #转置 26 | correct = pred.eq(target.view(1,-1).expand_as(pred)) #eq表示是否相等 27 | 28 | res =[] 29 | for k in topk: 30 | correct_k = correct[:k].view(-1).float().sum(0,keepdim =True) #correct[:k]是取前k行 31 | '''.float()转换成float类型,False = 0,True = 1''' 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | 35 | 36 | def lr_step(epoch): 37 | if epoch < 30: 38 | lr = 0.01 39 | elif epoch < 80: 40 | lr = 0.001 41 | elif epoch < 120: 42 | lr = 0.0005 43 | else: 44 | lr = 0.0001 45 | return lr 46 | --------------------------------------------------------------------------------