├── Model.pth ├── README ├── ResNet50.py ├── dataload.py ├── loss.png ├── test.txt ├── train.py └── 正确率.png /Model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pingxi1009/ResNet50/41eca474d54a4c75a756cc9cb342d258d0f7af0a/Model.pth -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | 文件说明: 2 | 1、ResNet50.py 是构建模型 ResNet50 3 | 2、dataload.py 是用来加载数据集 4 | 3、train.py 是训练模型 5 | 4、Model.pth 是训练好的模型,可直接加载使用 6 | 5、两张 PNG 图片为记录一次训练的成果 7 | 8 | 使用须知: 9 | 1、我的环境是 (win10 64位) + (Python 3.8.3) + (OpenCV 4.5.1) + (Pytorch 1.7.0+cu110) 10 | 2、要使用的话直接运行 train.py 即可 11 | 3、欢迎沟通交流,一起进步,共同成长1009088103@qq.com -------------------------------------------------------------------------------- /ResNet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | class Bottleneck(nn.Module): 6 | expansion = 4 7 | 8 | def __init__(self, inplanes, planes, stride = 1, downsample=None): 9 | super(Bottleneck, self).__init__() 10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 11 | self.bn1 = nn.BatchNorm2d(planes) # 数据归一化处理,使其均值为0,方差为1,可有效避免梯度消失 12 | 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) # 数据归一化处理,使其均值为0,方差为1,可有效避免梯度消失 15 | 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = nn.BatchNorm2d(planes * 4) # 数据归一化处理,使其均值为0,方差为1,可有效避免梯度消失 18 | 19 | self.relu = nn.ReLU(inplace=True) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | def __init__(self, block, layers, num_classes=1000): 47 | self.inplanes = 64 48 | super(ResNet, self).__init__() 49 | 50 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.bn1 = nn.BatchNorm2d(64) 53 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) 54 | self.layer1 = self._make_layer(block, 64, layers[0]) 55 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 56 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 57 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 58 | self.avgpool = nn.AvgPool2d(7) 59 | self.fc = nn.Linear(512 * block.expansion, num_classes) 60 | 61 | # 遍历所有模块,然后对其中参数进行初始化 62 | for m in self.modules(): # self.modules()采用深度优先遍历的方式,存储了net的所有模块 63 | if isinstance(m, nn.Conv2d): # 判断是不是卷积 64 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 65 | m.weight.data.normal_(0, math.sqrt(2. / n)) # 对权值参数初始化 66 | elif isinstance(m, nn.BatchNorm2d): # 判断是不是数据归一化 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | def _make_layer(self, block, planes, blocks, stride=1): 71 | downsample = None 72 | if stride != 1 or self.inplanes != planes * block.expansion: 73 | # 当是 3 4 6 3 的第一层时,由于跨层要做一个步长为 2 的卷积 size会变成二分之一,所以此处跳连接 x 必须也是相同维度 74 | downsample = nn.Sequential( # 对跳连接 x 做 1x1 步长为 2 的卷积,保证跳连接的时候 size 一致 75 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 76 | nn.BatchNorm2d(planes * block.expansion) 77 | ) 78 | layers = [] 79 | layers.append(block(self.inplanes, planes, stride, downsample)) # 将跨区的第一个要做步长为 2 的卷积添加到layer里面 80 | self.inplanes = planes * block.expansion 81 | for i in range(1, blocks): # 将除去第一个的剩下的 block 添加到layer里面 82 | layers.append(block(self.inplanes, planes)) 83 | 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | x = self.conv1(x) 88 | x = self.bn1(x) 89 | x = self.relu(x) 90 | x = self.maxpool(x) 91 | 92 | x = self.layer1(x) 93 | x = self.layer2(x) 94 | x = self.layer3(x) 95 | x = self.layer4(x) 96 | 97 | x = self.avgpool(x) 98 | x = x.view(x.size(0), -1) 99 | x = self.fc(x) 100 | 101 | return x 102 | 103 | def resnet50(pretrained=False): 104 | '''Constructs a ResNet-50 model. 105 | Args: 106 | pretrained (bool): If True, returns a model pre-trained on ImageNet 107 | ''' 108 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2) 109 | # if pretrained: # 加载已经生成的模型 110 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 111 | 112 | return model -------------------------------------------------------------------------------- /dataload.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | from PIL import Image 6 | 7 | # 初始化根目录 8 | train_path = 'D:\\DeapLearn Project\\ResNet50\\CatDogData\\train\\' 9 | test_path = 'D:\\DeapLearn Project\\ResNet50\\CatDogData\\test\\' 10 | 11 | # 定义读取文件的格式 12 | # 数据集 13 | class MyDataSet(Dataset): 14 | def __init__(self, data_path:str, transform=None): 15 | super(MyDataSet, self).__init__() 16 | self.data_path = data_path 17 | if transform is None: 18 | self.transform = transforms.Compose( 19 | [ 20 | transforms.Resize(size=(224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 23 | ] 24 | ) 25 | else: 26 | self.transform = transforms 27 | self.path_list = os.listdir(data_path) 28 | 29 | def __getitem__(self, idx:int): 30 | img_path = self.path_list[idx] 31 | if img_path.split('.')[0] == 'dog': 32 | label = 1 33 | else: 34 | label = 0 35 | label = torch.as_tensor(label, dtype=torch.int64) 36 | img_path = os.path.join(self.data_path, img_path) 37 | img = Image.open(img_path) 38 | img = self.transform(img) 39 | return img, label 40 | 41 | def __len__(self)->int: 42 | return len(self.path_list) 43 | 44 | train_ds = MyDataSet(train_path) 45 | 46 | full_ds = train_ds 47 | train_size = int(0.8*len(full_ds)) 48 | test_size = len(full_ds) - train_size 49 | new_train_ds, test_ds = torch.utils.data.random_split(full_ds, [train_size, test_size]) 50 | 51 | # 数据加载 52 | new_train_loader = DataLoader(new_train_ds, batch_size=32, shuffle=True, pin_memory=True, num_workers=0) 53 | new_test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, pin_memory=True, num_workers=0) 54 | -------------------------------------------------------------------------------- /loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pingxi1009/ResNet50/41eca474d54a4c75a756cc9cb342d258d0f7af0a/loss.png -------------------------------------------------------------------------------- /test.txt: -------------------------------------------------------------------------------- 1 | Start write!!! 2 | Accuracy on test set: (3471/5000)69 % 3 | Accuracy on test set: (3907/5000)78 % 4 | Accuracy on test set: (4224/5000)84 % 5 | Accuracy on test set: (4425/5000)88 % 6 | Accuracy on test set: (4412/5000)88 % 7 | Accuracy on test set: (4576/5000)91 % 8 | Accuracy on test set: (4599/5000)91 % 9 | Accuracy on test set: (4572/5000)91 % 10 | Accuracy on test set: (4662/5000)93 % 11 | Accuracy on test set: (4616/5000)92 % 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import os 4 | import matplotlib.pyplot as plt 5 | import torch.optim as optim 6 | import math 7 | 8 | LR = 0.0005 # 设置学习率 9 | EPOCH_NUM = 10 # 训练轮次 10 | 11 | # 导入 定义的 ResNet50 和 导入的数据 12 | from ResNet50 import resnet50 13 | from dataload import new_train_loader, new_test_loader 14 | 15 | 16 | def time_since(since): 17 | s = time.time() - since 18 | m = math.floor(s/60) 19 | s -= m*60 20 | return '%dm %ds' % (m, s) 21 | 22 | model = resnet50() 23 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | model.to(device) 25 | 26 | train_data = new_train_loader 27 | test_data = new_test_loader 28 | 29 | criterion = torch.nn.CrossEntropyLoss() 30 | optimizer = optim.Adam(model.parameters(), lr=LR) 31 | 32 | def train(epoch, loss_list): 33 | running_loss = 0.0 34 | for batch_idx, data in enumerate(new_train_loader, 0): 35 | inputs, target = data[0], data[1] 36 | inputs, target = inputs.to(device), target.to(device) 37 | optimizer.zero_grad() 38 | 39 | outputs = model(inputs) 40 | 41 | loss = criterion(outputs, target) 42 | loss.backward() 43 | optimizer.step() 44 | 45 | loss_list.append(loss.item()) 46 | running_loss += loss.item() 47 | if batch_idx % 100 == 99: 48 | print(f'[{time_since(start)}] Epoch {epoch}', end='') 49 | print('[%d, %5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 100)) 50 | running_loss = 0.0 51 | 52 | return loss_list 53 | 54 | def test(): 55 | correct = 0 56 | total = 0 57 | with torch.no_grad(): 58 | for _, data in enumerate(new_test_loader, 0): 59 | inputs, target = data[0], data[1] 60 | inputs, target = inputs.to(device), target.to(device) 61 | outputs = model(inputs) 62 | _, prediction = torch.max(outputs.data, dim=1) 63 | 64 | total += target.size(0) 65 | correct += (prediction == target).sum().item() 66 | print('Accuracy on test set: (%d/%d)%d %%' % (correct, total, 100 * correct / total)) 67 | with open("test.txt", "a") as f: 68 | f.write('Accuracy on test set: (%d/%d)%d %% \n' % (correct, total, 100 * correct / total)) 69 | 70 | if __name__ == '__main__': 71 | start = time.time() 72 | 73 | with open("test.txt", "a") as f: 74 | f.write('Start write!!! \n') 75 | 76 | loss_list = [] 77 | for epoch in range(EPOCH_NUM): 78 | train(epoch, loss_list) 79 | test() 80 | torch.save(model.state_dict(), 'Model.pth') 81 | 82 | x_ori = [] 83 | for i in range(len(loss_list)): 84 | x_ori.append(i) 85 | plt.title("Graph") 86 | plt.plot(x_ori, loss_list) 87 | plt.ylabel("Y") 88 | plt.xlabel("X") 89 | plt.show() 90 | 91 | 92 | -------------------------------------------------------------------------------- /正确率.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pingxi1009/ResNet50/41eca474d54a4c75a756cc9cb342d258d0f7af0a/正确率.png --------------------------------------------------------------------------------