├── download.sh └── README.md /download.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | wget \ 4 | --recursive \ 5 | --page-requisites \ 6 | --html-extension \ 7 | --no-parent \ 8 | --reject *.zip \ 9 | --reject *.html \ 10 | --reject *.txt \ 11 | http://www.cs.toronto.edu/~vmnih/data/ 12 | 13 | mv www.cs.toronto.edu/~vmnih/data ./ 14 | rm -rf www.cs.toronto.edu 15 | 16 | mv data/mass_merged/test/map data/mass_merged/test/map_orig 17 | cd data/mass_merged/test; wget https://www.dropbox.com/s/yk6d4garyz3nm19/multi_test_map.tar.gz?dl=0; 18 | tar zxvf multi_test_map.tar.gz; rm -rf multi_test_map.tar.gz 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mass_road_extraction 2 | Massachusetts Road extraction by u-net using fastaiv1 3 | @[TOC]([fast.ai] unet道路提取) 4 | 前一篇介绍了[unet网络对Camvid数据集的分](https://blog.csdn.net/qq_39337332/article/details/105335457)类,道路提取的问题可以看成简化的Segmentation:就是将像素点转化为两类`['0', '1']`。 5 | 6 | `Massachusetts roads`数据集是Volodymyr Mnih的[博士项目](https://www.cs.toronto.edu/~vmnih/data/),搜集了马萨诸塞州的遥感图片并分别对建筑物和路径进行了标定,这里我们只用到了路径的数据。 7 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200407195356331.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5MzM3MzMy,size_16,color_FFFFFF,t_70 =300x300) 8 | 下载数据的脚本在下面的github链接中。 9 | 10 | 下载下来的数据有很多坏掉的图片,即遥感图片有大部分缺失,这些图片我在前期手动进行了剔除。然后因为每一幅原始的`tiff`遥感数据过大,不好直接进行训练。因此进行了随机剪切,一幅图片随机切出15张`256x256`大小的对应图片作为训练数据。 11 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200407195849271.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5MzM3MzMy,size_16,color_FFFFFF,t_70) 12 | 切分好的数据我放在了下面的百度网盘中: 13 | 14 | > 链接:https://pan.baidu.com/s/1l5OytHjQzx2VxYOmq_nEYg 密码:k4kk 15 | 16 | 现在可以用fastai进行训练了,先导入所需的包 17 | 18 | ```python 19 | from fastai.vision import * 20 | ``` 21 | 设置训练使用的GPU 22 | 23 | ```python 24 | torch.cuda.set_device(1) 25 | ``` 26 | 教研室的服务器是4块12GiB的GPU,这里我指定一号GPU。多说一句,fastai现在不支持多GPU进行训练。不过fastai官方给了分布训练的脚本,可以参考这里: 27 | ```python 28 | https://docs.fast.ai/distributed.html 29 | ``` 30 | 后续我也会试一下。 31 | 32 | ```python 33 | mass_roads = Path('~/road_extraction/data/mass_roads_crop') 34 | mass_roads.ls() 35 | ``` 36 | **OUT** 37 | 38 | ```python 39 | PosixPath('~/road_extraction/data/mass_roads_crop/train'), 40 | PosixPath('~/road_extraction/data/mass_roads_crop/valid')] 41 | ``` 42 | 虽然Volodymyr Mnih给我们划分好了验证集,但是数据数量很少,不太好进行训练,因此最后还是采取了从训练集中划分出20%作为验证集的方式。 43 | 44 | ```python 45 | mass_roads_train = mass_roads / 'train' 46 | mass_roads_train.ls() 47 | ``` 48 | **OUT** 49 | 50 | ```python 51 | PosixPath('~/road_extraction/data/mass_roads_crop/train/map'), 52 | PosixPath('~/road_extraction/data/mass_roads_crop/train/sat')] 53 | ``` 54 | 训练集文件中主要就是`sat`和`map`文件夹,`sat`文件夹内存放的就是遥感数据,`map`文件夹内存放的是对应的mask文件。查看一下内容: 55 | 56 | ```python 57 | sorted(mass_roads_valid.ls()[1].ls())[:5] 58 | ``` 59 | **OUT** 60 | 61 | ```python 62 | [PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/map/0_10228690_15.tif'), 63 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/map/0_10978735_15.tif'), 64 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/map/0_10978795_15.tif'), 65 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/map/0_18028945_15.tif'), 66 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/map/0_21929020_15.tif')] 67 | ``` 68 | 69 | ```python 70 | sorted(mass_roads_valid.ls()[2].ls())[:5] 71 | ``` 72 | **OUT** 73 | 74 | ```python 75 | [PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/sat/0_10228690_15.tiff'), 76 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/sat/0_10978735_15.tiff'), 77 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/sat/0_10978795_15.tiff'), 78 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/sat/0_18028945_15.tiff'), 79 | PosixPath('/home/bir2160400081/road_extraction/data/mass_roads_crop/valid/sat/0_21929020_15.tiff')] 80 | ``` 81 | 可以看到遥感图和mask文件对应的只有后缀名的不同,因此获得标签的函数可以这么写: 82 | 83 | ```python 84 | get_y_fnc = lambda x: x.parent.parent / 'map' / f'{x.stem}.tif' 85 | ``` 86 | 查看一下某一对数据: 87 | 88 | ```python 89 | img_sat = sorted(train_sat.ls())[17] 90 | img_map = get_y_fnc(img_sat) 91 | 92 | img_sat = open_image(img_sat) 93 | img_map = open_image(img_map) 94 | 95 | _, axs = plt.subplots(1, 2, figsize=(5, 5)) 96 | img_sat.show(ax=axs[0]) 97 | img_map.show(ax=axs[1]) 98 | ``` 99 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200407202805663.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5MzM3MzMy,size_16,color_FFFFFF,t_70 =300x140) 100 | 接下来创建fastai的训练数据对象,首先定义`SegmentationItemList`对象。这个实质上和pytorch的dataset类似,定义的是可取、有长度的对象。不这么做的话,fastai会爆奇怪的mask错误。 101 | 102 | ```python 103 | class MySegmentationLabelList(SegmentationLabelList): 104 | def open(self, fn): 105 | return open_mask(fn, div=True) 106 | 107 | class MySegmentationItemList(SegmentationItemList): 108 | _label_cls, _square_show_res = MySegmentationLabelList, False 109 | ``` 110 | 定义训练和验证数据 111 | 112 | ```python 113 | src = (MySegmentationItemList.from_folder(mass_roads_train / 'sat') 114 | # Load in x data from folder 115 | # .split_by_folder(train='train', valid='valid') 116 | .split_by_rand_pct() 117 | # Split data into training and validation set 118 | .label_from_func(get_y_fnc, classes=['0', '1']) 119 | # Get label image of sat 120 | ) 121 | ``` 122 | 转化为`databunch`对象 123 | 124 | ```python 125 | tfms = get_transforms() 126 | bs = 32 127 | size= 256 128 | data = (src.transform(tfms, size=size, tfm_y=True) 129 | # Flip images horizontally 130 | .databunch(bs=bs, path=mass_roads_train) 131 | # Create a databunch 132 | .normalize(imagenet_stats) 133 | # Normalize for resnet 134 | ) 135 | ``` 136 | 看一下数据什么样: 137 | 138 | ```python 139 | data.show_batch(figsize=(5, 5)) 140 | ``` 141 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200407203309460.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5MzM3MzMy,size_16,color_FFFFFF,t_70) 142 | 可以看到就算手动剔除了很多差劲的数据,随机剪切的数据图片中仍然有空白情况,不过这种图片很少,并不是很影响训练。 143 | 144 | 最后定义一个unet训练器,用resnet18提取出来的特征进行图片分割 145 | 146 | ```python 147 | learn = unet_learner(data, models.resnet18, metrics=dice, wd=1e-2) 148 | ``` 149 | `dice`是医学影像分割中比较常用的metric,其公式为: 150 | $$ 151 | Dice = \frac{A \cap B}{A \cup B} 152 | $$ 153 | 照例找一下最有学习率 154 | 155 | ```python 156 | learn.lr_find() 157 | learn.recorder.plot() 158 | ``` 159 | **OUT** 160 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200407210115908.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5MzM3MzMy,size_16,color_FFFFFF,t_70) 161 | 最后选取`lr=1e-5` 162 | 163 | ```python 164 | lr = 1e-5 165 | learn.fit_one_cycle(4, slice(lr), pct_start=0.8) 166 | ``` 167 | 训练上几轮看看结果 168 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200407210224856.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM5MzM3MzMy,size_16,color_FFFFFF,t_70) 169 | 到处模型之后 ,预测一下看看效果如何: 170 | 171 | ```python 172 | learn.path = Path("/home/bir2160400081/fast.ai/mass_road") 173 | learn.save("mass-road-stage-1") 174 | learn.export() 175 | mask_pred = learn.predict(data.train_ds[300][0]) 176 | # image_pred 177 | _, axs = plt.subplots(1, 3, figsize=(5, 5)) 178 | 179 | Image.show(data.train_ds[300][0], ax=axs[0]) 180 | Image.show(data.train_ds[300][1], ax=axs[1]) 181 | Image.show(mask_pred[0], ax=axs[2]) 182 | ``` 183 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/2020040721032181.png) 184 | 可见效果不是很好,可能是因为resnet18太浅了;后续利用上文提到的分布式训练再用resnet34或者resnet50试试看;也可以在寻找一下最佳学习率进行进一步的优化,但是我不太看好这个方式。源代码和数据下载的脚本可以在github上找到: 185 | --------------------------------------------------------------------------------