├── README.md ├── dataset └── split_data.py └── torch-classification ├── AlexNet ├── AlexNet.pth ├── __pycache__ │ └── model.cpython-38.pyc ├── class_indices.json ├── model.py ├── predict.py └── train.py └── resNet ├── model.py ├── predict.py └── train.py /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bryant6/landslide-detection/8df896a728e82f6d40d05f9d41bddf672534ba12/README.md -------------------------------------------------------------------------------- /dataset/split_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/20 15:02 5 | """ 6 | import os 7 | from shutil import copy, rmtree 8 | import random 9 | 10 | 11 | def mk_file(file_path: str): 12 | if os.path.exists(file_path): 13 | # 如果文件夹存在,则先删除原文件夹在重新创建 14 | rmtree(file_path) 15 | os.makedirs(file_path) 16 | 17 | 18 | def main(): 19 | # 保证随机可复现 20 | random.seed(0) 21 | 22 | # 将数据集中10%的数据划分到验证集中 23 | split_rate = 0.1 24 | 25 | # 指向你解压后的flower_photos文件夹 26 | cwd = os.getcwd() 27 | data_root = os.path.join(cwd, "Bijie-landslide-dataset") 28 | origin_slide_path = os.path.join(data_root, "images") 29 | assert os.path.exists(origin_slide_path), "path '{}' does not exist.".format(origin_slide_path) 30 | 31 | slide_class = [cla for cla in os.listdir(origin_slide_path) 32 | if os.path.isdir(os.path.join(origin_slide_path, cla))] 33 | 34 | # 建立保存训练集的文件夹 35 | train_root = os.path.join(data_root, "train") 36 | mk_file(train_root) 37 | for cla in slide_class: 38 | # 建立每个类别对应的文件夹 39 | mk_file(os.path.join(train_root, cla)) 40 | 41 | # 建立保存验证集的文件夹 42 | val_root = os.path.join(data_root, "val") 43 | mk_file(val_root) 44 | for cla in slide_class: 45 | # 建立每个类别对应的文件夹 46 | mk_file(os.path.join(val_root, cla)) 47 | 48 | for cla in slide_class: 49 | cla_path = os.path.join(origin_slide_path, cla) 50 | images = os.listdir(cla_path) 51 | num = len(images) 52 | # 随机采样验证集的索引 53 | eval_index = random.sample(images, k=int(num*split_rate)) 54 | for index, image in enumerate(images): 55 | if image in eval_index: 56 | # 将分配至验证集中的文件复制到相应目录 57 | image_path = os.path.join(cla_path, image) 58 | new_path = os.path.join(val_root, cla) 59 | copy(image_path, new_path) 60 | else: 61 | # 将分配至训练集中的文件复制到相应目录 62 | image_path = os.path.join(cla_path, image) 63 | new_path = os.path.join(train_root, cla) 64 | copy(image_path, new_path) 65 | print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar 66 | print() 67 | 68 | print("processing done!") 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /torch-classification/AlexNet/AlexNet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bryant6/landslide-detection/8df896a728e82f6d40d05f9d41bddf672534ba12/torch-classification/AlexNet/AlexNet.pth -------------------------------------------------------------------------------- /torch-classification/AlexNet/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bryant6/landslide-detection/8df896a728e82f6d40d05f9d41bddf672534ba12/torch-classification/AlexNet/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /torch-classification/AlexNet/class_indices.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "landslide", 3 | "1": "non-landslide" 4 | } -------------------------------------------------------------------------------- /torch-classification/AlexNet/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/21 15:20 5 | """ 6 | import torch.nn as nn 7 | import torch 8 | 9 | class AlexNet(nn.Module): 10 | def __init__(self, num_classes=1000, init_weights=False): 11 | super(AlexNet, self).__init__() 12 | 13 | self.features = nn.Sequential( 14 | nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), 15 | nn.ReLU(inplace=True), 16 | nn.MaxPool2d(kernel_size=3, stride=2), 17 | 18 | nn.Conv2d(48, 128, kernel_size=5, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | 22 | nn.Conv2d(128, 192, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | 25 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 26 | nn.ReLU(inplace=True), 27 | 28 | nn.Conv2d(192, 128, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2) 31 | ) 32 | self.classifier = nn.Sequential( 33 | nn.Dropout(p=0.5), 34 | nn.Linear(128 * 6 * 6, 2048), 35 | nn.ReLU(inplace=True), 36 | 37 | nn.Dropout(p=0.5), 38 | nn.Linear(2048, 2048), 39 | nn.ReLU(inplace=True), 40 | 41 | nn.Linear(2048, num_classes) 42 | ) 43 | if init_weights: 44 | self._initialize_weights() 45 | 46 | def forward(self, x): 47 | x = self.features(x) 48 | x = torch.flatten(x, start_dim=1) 49 | x = self.classifier(x) 50 | return x 51 | 52 | def _initialize_weights(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 56 | if m.bias is not None: 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | nn.init.normal_(m.weight, 0, 0.01) 60 | nn.init.constant_(m.bias, 0) -------------------------------------------------------------------------------- /torch-classification/AlexNet/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/21 16:48 5 | """ 6 | import os 7 | import json 8 | 9 | import torch 10 | from PIL import Image 11 | from torchvision import transforms 12 | import matplotlib.pyplot as plt 13 | 14 | from model import AlexNet 15 | 16 | 17 | def main(): 18 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | data_transform = transforms.Compose( 20 | [transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 23 | 24 | image_path = "../../dataset/Bijie-landslide-dataset/images/landslide/df018.png" 25 | image = Image.open(image_path) 26 | plt.imshow(image) 27 | 28 | image = data_transform(image) 29 | image = torch.unsqueeze(image, dim=0) 30 | 31 | json_path = "./class_indices.json" 32 | assert os.path.exists(json_path), "file {} does not exist.".format(json_path) 33 | json_file = open(json_path, "r") 34 | class_dict = json.load(json_file) 35 | 36 | model = AlexNet(num_classes=2).to(device) 37 | weights_path = "./AlexNet.pth" 38 | assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) 39 | model.load_state_dict(torch.load(weights_path)) 40 | 41 | model.eval() 42 | with torch.no_grad(): 43 | output = torch.squeeze(model(image.to(device))) 44 | predict = torch.softmax(output, dim=0).cpu() 45 | predict_class = torch.argmax(predict).numpy() 46 | 47 | print_res = "class: {} prob: {:.3}".format(class_dict[str(predict_class)], 48 | predict[predict_class].numpy()) 49 | plt.title(print_res) 50 | print(print_res) 51 | plt.show() 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /torch-classification/AlexNet/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/21 16:13 5 | """ 6 | import os 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torchvision import transforms, datasets, utils 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import torch.optim as optim 15 | from tqdm import tqdm 16 | 17 | from model import AlexNet 18 | 19 | 20 | def main(): 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | print("using %s device." % device) 23 | 24 | data_transform = { 25 | "train": transforms.Compose([transforms.RandomResizedCrop(224), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 29 | "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224) 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} 32 | data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) 33 | image_path = os.path.join(data_root, "dataset", "Bijie-landslide-dataset") 34 | assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 35 | train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 36 | transform=data_transform["train"]) 37 | train_num = len(train_dataset) 38 | validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 39 | transform=data_transform["val"]) 40 | validate_num = len(validate_dataset) 41 | print("using {} for training, {} for validation".format(train_num, validate_num)) 42 | 43 | landslide_list = train_dataset.class_to_idx 44 | class_dict = dict((val, key) for key, val in landslide_list.items()) 45 | json_str = json.dumps(class_dict, indent=1) 46 | with open('class_indices.json', 'w') as json_file: 47 | json_file.write(json_str) 48 | 49 | batch_size = 32 50 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) 51 | validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=0) 52 | 53 | net = AlexNet(num_classes=2, init_weights=True) 54 | net.to(device) 55 | 56 | loss_function = nn.CrossEntropyLoss() 57 | optimizer = optim.Adam(net.parameters(), lr=0.0002) 58 | 59 | epochs = 10 60 | save_path = "./AlexNet.pth" 61 | best_acc = 0.0 62 | train_steps = len(train_loader) 63 | for epoch in range(epochs): 64 | net.train() 65 | running_loss = 0.0 66 | train_bar = tqdm(train_loader) 67 | for step, data in enumerate(train_bar): 68 | images, labels = data 69 | optimizer.zero_grad() 70 | outputs = net(images.to(device)) 71 | loss = loss_function(outputs, labels.to(device)) 72 | loss.backward() 73 | optimizer.step() 74 | 75 | running_loss += loss.item() 76 | train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, 77 | epochs, 78 | loss) 79 | 80 | net.eval() 81 | acc = 0.0 82 | with torch.no_grad(): 83 | val_bar = tqdm(validate_loader) 84 | for val_data in val_bar: 85 | val_images, val_labels = val_data 86 | outputs = net(val_images.to(device)) 87 | predict_y = torch.max(outputs, dim=1)[1] 88 | acc += torch.eq(predict_y, val_labels.to(device)).sum().item() 89 | val_accuracy = acc / validate_num 90 | print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % 91 | (epoch + 1, running_loss / train_steps, val_accuracy)) 92 | if val_accuracy > best_acc: 93 | best_acc = val_accuracy 94 | torch.save(net.state_dict(), save_path) 95 | 96 | print("Training Finished!") 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /torch-classification/resNet/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/20 14:48 5 | """ 6 | import torch.nn as nn 7 | import torch 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | # layer18/layer34 12 | expansion = 1 # 对应各层的卷积核的个数是否改变,对于18和34层网络,相邻两层卷积核个数均为64 13 | 14 | def __init__(self, in_channel, out_channel, stride=1, downsample=None): 15 | """ 16 | 初始化函数 17 | :param in_channel: 输入特征矩阵的深度 18 | :param out_channel:输出特征矩阵的深度(卷积核的个数) 19 | :param stride: 步长 20 | :param downsample: 下采样,对应虚线处的残差结构,1x1的卷积层 21 | """ 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = nn.Conv2d(in_channels=in_channel, 24 | out_channels=out_channel, 25 | kernel_size=3, # 卷积核的大小 26 | stride=stride, 27 | padding=1, # 填充步长 28 | bias=False,) # 偏置,使用BatchNormalization时不用偏置 29 | self.bn1 = nn.BatchNorm2d(out_channel) 30 | self.relu = nn.ReLU() 31 | self.conv2 = nn.Conv2d(in_channels=out_channel, 32 | out_channels=out_channel, 33 | kernel_size=3, 34 | stride=1, 35 | padding=1, 36 | bias=False) 37 | self.bn2 = nn.BatchNorm2d(out_channel) 38 | self.downsample = downsample 39 | 40 | def forward(self, x): 41 | identity = x # 捷径输出值 42 | if self.downsample is not None: 43 | identity = self.downsample(x) # 有下采样函数,捷径输出值为下采样输出值 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | """ 60 | 注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。 61 | 但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2, 62 | 这么做的好处是能够在top1上提升大概0.5%的准确率。 63 | 可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch 64 | """ 65 | expansion = 4 66 | 67 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, 68 | groups=1, width_per_group=64): 69 | super(Bottleneck, self).__init__() 70 | 71 | width = int(out_channel * (width_per_group / 64.)) * groups 72 | 73 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, 74 | kernel_size=1, stride=1, bias=False) # squeeze channels 75 | self.bn1 = nn.BatchNorm2d(width) 76 | # ----------------------------------------- 77 | self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups, 78 | kernel_size=3, stride=stride, bias=False, padding=1) 79 | self.bn2 = nn.BatchNorm2d(width) 80 | # ----------------------------------------- 81 | self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion, 82 | kernel_size=1, stride=1, bias=False) # unsqueeze channels 83 | self.bn3 = nn.BatchNorm2d(out_channel*self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | 87 | def forward(self, x): 88 | identity = x 89 | if self.downsample is not None: 90 | identity = self.downsample(x) 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | out += identity 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, block, blocks_num, num_classes=1000, include_top=True): 112 | """ 113 | 114 | :param block: BasicBlock(18, 34), Bottleneck(50, 101, 152) 115 | :param blocks_num: 残差网络的个数 116 | :param num_classes: 分类个数 117 | :param include_top: 便于扩展网络 118 | :param groups: 119 | :param width_per_group: 120 | """ 121 | super(ResNet, self).__init__() 122 | self.include_top = include_top 123 | self.in_channel = 64 124 | 125 | # self.groups = groups 126 | # self.width_per_group = width_per_group 127 | 128 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 129 | padding=3, bias=False) 130 | self.bn1 = nn.BatchNorm2d(self.in_channel) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | self.layer1 = self._make_layer(block, 64, blocks_num[0]) 134 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 135 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 136 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 137 | if self.include_top: 138 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | 145 | def _make_layer(self, block, channel, block_num, stride=1): 146 | """ 147 | 148 | :param block: 149 | :param channel: 残差结构中卷积核的个数 150 | :param block_num: 残差结构的个数 151 | :param stride: 152 | :return: 153 | """ 154 | downsample = None 155 | if stride != 1 or self.in_channel != channel * block.expansion: 156 | downsample = nn.Sequential( 157 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 158 | nn.BatchNorm2d(channel * block.expansion)) 159 | 160 | layers = [] 161 | layers.append(block(self.in_channel, 162 | channel, 163 | downsample=downsample, 164 | stride=stride)) 165 | self.in_channel = channel * block.expansion 166 | 167 | for _ in range(1, block_num): 168 | layers.append(block(self.in_channel, 169 | channel)) 170 | 171 | return nn.Sequential(*layers) 172 | 173 | def forward(self, x): 174 | x = self.conv1(x) 175 | x = self.bn1(x) 176 | x = self.relu(x) 177 | x = self.maxpool(x) 178 | 179 | x = self.layer1(x) 180 | x = self.layer2(x) 181 | x = self.layer3(x) 182 | x = self.layer4(x) 183 | 184 | if self.include_top: 185 | x = self.avgpool(x) 186 | x = torch.flatten(x, 1) 187 | x = self.fc(x) 188 | 189 | return x 190 | 191 | 192 | def resnet34(num_classes=1000, include_top=True): 193 | # https://download.pytorch.org/models/resnet34-333f7ec4.pth 194 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 195 | 196 | 197 | def resnet50(num_classes=1000, include_top=True): 198 | # https://download.pytorch.org/models/resnet50-19c8e357.pth 199 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 200 | 201 | 202 | def resnet101(num_classes=1000, include_top=True): 203 | # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth 204 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top) 205 | 206 | 207 | def resnext50_32x4d(num_classes=1000, include_top=True): 208 | # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth 209 | groups = 32 210 | width_per_group = 4 211 | return ResNet(Bottleneck, [3, 4, 6, 3], 212 | num_classes=num_classes, 213 | include_top=include_top, 214 | groups=groups, 215 | width_per_group=width_per_group) 216 | 217 | 218 | def resnext101_32x8d(num_classes=1000, include_top=True): 219 | # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth 220 | groups = 32 221 | width_per_group = 8 222 | return ResNet(Bottleneck, [3, 4, 23, 3], 223 | num_classes=num_classes, 224 | include_top=include_top, 225 | groups=groups, 226 | width_per_group=width_per_group) -------------------------------------------------------------------------------- /torch-classification/resNet/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/20 14:50 5 | """ 6 | import json 7 | import torch 8 | from torchvision import transforms 9 | from PIL import Image 10 | 11 | from model import resnet50 12 | 13 | 14 | def main(): 15 | transform = transforms.Compose([ 16 | transforms.RandomResizedCrop(224), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.ToTensor(), 19 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]) 21 | json_file = open("./class_indices.json", "r") 22 | class_dict = json.load(json_file) 23 | 24 | net = resnet50(num_classes=2) 25 | net.load_state_dict(torch.load('resNet50.pth')) 26 | 27 | img = Image.open('../../dataset/Aba_dataset/Aba_3/landslide_jpeg/140.jpeg') 28 | img = transform(img) 29 | img = torch.unsqueeze(img, dim=0) 30 | 31 | net.eval() 32 | with torch.no_grad(): 33 | outputs = torch.squeeze(net(img)) 34 | predicts = torch.softmax(outputs, dim=0) 35 | print(outputs) 36 | print(predicts) 37 | predict = torch.argmax(predicts).numpy() 38 | print("class: {}, prob: {:.3}".format(class_dict[str(predict)], 39 | predicts[predict].numpy())) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() -------------------------------------------------------------------------------- /torch-classification/resNet/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | #-*-coding:utf-8-*- 3 | # @anthor: wangyu a beginner programmer, striving to be the strongest. 4 | # @date: 2021/7/20 14:50 5 | """ 6 | import json 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | import torchvision 12 | import torch.optim as optim 13 | from torchvision import transforms, datasets 14 | 15 | from model import resnet50 16 | 17 | 18 | def main(): 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | # device = torch.device("cpu") 21 | print("using {} device.".format(device)) 22 | 23 | transform = { 24 | "train": transforms.Compose([ 25 | transforms.RandomResizedCrop(224), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]), 30 | "val": transforms.Compose([ 31 | transforms.Resize(256), 32 | transforms.CenterCrop(224), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]) 36 | } 37 | 38 | data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) 39 | image_path = os.path.join(data_root, "dataset", "Bijie-landslide-dataset") 40 | assert os.path.exists(image_path), "{} path does not exist.".format(image_path) 41 | print("图像路径:", image_path) 42 | train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), 43 | transform=transform["train"]) 44 | train_num = len(train_dataset) 45 | 46 | validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), 47 | transform=transform["val"]) 48 | validate_num = len(validate_dataset) 49 | print("using {} images for training, {} images for validation.".format(train_num, validate_num)) 50 | 51 | # {'landslide': 0, 'non-landslide': 1} 52 | landslide_list = train_dataset.class_to_idx 53 | # print(landslide_list) 54 | class_dict = dict((val, key) for key, val in landslide_list.items()) 55 | # print(class_dict) # {0: 'landslide', 1: 'non-landslide'} 56 | json_str = json.dumps(class_dict, indent=1) 57 | # print(json_str) 58 | with open('class_indices.json', 'w') as json_file: 59 | json_file.write(json_str) 60 | 61 | batch_size = 16 62 | nw = 0 63 | train_loader = torch.utils.data.DataLoader(train_dataset, 64 | batch_size=batch_size, 65 | shuffle=True, 66 | num_workers=nw) 67 | validate_loader = torch.utils.data.DataLoader(validate_dataset, 68 | batch_size=batch_size, 69 | shuffle=True, 70 | num_workers=nw) 71 | # val_data_iter = iter(validate_loader) 72 | # val_image, val_label = val_data_iter.next() 73 | 74 | net = resnet50() 75 | weight_path = "./resnet50-pre.pth" 76 | net.load_state_dict(torch.load(weight_path, map_location=device)) 77 | inchannel = net.fc.in_features 78 | net.fc = nn.Linear(inchannel, 2) 79 | net.to(device) 80 | 81 | loss_func = nn.CrossEntropyLoss() 82 | params = [p for p in net.parameters() if p.requires_grad] 83 | optimizer = optim.Adam(params, lr=0.0001) 84 | 85 | EPOCHS = 10 86 | best_acc = 0.0 87 | save_path = './resNet50.pth' 88 | train_steps = len(train_loader) 89 | # print(len(train_loader)) # train_num / batch_size 90 | for epoch in range(EPOCHS): 91 | net.train() 92 | running_loss = 0.0 93 | train_bar = tqdm(train_loader) 94 | for step, data in enumerate(train_bar): 95 | images, labels = data 96 | optimizer.zero_grad() 97 | output = net(images.to(device)) 98 | loss = loss_func(output, labels.to(device)) 99 | loss.backward() 100 | optimizer.step() 101 | running_loss += loss.item() 102 | 103 | train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, EPOCHS, loss) 104 | 105 | net.eval() 106 | acc = 0.0 107 | with torch.no_grad(): 108 | val_bar = tqdm(validate_loader) 109 | for val_data in val_bar: 110 | val_images, val_labels = val_data 111 | outputs = net(val_images.to(device)) 112 | predicts = torch.max(outputs, dim=1)[1] 113 | acc += torch.eq(predicts, val_labels.to(device)).sum().item() 114 | 115 | val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, EPOCHS) 116 | 117 | val_accurate = acc / validate_num 118 | print("[epoch %d] train_loss: %.3f val_accuracy:%.3f" % 119 | (epoch + 1, running_loss / train_steps, val_accurate)) 120 | 121 | if val_accurate > best_acc: 122 | best_acc = val_accurate 123 | torch.save(net.state_dict(), save_path) 124 | 125 | print("Training Finished!") 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | --------------------------------------------------------------------------------