├── .gitignore ├── README.md ├── dataloaders ├── README.md ├── __init__.py ├── custom_transforms.py ├── datasets │ ├── __init__.py │ ├── rssrai.py │ └── utils │ │ ├── RandomCropTiffImage.py │ │ ├── __init__.py │ │ ├── calculateMeanStd.py │ │ ├── createTrainIdLabelImgs.py │ │ ├── cropTIFFImg.py │ │ ├── labels.py │ │ └── stitchTestImg.py └── utils.py ├── models ├── CombineNet.py ├── __init__.py ├── backbone │ ├── UNet.py │ ├── UNetNested.py │ └── __init__.py ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py └── utils │ ├── __init__.py │ ├── layers.py │ └── utils.py ├── mypath.py ├── requirements.txt ├── train.py ├── train.sh ├── train_combine_net.py ├── train_combine_net.sh ├── utils ├── __init__.py ├── calculate_weights.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── save_model_and_params.py ├── saver.py └── summaries.py ├── vis.py ├── vis.sh ├── vis_combine_net.py └── vis_combine_net.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # ignore pycharm folder 7 | .idea/ 8 | 9 | # ignore running results folder 10 | run/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2019年遥感图像稀疏表征与智能分析竞赛-语义分割组 2 | 比赛官网:[http://rscup.bjxintong.com.cn/#/theme/3/](http://rscup.bjxintong.com.cn/#/theme/3) 3 | 4 | 代码主要基于我的另外一个仓库:https://github.com/huiyiygy/pytorch-deeplab-xception/tree/huiyiygy 修改而来。 5 | 6 | 最终提交kappa系数最高得分为0.37187 7 | 8 | # 实验记录 9 | Deeplab 部分的实验我直接基于上面仓库的代码,修改了一些内容进行训练的,所以没放在这个目录里面,有兴趣可以自行去看。 10 | 11 | | 序号 | 实验序号 | 实验更改内容 | 结果 | 12 | | ---- | ------------- | ------------------------------------------------------------ | -------------------------- | 13 | | 1 | Deeplab 0 | SGD --lr 0.01 --out-stride 16 --epochs 200 --batch-size 8 --out-stride 16 | 0.25705 | 14 | | 2 | Deeplab 1 / 2 | 改变数据集分布,将原始train/val打散,重新组合成新的train/val batch-size 10 --out-stride 16 | 0.25024 | 15 | | 3 | Deeplab 3 / 4 | --out-stride 8 --epochs 100 --batch-size 4 | 0.12774 | 16 | | 4 | Deeplab 5 | 在训练阶段取train数据时,重新添加随机尺度裁剪步骤 混合数据集 --out-stride 16 --epochs 200 --batch-size 10 | 0.19396 | 17 | | 5 | Deeplab 6 | 使用自带的预训练模型进行训练 混合数据集 | 0.20624 | 18 | | 6 | Deeplab 7 | 使用随机尺度裁剪对原始数据集进行增强,扩展图片至1W张, 使用重新生成的数据集进行训练 --weight-decay 0.001 --out-stride 16 --epochs 100 --batch-size 14 | 精简后的模型依旧严重过拟合 | 19 | | 7 | UNet 0 | SGD --lr 0.01 --weight-decay 0.001 --epochs 200 --batch-size 32 | 0.19 | 20 | | 8 | UNet 1 | using Adam no weight-decay --learn-rate 0.001 --weight-decay 0 --epochs 1000 --batch-size 32 | 0.28297 | 21 | | 9 | UNet 2 | 将编码器模块中每块增加一层卷积层,并在下采样层最后添加dropout=0.5, RandomGammaTransform, RandomBilateralFilter, RandomNoise --batch-size 20 | 0.36428 | 22 | | 10 | UNet 3 | 在实验2的基础上add weight-decay amsgrad, 当epoch为总数的50%,80%时,将学习率缩小10倍 --weight-decay 1e-4 | 0.3582 | 23 | | 11 | UNet 4 | 实验3 L2系数过大导致精度过低,将其继续缩小10倍。学习率更新改为每次0.3倍,增加epoch数量 | 0.29755 | 24 | | 12 | UNet++ | 使用UNet++网络训练 --learn-rate 0.001 --weight-decay 0 --epochs 1000 --batch-size 12 | 0.35381 | 25 | | 13 | CombineNet 1 | 单Unet模型,四个角度、上下水平翻转,6张图预测 | 0.37187 | 26 | | 14 | CombineNet 2 | Unet与Unet++ ,多角度预测 | 没来得及提交 | 27 | 28 | # 总结 29 | 30 | 1. (请先容我抱怨几句)官方提供的数据标注质量太差了,数据存在严重不均衡、错标、漏标现象。同时提供的数据量也较少(一共10张图片,我能训练出花来?)。 31 | 2. 使用Deeplab训练,普遍存在严重的过拟合现象。原因:模型复杂度太高,样本数据太少。就算加入正则化项和数据增强,效果也没有提升。 32 | 3. 原生UNet网络可以很好的避免过拟合问题,可以以此为baseline,逐步添加更多的层以及其他训练技巧,提高精度。 33 | 34 | # 后续可以尝试的方法 35 | * LovaszSoftmax Loss 36 | * Test阶段小图预测后拼接回大图时的拼接方法可以进一步优化,目前我用的是400大小,200步长去裁剪,使得每张图片有一半是重合的,这样做一定程度上减轻了边界黑边的问题,不过还是存在。 37 | * 有遥感图像专业知识背景的同学,可以尝试用envi对原始标注图片进行修改,我和某位大佬沟通过,他只通过改原始标注图片,然后用相同的方法训练,分数就提高了10个点。 38 | -------------------------------------------------------------------------------- /dataloaders/README.md: -------------------------------------------------------------------------------- 1 | # 数据集处理步骤 2 | 1. 在labels.py中定义好所有类别,及其对应的颜色 3 | 4 | 2. 修改createTrainIdLabelImgs.py中color标注图片所在的路径,然后运行该脚本,生成color图片对应的train_id图片。便于模型训练时,使用该图片进行判定。 5 | 6 | 3. 修改cropTIFFImg.py中各图片路径,和crop_size大小,对所有图片进行裁剪。模型训练时,使用裁剪后的图片训练。 -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | from dataloaders.datasets import rssrai 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def make_data_loader(args, **kwargs): 7 | if args.dataset == 'rssrai2019': 8 | train_set = rssrai.RssraiSegmentation(args, split='train') 9 | val_set = rssrai.RssraiSegmentation(args, split='val') 10 | test_set = rssrai.RssraiSegmentation(args, split='test') 11 | num_class = train_set.NUM_CLASSES 12 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 13 | val_loader = DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) 14 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 15 | 16 | return train_loader, val_loader, test_loader, num_class 17 | 18 | else: 19 | raise NotImplementedError 20 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import random 4 | import numpy as np 5 | import cv2 6 | 7 | from PIL import Image, ImageOps, ImageFilter 8 | 9 | 10 | class Normalize(object): 11 | """Normalize a tensor image with mean and standard deviation. 12 | Args: 13 | mean (tuple): means for each channel. 14 | std (tuple): standard deviations for each channel. 15 | """ 16 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 17 | self.mean = mean 18 | self.std = std 19 | 20 | def __call__(self, sample): 21 | img = sample['image'] 22 | mask = sample['label'] 23 | img = np.array(img).astype(np.float32) 24 | mask = np.array(mask).astype(np.float32) 25 | img /= 255.0 26 | img -= self.mean 27 | img /= self.std 28 | 29 | sample['image'] = img 30 | sample['label'] = mask 31 | 32 | return sample 33 | 34 | 35 | class ToTensor(object): 36 | """Convert ndarrays in sample to Tensors.""" 37 | 38 | def __call__(self, sample): 39 | # swap color axis because 40 | # numpy image: H x W x C 41 | # torch image: C X H X W 42 | img = sample['image'] 43 | mask = sample['label'] 44 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 45 | mask = np.array(mask).astype(np.float32) 46 | 47 | img = torch.from_numpy(img).float() 48 | mask = torch.from_numpy(mask).float() 49 | 50 | sample['image'] = img 51 | sample['label'] = mask 52 | 53 | return sample 54 | 55 | 56 | class RandomHorizontalFlip(object): 57 | def __call__(self, sample): 58 | img = sample['image'] 59 | mask = sample['label'] 60 | if random.random() < 0.25: 61 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 62 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 63 | 64 | return {'image': img, 65 | 'label': mask} 66 | 67 | 68 | class RandomVerticalFlip(object): 69 | def __call__(self, sample): 70 | img = sample['image'] 71 | mask = sample['label'] 72 | if random.random() < 0.25: 73 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 74 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 75 | 76 | return {'image': img, 77 | 'label': mask} 78 | 79 | 80 | class RandomRotate(object): 81 | def __init__(self,): 82 | self.degree = [0, 90, 180, 270] 83 | 84 | def __call__(self, sample): 85 | img = sample['image'] 86 | mask = sample['label'] 87 | index = random.randint(0, 3) 88 | if index != 0: 89 | rotate_degree = self.degree[index] 90 | img = img.rotate(rotate_degree, Image.BILINEAR) 91 | mask = mask.rotate(rotate_degree, Image.NEAREST) 92 | 93 | return {'image': img, 94 | 'label': mask} 95 | 96 | 97 | class RandomGammaTransform(object): 98 | def __call__(self, sample): 99 | img = sample['image'] 100 | img_np = np.array(img, dtype=np.uint8) 101 | alpha = np.random.uniform(-np.e, np.e) 102 | gamma = np.exp(alpha) 103 | gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)] 104 | gamma_table = np.round(np.array(gamma_table)).astype(np.uint8) 105 | img_np = cv2.LUT(img_np, gamma_table) 106 | img = Image.fromarray(img_np, mode='CMYK') 107 | sample['image'] = img 108 | return sample 109 | 110 | 111 | class RandomGaussianBlur(object): 112 | def __call__(self, sample): 113 | img = sample['image'] 114 | mask = sample['label'] 115 | if random.random() < 0.5: 116 | img = img.filter(ImageFilter.GaussianBlur( 117 | radius=random.randint(2, 5))) 118 | 119 | return {'image': img, 120 | 'label': mask} 121 | 122 | 123 | class RandomNoise(object): 124 | def __call__(self, sample): 125 | img = sample['image'] 126 | w, h = img.size 127 | img_np = np.array(img, dtype=np.uint8) 128 | for i in range(5000): # 噪声点个数 129 | x = np.random.randint(0, w) 130 | y = np.random.randint(0, h) 131 | img_np[x, y] = 255 132 | img = Image.fromarray(img_np, mode='CMYK') 133 | sample['image'] = img 134 | return sample 135 | 136 | 137 | class RandomScaleCrop(object): 138 | def __init__(self, base_size, crop_size, fill=0): 139 | self.base_size = base_size 140 | self.crop_size = crop_size 141 | self.fill = fill 142 | 143 | def __call__(self, sample): 144 | img = sample['image'] 145 | mask = sample['label'] 146 | # random scale (short edge) 147 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 148 | w, h = img.size 149 | if h > w: 150 | ow = short_size 151 | oh = int(1.0 * h * ow / w) 152 | else: 153 | oh = short_size 154 | ow = int(1.0 * w * oh / h) 155 | img = img.resize((ow, oh), Image.BILINEAR) 156 | mask = mask.resize((ow, oh), Image.NEAREST) 157 | # pad crop 158 | if short_size < self.crop_size: 159 | padh = self.crop_size - oh if oh < self.crop_size else 0 160 | padw = self.crop_size - ow if ow < self.crop_size else 0 161 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 162 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 163 | # random crop crop_size 164 | w, h = img.size 165 | x1 = random.randint(0, w - self.crop_size) 166 | y1 = random.randint(0, h - self.crop_size) 167 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 168 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 169 | 170 | return {'image': img, 171 | 'label': mask} 172 | 173 | 174 | class FixScaleCrop(object): 175 | def __init__(self, crop_size): 176 | self.crop_size = crop_size 177 | 178 | def __call__(self, sample): 179 | img = sample['image'] 180 | mask = sample['label'] 181 | w, h = img.size 182 | if w > h: 183 | oh = self.crop_size 184 | ow = int(1.0 * w * oh / h) 185 | else: 186 | ow = self.crop_size 187 | oh = int(1.0 * h * ow / w) 188 | img = img.resize((ow, oh), Image.BILINEAR) 189 | mask = mask.resize((ow, oh), Image.NEAREST) 190 | # center crop 191 | w, h = img.size 192 | x1 = int(round((w - self.crop_size) / 2.)) 193 | y1 = int(round((h - self.crop_size) / 2.)) 194 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 195 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 196 | 197 | return {'image': img, 198 | 'label': mask} 199 | 200 | 201 | class FixedResize(object): 202 | def __init__(self, size): 203 | self.size = (size, size) # size: (h, w) 204 | 205 | def __call__(self, sample): 206 | img = sample['image'] 207 | mask = sample['label'] 208 | 209 | assert img.size == mask.size 210 | 211 | img = img.resize(self.size, Image.BILINEAR) 212 | mask = mask.resize(self.size, Image.NEAREST) 213 | 214 | sample['image'] = img 215 | sample['label'] = mask 216 | 217 | return sample 218 | -------------------------------------------------------------------------------- /dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: __init__.py.py 6 | @time: 2019/7/26 下午2:31 7 | """ 8 | -------------------------------------------------------------------------------- /dataloaders/datasets/rssrai.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | from PIL import Image 4 | from torch.utils import data 5 | from mypath import Path 6 | from torchvision import transforms 7 | from dataloaders import custom_transforms as tr 8 | 9 | 10 | class RssraiSegmentation(data.Dataset): 11 | NUM_CLASSES = 16 12 | 13 | def __init__(self, args, root=Path.db_root_dir('rssrai2019'), split="train"): 14 | self.root = root 15 | self.split = split 16 | self.args = args 17 | self.files = {} 18 | 19 | self.images_base = os.path.join(self.root, 'image', self.split+'_mix') 20 | self.annotations_base = os.path.join(self.root, 'label', self.split+'_mix_id_image') 21 | 22 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.tif') 23 | 24 | self.classes = [0, 1, 2, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 25 | self.class_names = ['其他类别', '水 田', '水 浇地', '旱 耕地', '园 地', '乔木林地', '灌木林地', '天然草地', '人工草地', 26 | '工业用地', '城市住宅', '村镇住宅', '交通运输', '河 流', '湖 泊', '坑 塘'] 27 | 28 | self.ignore_index = 255 29 | 30 | if not self.files[split]: 31 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 32 | 33 | print("Found %d %s images" % (len(self.files[split]), split)) 34 | 35 | def __len__(self): 36 | return len(self.files[self.split]) 37 | 38 | def __getitem__(self, index): 39 | img_path = self.files[self.split][index].rstrip() 40 | _img = Image.open(img_path) 41 | 42 | if self.split != 'test': 43 | lbl_path = os.path.join(self.annotations_base, os.path.basename(img_path)[:-4] + '_labelTrainIds.png') 44 | _target = Image.open(lbl_path) 45 | sample = {'image': _img, 'label': _target} 46 | else: 47 | sample = {'image': _img, 'label': _img, 'img_path': img_path} # We do not have test label 48 | 49 | if self.split == 'train': 50 | return self.transform_train(sample) 51 | elif self.split == 'val': 52 | return self.transform_val(sample) 53 | elif self.split == 'test': 54 | return self.transform_test(sample) 55 | 56 | @staticmethod 57 | def recursive_glob(rootdir='.', suffix=''): 58 | """Performs recursive glob with given suffix and rootdir 59 | :param rootdir is the root directory 60 | :param suffix is the suffix to be searched 61 | """ 62 | return [os.path.join(looproot, filename) 63 | for looproot, _, filenames in os.walk(rootdir) 64 | for filename in filenames if filename.endswith(suffix)] 65 | 66 | def transform_train(self, sample): 67 | composed_transforms = transforms.Compose([ 68 | tr.RandomHorizontalFlip(), 69 | tr.RandomVerticalFlip(), 70 | # tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), 71 | # tr.FixedResize(size=self.args.crop_size), 72 | tr.RandomRotate(), 73 | tr.RandomGammaTransform(), 74 | tr.RandomGaussianBlur(), 75 | tr.RandomNoise(), 76 | tr.Normalize(mean=(0.544650, 0.352033, 0.384602, 0.352311), std=(0.249456, 0.241652, 0.228824, 0.227583)), 77 | tr.ToTensor()]) 78 | 79 | return composed_transforms(sample) 80 | 81 | def transform_val(self, sample): 82 | composed_transforms = transforms.Compose([ 83 | # tr.FixScaleCrop(crop_size=self.args.crop_size), 84 | # tr.FixedResize(size=self.args.crop_size), 85 | tr.Normalize(mean=(0.544650, 0.352033, 0.384602, 0.352311), std=(0.249456, 0.241652, 0.228824, 0.227583)), 86 | tr.ToTensor()]) 87 | 88 | return composed_transforms(sample) 89 | 90 | def transform_test(self, sample): 91 | composed_transforms = transforms.Compose([ 92 | # tr.FixedResize(size=self.args.crop_size), 93 | tr.Normalize(mean=(0.544650, 0.352033, 0.384602, 0.352311), std=(0.249456, 0.241652, 0.228824, 0.227583)), 94 | tr.ToTensor()]) 95 | 96 | return composed_transforms(sample) 97 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/RandomCropTiffImage.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: RandomCropTiffImage.py 6 | @time: 2019/7/23 上午9:52 7 | """ 8 | import numpy as np 9 | import os 10 | from PIL import Image, ImageOps 11 | from tqdm import tqdm 12 | 13 | 14 | def random_scale_crop(filename, base_size, crop_size, crop_num, path_dict, postfix): 15 | img_raw = Image.open(os.path.join(path_dict['img_dir'], filename)) 16 | color_img_raw = Image.open(os.path.join(path_dict['color_dir'], filename)) 17 | id_img_raw = Image.open(os.path.join(path_dict['id__dir'], filename[:-3] + 'png')) 18 | 19 | for i in tqdm(range(crop_num)): 20 | img = img_raw.copy() 21 | color_img = color_img_raw.copy() 22 | id_img = id_img_raw.copy() 23 | 24 | # random scale (short edge) 25 | short_size = np.random.randint(int(base_size * 0.5), int(base_size * 2.0)) 26 | w, h = img.size 27 | if h > w: 28 | ow = short_size 29 | oh = int(1.0 * h * ow / w) 30 | else: 31 | oh = short_size 32 | ow = int(1.0 * w * oh / h) 33 | img = img.resize((ow, oh), Image.BILINEAR) 34 | color_img = color_img.resize((ow, oh), Image.BILINEAR) 35 | id_img = id_img.resize((ow, oh), Image.NEAREST) 36 | # pad crop 37 | if short_size < crop_size: 38 | padh = crop_size - oh if oh < crop_size else 0 39 | padw = crop_size - ow if ow < crop_size else 0 40 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 41 | color_img = ImageOps.expand(color_img, border=(0, 0, padw, padh), fill=0) 42 | id_img = ImageOps.expand(id_img, border=(0, 0, padw, padh), fill=255) 43 | # random crop crop_size 44 | w, h = img.size 45 | x1 = np.random.randint(0, w - crop_size) 46 | y1 = np.random.randint(0, h - crop_size) 47 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 48 | color_img = color_img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 49 | id_img = id_img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 50 | # set filename 51 | img_filename = filename[:-4] + '_%04d.tif' % i 52 | id_img_filename = filename[:-4] + '_%04d_labelTrainIds.png' % i 53 | # save file 54 | img.save(os.path.join(path_dict['img_random_crop_dir'], img_filename)) 55 | color_img.save(os.path.join(path_dict['color_random_crop_dir'], img_filename)) 56 | id_img.save(os.path.join(path_dict['id_random_crop_dir'], id_img_filename)) 57 | 58 | 59 | if __name__ == "__main__": 60 | paths = {'img_dir': r'/home/lab/ygy/rssrai2019/datasets/image/train', 61 | 'img_random_crop_dir': r'/home/lab/ygy/rssrai2019/datasets/image/train_random_crop', 62 | 'color_dir': r'/home/lab/ygy/rssrai2019/datasets/label/train', 63 | 'color_random_crop_dir': r'/home/lab/ygy/rssrai2019/datasets/label/train_random_crop', 64 | 'id__dir': r'/home/lab/ygy/rssrai2019/datasets/label/train_id_image', 65 | 'id_random_crop_dir': r'/home/lab/ygy/rssrai2019/datasets/label/train_random_crop_id_image' 66 | } 67 | 68 | postfix = '.tif' 69 | base_size = 6800 70 | crop_size = 400 71 | 72 | img_crop_num = 1250 73 | 74 | n = 1 75 | for root, dirs, files in os.walk(paths['img_dir']): 76 | for file in files: 77 | print('cropping the %d th image, filename=%s' % (n, file)) 78 | random_scale_crop(file, base_size, crop_size, img_crop_num, paths, postfix) 79 | n += 1 80 | print('crop finished') 81 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: __init__.py.py 6 | @time: 2019/7/26 下午2:52 7 | """ 8 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/calculateMeanStd.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | 6 | filepath = '/home/lab/ygy/rssrai2019/datasets/image/temp' # 数据集目录 7 | pathDir = os.listdir(filepath) 8 | 9 | N_channel = 0 10 | R_channel = 0 11 | G_channel = 0 12 | B_channel = 0 13 | for idx in range(len(pathDir)): 14 | filename = pathDir[idx] 15 | img = Image.open(os.path.join(filepath, filename)) 16 | img_np = np.array(img, dtype=np.uint8) / 255 17 | N_channel = N_channel + np.sum(img_np[:, :, 0]) 18 | R_channel = R_channel + np.sum(img_np[:, :, 1]) 19 | G_channel = G_channel + np.sum(img_np[:, :, 2]) 20 | B_channel = B_channel + np.sum(img_np[:, :, 3]) 21 | 22 | num = len(pathDir) * 400 * 400 # 这里(400,400)是每幅图片的大小,所有图片尺寸都一样 23 | N_mean = N_channel / num 24 | R_mean = R_channel / num 25 | G_mean = G_channel / num 26 | B_mean = B_channel / num 27 | 28 | N_channel = 0 29 | R_channel = 0 30 | G_channel = 0 31 | B_channel = 0 32 | for idx in range(len(pathDir)): 33 | filename = pathDir[idx] 34 | img = Image.open(os.path.join(filepath, filename)) 35 | img_np = np.array(img, dtype=np.uint8) / 255 36 | N_channel = N_channel + np.sum((img_np[:, :, 0] - N_mean) ** 2) 37 | R_channel = R_channel + np.sum((img_np[:, :, 1] - R_mean) ** 2) 38 | G_channel = G_channel + np.sum((img_np[:, :, 2] - G_mean) ** 2) 39 | B_channel = B_channel + np.sum((img_np[:, :, 3] - B_mean) ** 2) 40 | 41 | N_var = N_channel / num 42 | R_var = R_channel / num 43 | G_var = G_channel / num 44 | B_var = B_channel / num 45 | 46 | N_std = np.sqrt(N_var) 47 | R_std = np.sqrt(R_var) 48 | G_std = np.sqrt(G_var) 49 | B_std = np.sqrt(B_var) 50 | 51 | # mean = (0.544650, 0.352033, 0.384602, 0.352311) 52 | print("N_mean is %f, R_mean is %f, G_mean is %f, B_mean is %f" % (N_mean, R_mean, G_mean, B_mean)) 53 | # var = (0.062228, 0.058396, 0.052360, 0.051794) 54 | print("N_var is %f, R_var is %f, G_var is %f, B_var is %f" % (N_var, R_var, G_var, B_var)) 55 | # std = (0.249456, 0.241652, 0.228824, 0.227583) 56 | print("N_std is %f, R_std is %f, G_std is %f, B_std is %f" % (N_std, R_std, G_std, B_std)) 57 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/createTrainIdLabelImgs.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Converts the color images of the rssrai2019 dataset 3 | # to labelTrainIds images, where pixel values encode ground truth classes. 4 | # 5 | # The rssrai2019 downloads already include such images 6 | # a) *color.tif : the class is encoded by its color 7 | # 8 | # With this tool, you can generate option 9 | # b) *labelTrainIds.png : the class is encoded by its training ID 10 | # This encoding might come handy for training purposes. You can use 11 | # the file labes.py to define the training IDs that suit your needs. 12 | # Note however, that once you submit or evaluate results, the regular 13 | # IDs are needed. 14 | # 15 | # Uses the mapping defined in 'labels.py' 16 | # 17 | import numpy as np 18 | import os 19 | from dataloaders.datasets.utils.labels import color2label 20 | 21 | from PIL import Image 22 | 23 | 24 | def create_train_id_imgs(filename, source_dir, target_dir): 25 | img_pil = Image.open(os.path.join(source_dir, filename)) 26 | img_np = np.array(img_pil, dtype=np.uint8) 27 | 28 | rows, cols = img_np.shape[0], img_np.shape[1] 29 | train_id_img_np = np.zeros((rows, cols), dtype=np.uint8) 30 | 31 | for i in range(rows): 32 | for j in range(cols): 33 | color = (img_np[i, j, 0], img_np[i, j, 1], img_np[i, j, 2]) 34 | train_id = color2label[color].trainId 35 | train_id_img_np[i, j] = train_id 36 | train_id_img = Image.fromarray(train_id_img_np) 37 | train_id_img_filename = filename[:-4] + '_labelTrainIds.png' 38 | train_id_img.save(os.path.join(target_dir, train_id_img_filename)) 39 | 40 | 41 | if __name__ == "__main__": 42 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_crop' 43 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_crop_id_image' 44 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_crop' 45 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_crop_id_image' 46 | 47 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_mix' 48 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_mix_id_image' 49 | source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_mix' 50 | target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_mix_id_image' 51 | 52 | n = 1 53 | for root, _, files in os.walk(source_dir): 54 | for file in files: 55 | if os.path.splitext(file)[1] == '.tif': 56 | print('creating the %d th labelTrainIds image, filename=%s' % (n, file)) 57 | create_train_id_imgs(file, source_dir, target_dir) 58 | n += 1 59 | print('create finished') 60 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/cropTIFFImg.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | 6 | 7 | def crop_tiff_image(filename, source_dir, stride, crop_size, target_dir, postfix): 8 | img_pil = Image.open(os.path.join(source_dir, filename)) 9 | img_np = np.array(img_pil, dtype=np.uint8) 10 | 11 | rows, cols = img_np.shape[0], img_np.shape[1] 12 | crop_rows, crop_cols = crop_size[0], crop_size[1] 13 | 14 | if (rows - crop_rows) % stride != 0 or (cols - crop_cols) % stride != 0: 15 | raise ValueError('Inappropriate crop size {}' % crop_size) 16 | 17 | rows_num = (rows - crop_rows) // stride + 1 18 | cols_num = (cols - crop_cols) // stride + 1 19 | 20 | n = 0 21 | for i in range(rows_num): 22 | for j in range(cols_num): 23 | crop_img_pil = None 24 | if len(img_np.shape) == 2: 25 | crop_img_np = img_np[i * stride:i * stride + crop_rows, j * stride:j * stride + crop_cols] 26 | crop_img_pil = Image.fromarray(crop_img_np) 27 | else: 28 | crop_img_np = img_np[i * stride:i * stride + crop_rows, j * stride:j * stride + crop_cols, :] 29 | if img_np.shape[2] == 3: 30 | crop_img_pil = Image.fromarray(crop_img_np, mode='RGB') 31 | elif img_np.shape[2] == 4: 32 | crop_img_pil = Image.fromarray(crop_img_np, mode='CMYK') 33 | crop_img_filename = os.path.splitext(filename)[0] + '_' + str(n) + postfix 34 | # 将原始训练集、验证集分割图片打散, 组成新的训练集、验证集 35 | # if n % 5 == 4: 36 | # crop_img_pil.save(os.path.join(val_mix_dir, crop_img_filename)) 37 | # else: 38 | # crop_img_pil.save(os.path.join(train_mix_dir, crop_img_filename)) 39 | crop_img_pil.save(os.path.join(target_dir, crop_img_filename)) 40 | n += 1 41 | 42 | 43 | if __name__ == "__main__": 44 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_id_image' 45 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_id_image_crop' 46 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_id_image' 47 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_id_image_crop' 48 | # stride = 400 49 | # postfix = '.png' 50 | 51 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/image/train' 52 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/image/train_crop' 53 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/image/val' 54 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/image/val_crop' 55 | # train_mix_dir = r'/home/lab/ygy/rssrai2019/datasets/image/train_mix' 56 | # val_mix_dir = r'/home/lab/ygy/rssrai2019/datasets/image/val_mix' 57 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/image/test' 58 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/image/test_crop' 59 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train' 60 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_crop' 61 | # source_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val' 62 | # target_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_crop' 63 | # train_mix_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_mix' 64 | # val_mix_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_mix' 65 | # stride = 400 66 | source_dir = r'/home/lab/ygy/rssrai2019/datasets/image/test' 67 | target_dir = r'/home/lab/ygy/rssrai2019/datasets/image/test_overlay_crop' 68 | stride = 200 69 | 70 | postfix = '.tif' 71 | 72 | crop_size = [400, 400] 73 | 74 | n = 1 75 | for root, dirs, files in os.walk(source_dir): 76 | for file in files: 77 | if os.path.splitext(file)[1] == postfix: 78 | print('cropping the %d th image, filename=%s' % (n, file)) 79 | crop_tiff_image(file, root, stride, crop_size, target_dir, postfix) 80 | n += 1 81 | print('crop finished') 82 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/labels.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # 3 | # rssrai2019 labels 4 | # 5 | from collections import namedtuple 6 | 7 | # -------------------------------------------------------------------------------- 8 | # Definitions 9 | # -------------------------------------------------------------------------------- 10 | 11 | # a label and all meta information 12 | Label = namedtuple('Label', [ 13 | 'name', # The identifier of this label, e.g. 'car', 'person', ... . 14 | # We use them to uniquely name a class 15 | 'id', # An integer ID that is associated with this label. 16 | # The IDs are used to represent the label in ground truth images 17 | # An ID of -1 means that this label does not have an ID and thus 18 | # is ignored when creating ground truth images (e.g. license plate). 19 | # Do not modify these IDs, since exactly these IDs are expected by the 20 | # evaluation server. 21 | 'trainId', # Feel free to modify these IDs as suitable for your method. Then create 22 | # ground truth images with train IDs, using the tools provided in the 23 | # 'preparation' folder. However, make sure to validate or submit results 24 | # to our evaluation server using the regular IDs above! 25 | # For trainIds, multiple labels might have the same ID. Then, these labels 26 | # are mapped to the same class in the ground truth images. For the inverse 27 | # mapping, we use the label that is defined first in the list below. 28 | # For example, mapping all void-type classes to the same ID in training, 29 | # might make sense for some approaches. 30 | # Max value is 255! 31 | 'category', # The name of the category that this label belongs to 32 | 'categoryId', # The ID of this category. Used to create ground truth images 33 | # on category level. 34 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 35 | # during evaluations or not 36 | 'color', # The color of this label 37 | ]) 38 | 39 | 40 | # -------------------------------------------------------------------------------- 41 | # A list of all labels 42 | # -------------------------------------------------------------------------------- 43 | 44 | # Please adapt the train IDs as appropriate for your approach. 45 | # Note that you might want to ignore labels with ID 255 during training. 46 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 47 | # Make sure to provide your results using the original IDs and not the training IDs. 48 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 49 | 50 | labels = [ 51 | # name id trainId category catId ignoreInEval color 52 | Label('其他类别', 0, 0, 'void', 0, False, (0, 0, 0)), 53 | 54 | Label('水 田', 1, 1, 'farmland', 1, False, (0, 200, 0)), 55 | Label('水 浇地', 2, 2, 'farmland', 1, False, (150, 250, 0)), 56 | Label('旱 耕地', 3, 3, 'farmland', 1, False, (150, 200, 150)), 57 | 58 | Label('园 地', 4, 4, 'woodland', 2, False, (200, 0, 200)), 59 | Label('乔木林地', 5, 5, 'woodland', 2, False, (150, 0, 250)), 60 | Label('灌木林地', 6, 6, 'woodland', 2, False, (150, 150, 250)), 61 | 62 | Label('天然草地', 7, 7, 'grassland', 3, False, (250, 200, 0)), 63 | Label('人工草地', 8, 8, 'grassland', 3, False, (200, 200, 0)), 64 | 65 | Label('工业用地', 9, 9, 'urbanland', 4, False, (200, 0, 0)), 66 | Label('城市住宅', 10, 10, 'urbanland', 4, False, (250, 0, 150)), 67 | Label('村镇住宅', 11, 11, 'urbanland', 4, False, (200, 150, 150)), 68 | Label('交通运输', 12, 12, 'urbanland', 4, False, (250, 150, 150)), 69 | 70 | Label('河 流', 13, 13, 'waterland', 5, False, (0, 0, 200)), 71 | Label('湖 泊', 14, 14, 'waterland', 5, False, (0, 150, 200)), 72 | Label('坑 塘', 15, 15, 'waterland', 5, False, (0, 200, 250)), 73 | ] 74 | 75 | 76 | # -------------------------------------------------------------------------------- 77 | # Create dictionaries for a fast lookup 78 | # -------------------------------------------------------------------------------- 79 | 80 | # Please refer to the main method below for example usages! 81 | 82 | # name to label object 83 | name2label = {label.name: label for label in labels} 84 | # id to label object 85 | id2label = {label.id: label for label in labels} 86 | # color to label object 87 | color2label = {label.color: label for label in labels} 88 | # trainId to label object 89 | trainId2label = {label.trainId: label for label in reversed(labels)} 90 | # category to list of label objects 91 | category2labels = {} 92 | for label in labels: 93 | category = label.category 94 | if category in category2labels: 95 | category2labels[category].append(label) 96 | else: 97 | category2labels[category] = [label] 98 | 99 | 100 | # -------------------------------------------------------------------------------- 101 | # Main for testing 102 | # -------------------------------------------------------------------------------- 103 | 104 | 105 | if __name__ == "__main__": 106 | # Print all the labels 107 | print("List of rssrai2019 labels:") 108 | print("") 109 | print("{:>15} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12}".format('name', 'id', 'trainId', 'category', 'categoryId', 110 | 'ignoreInEval')) 111 | print(" " + ('-' * 98)) 112 | for label in labels: 113 | print("{:>15} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12}".format(label.name, label.id, label.trainId, 114 | label.category, label.categoryId, 115 | label.ignoreInEval)) 116 | print("") 117 | 118 | print("Example usages:") 119 | 120 | # Map from name to label 121 | name = '水 田' 122 | id = name2label[name].id 123 | print("ID of label '{name}': {id}".format(name=name, id=id)) 124 | 125 | # Map from ID to label 126 | category = id2label[id].category 127 | print("Category of label with ID '{id}': {category}".format(id=id, category=category)) 128 | 129 | # Map from trainID to label 130 | trainId = 0 131 | name = trainId2label[trainId].name 132 | print("Name of label with trainID '{id}': {name}".format(id=trainId, name=name)) 133 | -------------------------------------------------------------------------------- /dataloaders/datasets/utils/stitchTestImg.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | 将裁剪的图片拼接回大图 4 | """ 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | 9 | test_img_num = 10 10 | 11 | rows, cols = 6800, 7200 12 | # stride = 400 13 | stride = 200 14 | crop_rows, crop_cols = 400, 400 15 | rows_num = (rows - crop_rows) // stride + 1 16 | cols_num = (cols - crop_cols) // stride + 1 17 | 18 | test_img_name = [ 19 | 'GF2_PMS1__20150902_L1A0001015646-MSS1', 20 | 'GF2_PMS1__20150902_L1A0001015648-MSS1', 21 | 'GF2_PMS1__20150912_L1A0001037899-MSS1', 22 | 'GF2_PMS1__20150926_L1A0001064469-MSS1', 23 | 'GF2_PMS1__20160327_L1A0001491484-MSS1', 24 | 'GF2_PMS1__20160430_L1A0001553848-MSS1', 25 | 'GF2_PMS1__20160623_L1A0001660727-MSS1', 26 | 'GF2_PMS1__20160627_L1A0001668483-MSS1', 27 | 'GF2_PMS1__20160704_L1A0001680853-MSS1', 28 | 'GF2_PMS1__20160801_L1A0001734328-MSS1' 29 | ] 30 | 31 | 32 | def stitch_test_img(color_dir, stitch_dir): 33 | files = os.listdir(color_dir) 34 | crop_num = len(files) // len(test_img_name) 35 | 36 | for i in range(test_img_num): 37 | # 拼接单张大图 38 | test_img_np = np.zeros((rows, cols, 3), dtype=np.uint8) 39 | row, col = 0, 0 40 | for j in range(crop_num): 41 | # 读取每张小图 42 | crop_img_name = os.path.join(color_dir, test_img_name[i]+'_'+str(j)+'.tif') 43 | crop_img_pil = Image.open(crop_img_name) 44 | crop_img_pil = crop_img_pil.resize((crop_rows, crop_cols), Image.NEAREST) 45 | crop_img_np = np.array(crop_img_pil, dtype=np.uint8) 46 | # 将小图放入大图中 47 | a0 = row * stride 48 | a1 = a0 + crop_rows 49 | b0 = col * stride 50 | b1 = b0 + crop_cols 51 | test_img_np[a0:a1, b0:b1, :] = crop_img_np 52 | # 更新行列 53 | col += 1 54 | if j != 0 and (j+1) % cols_num == 0: 55 | row += 1 56 | col = 0 57 | # 保存图片 58 | save_file_name = os.path.join(stitch_dir, test_img_name[i]+'_label.tif') 59 | test_img_pil = Image.fromarray(test_img_np, mode='RGB') 60 | test_img_pil.save(save_file_name) 61 | 62 | 63 | if __name__ == "__main__": 64 | vis_color_dir = '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/combine_net/vis_log/vis_color' 65 | stitch_img_dir = '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/combine_net/vis_log/stitch_img' 66 | stitch_test_img(vis_color_dir, stitch_img_dir) 67 | 68 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def decode_seg_map_sequence(label_masks, dataset='rssrai2019'): 8 | rgb_masks = [] 9 | for label_mask in label_masks: 10 | rgb_mask = decode_segmap(label_mask, dataset) 11 | rgb_masks.append(rgb_mask) 12 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 13 | return rgb_masks 14 | 15 | 16 | def decode_segmap(label_mask, dataset, plot=False): 17 | """Decode segmentation class labels into a color image 18 | Args: 19 | label_mask (np.ndarray): an (M,N) array of integer values denoting 20 | the class label at each spatial location. 21 | dataset 22 | plot (bool, optional): whether to show the resulting color image 23 | in a figure. 24 | Returns: 25 | (np.ndarray, optional): the resulting decoded color image. 26 | """ 27 | if dataset == 'rssrai2019': 28 | n_classes = 16 29 | label_colours = get_rssrai_labels() 30 | else: 31 | raise NotImplementedError 32 | 33 | r = label_mask.copy() 34 | g = label_mask.copy() 35 | b = label_mask.copy() 36 | for ll in range(0, n_classes): 37 | r[label_mask == ll] = label_colours[ll, 0] 38 | g[label_mask == ll] = label_colours[ll, 1] 39 | b[label_mask == ll] = label_colours[ll, 2] 40 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3), dtype=np.uint8) 41 | rgb[:, :, 0] = r 42 | rgb[:, :, 1] = g 43 | rgb[:, :, 2] = b 44 | if plot: 45 | plt.imshow(rgb) 46 | plt.show() 47 | else: 48 | return rgb 49 | 50 | 51 | def get_rssrai_labels(): 52 | return np.array([ 53 | [0, 0, 0], 54 | [0, 200, 0], 55 | [150, 250, 0], 56 | [150, 200, 150], 57 | [200, 0, 200], 58 | [150, 0, 250], 59 | [150, 150, 250], 60 | [250, 200, 0], 61 | [200, 200, 0], 62 | [200, 0, 0], 63 | [250, 0, 150], 64 | [200, 150, 150], 65 | [250, 150, 150], 66 | [0, 0, 200], 67 | [0, 150, 200], 68 | [0, 200, 250]]) 69 | -------------------------------------------------------------------------------- /models/CombineNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 将UNet上单张图片多尺度预测的结果concat,连接2层3*3的卷积,得到融合后的预测结果 4 | @author:HuiYi or 会意 5 | @file: CombineNet.py 6 | @time: 2019/8/6 下午3:09 7 | """ 8 | import torch.nn as nn 9 | from models.utils.utils import init_weights 10 | 11 | 12 | class CombineNet(nn.Module): 13 | def __init__(self, in_channels=96, n_classes=16,): 14 | super(CombineNet, self).__init__() 15 | 16 | self.conv1 = nn.Conv2d(in_channels, n_classes * 3, 3, 1, 1) 17 | self.conv2 = nn.Conv2d(n_classes * 3, n_classes, 3, 1, 1) 18 | 19 | # initialise the blocks 20 | for m in self.modules(): 21 | init_weights(m, init_type='kaiming') 22 | 23 | def forward(self, inputs): 24 | conv1 = self.conv1(inputs) 25 | conv2 = self.conv2(conv1) 26 | return conv2 27 | 28 | def get_params(self): 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | for p in m.parameters(): 32 | if p.requires_grad: 33 | yield p 34 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | -------------------------------------------------------------------------------- /models/backbone/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils.layers import UnetConv2, UnetUp 4 | from models.utils.utils import init_weights, count_param 5 | from models.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, in_channels=1, n_classes=2, feature_scale=2, is_deconv=True, sync_bn=False): 10 | super(UNet, self).__init__() 11 | self.in_channels = in_channels 12 | self.feature_scale = feature_scale 13 | self.is_deconv = is_deconv 14 | 15 | if sync_bn: 16 | self.batchnorm = SynchronizedBatchNorm2d 17 | else: 18 | self.batchnorm = nn.BatchNorm2d 19 | 20 | filters = [64, 128, 256, 512, 1024] 21 | filters = [int(i / feature_scale) for i in filters] 22 | 23 | # downsampling 24 | self.maxpool = nn.MaxPool2d(kernel_size=2) 25 | self.conv1 = UnetConv2(self.in_channels, filters[0], self.batchnorm) 26 | self.conv2 = UnetConv2(filters[0], filters[1], self.batchnorm) 27 | self.conv3 = UnetConv2(filters[1], filters[2], self.batchnorm) 28 | self.conv4 = UnetConv2(filters[2], filters[3], self.batchnorm) 29 | self.center = UnetConv2(filters[3], filters[4], self.batchnorm) 30 | self.dropout = nn.Dropout(0.5) 31 | # upsampling 32 | self.up_concat4 = UnetUp(filters[4], filters[3], self.is_deconv) 33 | self.up_concat3 = UnetUp(filters[3], filters[2], self.is_deconv) 34 | self.up_concat2 = UnetUp(filters[2], filters[1], self.is_deconv) 35 | self.up_concat1 = UnetUp(filters[1], filters[0], self.is_deconv) 36 | # final conv (without any concat) 37 | self.final = nn.Conv2d(filters[0], n_classes, 1) 38 | 39 | # initialise weights 40 | self._init_weight() 41 | 42 | def forward(self, inputs): 43 | conv1 = self.conv1(inputs) # 16*512*512 44 | maxpool1 = self.maxpool(conv1) # 16*256*256 45 | 46 | conv2 = self.conv2(maxpool1) # 32*256*256 47 | maxpool2 = self.maxpool(conv2) # 32*128*128 48 | 49 | conv3 = self.conv3(maxpool2) # 64*128*128 50 | maxpool3 = self.maxpool(conv3) # 64*64*64 51 | 52 | conv4 = self.conv4(maxpool3) # 128*64*64 53 | maxpool4 = self.maxpool(conv4) # 128*32*32 54 | 55 | center = self.center(maxpool4) # 256*32*32 56 | center = self.dropout(center) 57 | 58 | up4 = self.up_concat4(center, conv4) # 128*64*64 59 | up3 = self.up_concat3(up4, conv3) # 64*128*128 60 | up2 = self.up_concat2(up3, conv2) # 32*256*256 61 | up1 = self.up_concat1(up2, conv1) # 16*512*512 62 | 63 | final = self.final(up1) 64 | 65 | return final 66 | 67 | def _init_weight(self): 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | nn.init.kaiming_normal_(m.weight) 71 | elif isinstance(m, nn.ConvTranspose2d): 72 | nn.init.kaiming_normal_(m.weight) 73 | elif isinstance(m, nn.UpsamplingBilinear2d): 74 | nn.init.kaiming_normal_(m.weight) 75 | elif isinstance(m, SynchronizedBatchNorm2d): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | 82 | def get_params(self): 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.UpsamplingBilinear2d) \ 85 | or isinstance(m, SynchronizedBatchNorm2d) or isinstance(m, nn.BatchNorm2d): 86 | for p in m.parameters(): 87 | if p.requires_grad: 88 | yield p 89 | 90 | 91 | if __name__ == '__main__': 92 | print('#### Test Case ###') 93 | from torch.autograd import Variable 94 | x = Variable(torch.rand(2, 1, 64, 64)).cuda() 95 | model = UNet().cuda() 96 | param = count_param(model) 97 | y = model(x) 98 | print('Output shape:', y.shape) 99 | print('UNet total parameters: %.2fM (%d)' % (param/1e6, param)) 100 | -------------------------------------------------------------------------------- /models/backbone/UNetNested.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils.layers import UnetConv2, UnetUp 4 | from models.utils.utils import count_param 5 | from models.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class UNetNested(nn.Module): 9 | def __init__(self, in_channels=1, n_classes=2, feature_scale=2, is_deconv=True, is_ds=True, sync_bn=False): 10 | super(UNetNested, self).__init__() 11 | self.in_channels = in_channels 12 | self.feature_scale = feature_scale 13 | self.is_deconv = is_deconv 14 | self.is_ds = is_ds 15 | 16 | if sync_bn: 17 | self.batchnorm = SynchronizedBatchNorm2d 18 | else: 19 | self.batchnorm = nn.BatchNorm2d 20 | 21 | filters = [64, 128, 256, 512, 1024] 22 | filters = [int(i / self.feature_scale) for i in filters] 23 | 24 | # downsampling 25 | self.maxpool = nn.MaxPool2d(kernel_size=2) 26 | self.conv00 = UnetConv2(self.in_channels, filters[0], self.batchnorm) 27 | self.conv10 = UnetConv2(filters[0], filters[1], self.batchnorm) 28 | self.conv20 = UnetConv2(filters[1], filters[2], self.batchnorm) 29 | self.conv30 = UnetConv2(filters[2], filters[3], self.batchnorm) 30 | self.conv40 = UnetConv2(filters[3], filters[4], self.batchnorm) 31 | 32 | # upsampling 33 | self.up_concat01 = UnetUp(filters[1], filters[0], self.is_deconv) 34 | self.up_concat11 = UnetUp(filters[2], filters[1], self.is_deconv) 35 | self.up_concat21 = UnetUp(filters[3], filters[2], self.is_deconv) 36 | self.up_concat31 = UnetUp(filters[4], filters[3], self.is_deconv) 37 | 38 | self.up_concat02 = UnetUp(filters[1], filters[0], self.is_deconv, 3) 39 | self.up_concat12 = UnetUp(filters[2], filters[1], self.is_deconv, 3) 40 | self.up_concat22 = UnetUp(filters[3], filters[2], self.is_deconv, 3) 41 | 42 | self.up_concat03 = UnetUp(filters[1], filters[0], self.is_deconv, 4) 43 | self.up_concat13 = UnetUp(filters[2], filters[1], self.is_deconv, 4) 44 | 45 | self.up_concat04 = UnetUp(filters[1], filters[0], self.is_deconv, 5) 46 | 47 | # final conv (without any concat) 48 | self.final_1 = nn.Conv2d(filters[0], n_classes, 1) 49 | self.final_2 = nn.Conv2d(filters[0], n_classes, 1) 50 | self.final_3 = nn.Conv2d(filters[0], n_classes, 1) 51 | self.final_4 = nn.Conv2d(filters[0], n_classes, 1) 52 | 53 | # initialise weights 54 | self._init_weight() 55 | 56 | def forward(self, inputs): 57 | # column : 0 58 | x_00 = self.conv00(inputs) # 16*512*512 59 | maxpool0 = self.maxpool(x_00) # 16*256*256 60 | x_10 = self.conv10(maxpool0) # 32*256*256 61 | maxpool1 = self.maxpool(x_10) # 32*128*128 62 | x_20 = self.conv20(maxpool1) # 64*128*128 63 | maxpool2 = self.maxpool(x_20) # 64*64*64 64 | x_30 = self.conv30(maxpool2) # 128*64*64 65 | maxpool3 = self.maxpool(x_30) # 128*32*32 66 | x_40 = self.conv40(maxpool3) # 256*32*32 67 | # column : 1 68 | x_01 = self.up_concat01(x_10, x_00) 69 | x_11 = self.up_concat11(x_20, x_10) 70 | x_21 = self.up_concat21(x_30, x_20) 71 | x_31 = self.up_concat31(x_40, x_30) 72 | # column : 2 73 | x_02 = self.up_concat02(x_11, x_00, x_01) 74 | x_12 = self.up_concat12(x_21, x_10, x_11) 75 | x_22 = self.up_concat22(x_31, x_20, x_21) 76 | # column : 3 77 | x_03 = self.up_concat03(x_12, x_00, x_01, x_02) 78 | x_13 = self.up_concat13(x_22, x_10, x_11, x_12) 79 | # column : 4 80 | x_04 = self.up_concat04(x_13, x_00, x_01, x_02, x_03) 81 | 82 | # final layer 83 | final_1 = self.final_1(x_01) 84 | final_2 = self.final_2(x_02) 85 | final_3 = self.final_3(x_03) 86 | final_4 = self.final_4(x_04) 87 | 88 | final = (final_1+final_2+final_3+final_4)/4 89 | 90 | if self.is_ds: 91 | return final 92 | else: 93 | return final_4 94 | 95 | def _init_weight(self): 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight) 99 | elif isinstance(m, nn.ConvTranspose2d): 100 | nn.init.kaiming_normal_(m.weight) 101 | elif isinstance(m, nn.UpsamplingBilinear2d): 102 | nn.init.kaiming_normal_(m.weight) 103 | elif isinstance(m, SynchronizedBatchNorm2d): 104 | m.weight.data.fill_(1) 105 | m.bias.data.zero_() 106 | elif isinstance(m, nn.BatchNorm2d): 107 | m.weight.data.fill_(1) 108 | m.bias.data.zero_() 109 | 110 | def get_params(self): 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.UpsamplingBilinear2d) \ 113 | or isinstance(m, SynchronizedBatchNorm2d) or isinstance(m, nn.BatchNorm2d): 114 | for p in m.parameters(): 115 | if p.requires_grad: 116 | yield p 117 | 118 | 119 | if __name__ == '__main__': 120 | print('#### Test Case ###') 121 | from torch.autograd import Variable 122 | x = Variable(torch.rand(2, 1, 64, 64)).cuda() 123 | model = UNetNested().cuda() 124 | param = count_param(model) 125 | y = model(x) 126 | print('Output shape:', y.shape) 127 | print('UNet++ total parameters: %.2fM (%d)' % (param/1e6, param)) 128 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: __init__.py.py 6 | @time: 2019/7/26 下午1:30 7 | """ 8 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /models/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | -------------------------------------------------------------------------------- /models/utils/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.utils.utils import init_weights 4 | 5 | 6 | class UnetConv2(nn.Module): 7 | def __init__(self, in_size, out_size, batchnorm=None, is_batchnorm=True, n=2, kernel_size=3, stride=1, padding=1): 8 | super(UnetConv2, self).__init__() 9 | self.n = n 10 | self.ks = kernel_size 11 | self.stride = stride 12 | self.padding = padding 13 | if is_batchnorm: 14 | for i in range(1, n+1): 15 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, stride, padding), 16 | batchnorm(out_size), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_size, out_size, kernel_size, stride, padding), 19 | batchnorm(out_size), 20 | nn.ReLU(inplace=True) 21 | ) 22 | setattr(self, 'conv%d' % i, conv) 23 | in_size = out_size 24 | else: 25 | for i in range(1, n + 1): 26 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size, stride, padding), 27 | nn.ReLU(inplace=True)) 28 | setattr(self, 'conv%d' % i, conv) 29 | in_size = out_size 30 | 31 | # initialise the blocks 32 | for m in self.modules(): 33 | init_weights(m, init_type='kaiming') 34 | 35 | def forward(self, inputs): 36 | x = inputs 37 | for i in range(1, self.n+1): 38 | conv = getattr(self, 'conv%d' % i) 39 | x = conv(x) 40 | 41 | return x 42 | 43 | 44 | class UnetUp(nn.Module): 45 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 46 | super(UnetUp, self).__init__() 47 | self.conv = UnetConv2(in_size+(n_concat-2)*out_size, out_size, is_batchnorm=False) 48 | if is_deconv: 49 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 50 | else: 51 | self.up = nn.Sequential( 52 | nn.UpsamplingBilinear2d(scale_factor=2), 53 | nn.Conv2d(in_size, out_size, 1)) 54 | 55 | # initialise the blocks 56 | for m in self.children(): 57 | if m.__class__.__name__.find('unetConv2') != -1: 58 | continue 59 | init_weights(m, init_type='kaiming') 60 | 61 | def forward(self, high_feature, *low_feature): 62 | outputs0 = self.up(high_feature) 63 | for feature in low_feature: 64 | outputs0 = torch.cat([outputs0, feature], 1) 65 | return self.conv(outputs0) 66 | -------------------------------------------------------------------------------- /models/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 3 | 4 | 5 | # initialize the module 6 | def init_weights(net, init_type='normal'): 7 | # print('initialization method [%s]' % init_type) 8 | if init_type == 'kaiming': 9 | net.apply(weights_init_kaiming) 10 | else: 11 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 12 | 13 | 14 | def weights_init_kaiming(m): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight) 17 | elif isinstance(m, nn.ConvTranspose2d): 18 | nn.init.kaiming_normal_(m.weight) 19 | elif isinstance(m, nn.UpsamplingBilinear2d): 20 | nn.init.kaiming_normal_(m.weight) 21 | elif isinstance(m, SynchronizedBatchNorm2d): 22 | m.weight.data.fill_(1) 23 | m.bias.data.zero_() 24 | elif isinstance(m, nn.BatchNorm2d): 25 | # nn.init.normal_(m.weight.data, 1.0, 0.02) 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | 29 | 30 | # compute model params 31 | def count_param(model): 32 | param_count = 0 33 | for param in model.parameters(): 34 | param_count += param.view(-1).size()[0] 35 | return param_count 36 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | class Path(object): 3 | @staticmethod 4 | def db_root_dir(dataset): 5 | if dataset == "rssrai2019": 6 | return '/home/lab/ygy/rssrai2019/datasets' # folder that contains rssrai2019 7 | else: 8 | print('Dataset {} not available.'.format(dataset)) 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | pillow 4 | py-opencv 5 | tensorboardX 6 | torch 7 | torchvision 8 | tqdm -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: vis.py.py 6 | @time: 2019/6/23 下午7:00 7 | """ 8 | import argparse 9 | import os 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | from mypath import Path 15 | from utils.saver import Saver 16 | from utils.summaries import TensorboardSummary 17 | from dataloaders import make_data_loader 18 | from models.backbone.UNet import UNet 19 | from models.backbone.UNetNested import UNetNested 20 | from utils.calculate_weights import calculate_weigths_labels 21 | from utils.loss import SegmentationLosses 22 | from utils.metrics import Evaluator 23 | # from utils.lr_scheduler import LR_Scheduler 24 | from models.sync_batchnorm.replicate import patch_replication_callback 25 | 26 | 27 | class Trainer(object): 28 | def __init__(self, args): 29 | self.args = args 30 | 31 | # Define Saver 32 | self.saver = Saver(args) 33 | self.saver.save_experiment_config() 34 | # Define Tensorboard Summary 35 | self.summary = TensorboardSummary(self.saver.experiment_dir) 36 | self.writer = self.summary.create_summary() 37 | 38 | # Define Dataloader 39 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 40 | self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 41 | 42 | model = None 43 | # Define network 44 | if self.args.backbone == 'unet': 45 | model = UNet(in_channels=4, n_classes=self.nclass, sync_bn=args.sync_bn) 46 | print("using UNet") 47 | if self.args.backbone == 'unetNested': 48 | model = UNetNested(in_channels=4, n_classes=self.nclass, sync_bn=args.sync_bn) 49 | print("using UNetNested") 50 | 51 | # train_params = [{'params': model.get_params(), 'lr': args.lr}] 52 | train_params = [{'params': model.get_params()}] 53 | 54 | # Define Optimizer 55 | # optimizer = torch.optim.SGD(train_params, momentum=args.momentum, 56 | # weight_decay=args.weight_decay, nesterov=args.nesterov) 57 | optimizer = torch.optim.Adam(train_params, self.args.learn_rate, weight_decay=args.weight_decay, amsgrad=True) 58 | 59 | # Define Criterion 60 | # whether to use class balanced weights 61 | if args.use_balanced_weights: 62 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') 63 | if os.path.isfile(classes_weights_path): 64 | weight = np.load(classes_weights_path) 65 | else: 66 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 67 | weight = torch.from_numpy(weight.astype(np.float32)) 68 | else: 69 | weight = None 70 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) 71 | self.model, self.optimizer = model, optimizer 72 | 73 | # Define Evaluator 74 | self.evaluator = Evaluator(self.nclass) 75 | # Define lr scheduler 76 | # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) 77 | 78 | # Using cuda 79 | if args.cuda: 80 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 81 | patch_replication_callback(self.model) 82 | self.model = self.model.cuda() 83 | 84 | # Resuming checkpoint 85 | self.best_pred = 0.0 86 | if args.resume is not None: 87 | if not os.path.isfile(args.resume): 88 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 89 | checkpoint = torch.load(args.resume) 90 | args.start_epoch = checkpoint['epoch'] 91 | if args.cuda: 92 | self.model.module.load_state_dict(checkpoint['state_dict']) 93 | else: 94 | self.model.load_state_dict(checkpoint['state_dict']) 95 | if not args.ft: 96 | self.optimizer.load_state_dict(checkpoint['optimizer']) 97 | self.best_pred = checkpoint['best_pred'] 98 | print("=> loaded checkpoint '{}' (epoch {})" 99 | .format(args.resume, checkpoint['epoch'])) 100 | 101 | # Clear start epoch if fine-tuning 102 | if args.ft: 103 | args.start_epoch = 0 104 | 105 | def training(self, epoch): 106 | print('[Epoch: %d, learning rate: %.6f, previous best = %.4f]' % (epoch, self.args.learn_rate, self.best_pred)) 107 | train_loss = 0.0 108 | self.model.train() 109 | self.evaluator.reset() 110 | tbar = tqdm(self.train_loader) 111 | num_img_tr = len(self.train_loader) 112 | 113 | for i, sample in enumerate(tbar): 114 | image, target = sample['image'], sample['label'] 115 | if self.args.cuda: 116 | image, target = image.cuda(), target.cuda() 117 | # self.scheduler(self.optimizer, i, epoch, self.best_pred) 118 | self.optimizer.zero_grad() 119 | output = self.model(image) 120 | loss = self.criterion(output, target) 121 | loss.backward() 122 | self.optimizer.step() 123 | train_loss += loss.item() 124 | tbar.set_description('Train loss: %.5f' % (train_loss / (i + 1))) 125 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 126 | 127 | pred = output.data.cpu().numpy() 128 | target = target.cpu().numpy() 129 | pred = np.argmax(pred, axis=1) 130 | # Add batch sample into evaluator 131 | self.evaluator.add_batch(target, pred) 132 | 133 | # Fast test during the training 134 | Acc = self.evaluator.Pixel_Accuracy() 135 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 136 | mIoU = self.evaluator.Mean_Intersection_over_Union() 137 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 138 | self.writer.add_scalar('train/mIoU', mIoU, epoch) 139 | self.writer.add_scalar('train/Acc', Acc, epoch) 140 | self.writer.add_scalar('train/Acc_class', Acc_class, epoch) 141 | self.writer.add_scalar('train/fwIoU', FWIoU, epoch) 142 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 143 | 144 | print('train validation:') 145 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 146 | print('Loss: %.3f' % train_loss) 147 | print('---------------------------------') 148 | 149 | def validation(self, epoch): 150 | test_loss = 0.0 151 | self.model.eval() 152 | self.evaluator.reset() 153 | tbar = tqdm(self.val_loader, desc='\r') 154 | num_img_val = len(self.val_loader) 155 | 156 | for i, sample in enumerate(tbar): 157 | image, target = sample['image'], sample['label'] 158 | if self.args.cuda: 159 | image, target = image.cuda(), target.cuda() 160 | with torch.no_grad(): 161 | output = self.model(image) 162 | loss = self.criterion(output, target) 163 | test_loss += loss.item() 164 | tbar.set_description('Test loss: %.5f' % (test_loss / (i + 1))) 165 | self.writer.add_scalar('val/total_loss_iter', loss.item(), i + num_img_val * epoch) 166 | pred = output.data.cpu().numpy() 167 | target = target.cpu().numpy() 168 | pred = np.argmax(pred, axis=1) 169 | # Add batch sample into evaluator 170 | self.evaluator.add_batch(target, pred) 171 | 172 | # Fast test during the training 173 | Acc = self.evaluator.Pixel_Accuracy() 174 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 175 | mIoU = self.evaluator.Mean_Intersection_over_Union() 176 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 177 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 178 | self.writer.add_scalar('val/mIoU', mIoU, epoch) 179 | self.writer.add_scalar('val/Acc', Acc, epoch) 180 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch) 181 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch) 182 | print('test validation:') 183 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 184 | print('Loss: %.3f' % test_loss) 185 | print('====================================') 186 | 187 | new_pred = mIoU 188 | if new_pred > self.best_pred: 189 | is_best = True 190 | self.best_pred = new_pred 191 | self.saver.save_checkpoint({ 192 | 'epoch': epoch + 1, 193 | 'state_dict': self.model.module.state_dict(), 194 | 'optimizer': self.optimizer.state_dict(), 195 | 'best_pred': self.best_pred, 196 | }, is_best) 197 | 198 | 199 | def main(): 200 | parser = argparse.ArgumentParser(description="PyTorch Unet Training") 201 | parser.add_argument('--backbone', type=str, default='unet', 202 | choices=['unet', 'unetNested'], 203 | help='backbone name (default: unet)') 204 | parser.add_argument('--dataset', type=str, default='rssrai2019', 205 | choices=['rssrai2019'], 206 | help='dataset name (default: rssrai2019)') 207 | parser.add_argument('--workers', type=int, default=4, 208 | metavar='N', help='dataloader threads') 209 | parser.add_argument('--base-size', type=int, default=400, 210 | help='base image size') 211 | parser.add_argument('--crop-size', type=int, default=400, 212 | help='crop image size') 213 | parser.add_argument('--sync-bn', type=bool, default=None, 214 | help='whether to use sync bn (default: auto)') 215 | parser.add_argument('--freeze-bn', type=bool, default=False, 216 | help='whether to freeze bn parameters (default: False)') 217 | parser.add_argument('--loss-type', type=str, default='ce', 218 | choices=['ce', 'focal'], 219 | help='loss func type (default: ce)') 220 | # training hyper params 221 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 222 | help='number of epochs to train (default: auto)') 223 | parser.add_argument('--start_epoch', type=int, default=0, metavar='N', 224 | help='start epochs (default:0)') 225 | parser.add_argument('--batch-size', type=int, default=None, metavar='N', 226 | help='input batch size for training (default: auto)') 227 | parser.add_argument('--test-batch-size', type=int, default=None, metavar='N', 228 | help='input batch size for testing (default: auto)') 229 | parser.add_argument('--use-balanced-weights', action='store_true', default=False, 230 | help='whether to use balanced weights (default: False)') 231 | # optimizer params 232 | parser.add_argument('--learn-rate', type=float, default=None, metavar='LR', 233 | help='learning rate (default: auto)') 234 | parser.add_argument('--lr-scheduler', type=str, default='poly', 235 | choices=['poly', 'step', 'cos'], 236 | help='lr scheduler mode: (default: poly)') 237 | parser.add_argument('--momentum', type=float, default=0.9, 238 | metavar='M', help='momentum (default: 0.9)') 239 | parser.add_argument('--weight-decay', type=float, default=5e-4, 240 | metavar='M', help='w-decay (default: 5e-4)') 241 | parser.add_argument('--nesterov', action='store_true', default=True, 242 | help='whether use nesterov (default: False)') 243 | # cuda, seed and logging 244 | parser.add_argument('--no-cuda', action='store_true', default=False, 245 | help='disables CUDA training') 246 | parser.add_argument('--gpu-ids', type=str, default='0', 247 | help='use which gpu to train, must be a comma-separated list of integers only (default=0)') 248 | parser.add_argument('--seed', type=int, default=1, metavar='S', 249 | help='random seed (default: 1)') 250 | # checking point 251 | parser.add_argument('--resume', type=str, default=None, 252 | help='put the path to resuming file if needed') 253 | parser.add_argument('--checkname', type=str, default=None, 254 | help='set the checkpoint name') 255 | # finetuning pre-trained models 256 | parser.add_argument('--ft', action='store_true', default=False, 257 | help='finetuning on a different dataset') 258 | # evaluation option 259 | parser.add_argument('--eval-interval', type=int, default=1, 260 | help='evaluation interval (default: 1)') 261 | parser.add_argument('--no-val', action='store_true', default=False, 262 | help='skip validation during training') 263 | 264 | args = parser.parse_args() 265 | args.cuda = not args.no_cuda and torch.cuda.is_available() 266 | if args.cuda: 267 | try: 268 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 269 | except ValueError: 270 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 271 | 272 | if args.sync_bn is None: 273 | if args.cuda and len(args.gpu_ids) > 1: 274 | args.sync_bn = True 275 | else: 276 | args.sync_bn = False 277 | 278 | # default settings for epochs, batch_size and lr 279 | if args.epochs is None: 280 | epoches = {'rssrai2019': 100} 281 | args.epochs = epoches[args.dataset.lower()] 282 | 283 | if args.batch_size is None: 284 | args.batch_size = 4 * len(args.gpu_ids) 285 | 286 | if args.test_batch_size is None: 287 | args.test_batch_size = args.batch_size 288 | 289 | if args.learn_rate is None: 290 | lrs = {'rssrai2019': 0.01} 291 | args.learn_rate = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 292 | 293 | if args.checkname is None: 294 | args.checkname = str(args.backbone) 295 | 296 | print(args) 297 | torch.manual_seed(args.seed) 298 | trainer = Trainer(args) 299 | print('Starting Epoch:', trainer.args.start_epoch) 300 | print('Total Epoches:', trainer.args.epochs) 301 | print('====================================') 302 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 303 | trainer.training(epoch) 304 | if epoch % args.eval_interval == (args.eval_interval - 1): 305 | trainer.validation(epoch) 306 | 307 | trainer.writer.close() 308 | 309 | 310 | if __name__ == "__main__": 311 | main() 312 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # experiment 0 4 | # CUDA_VISIBLE_DEVICES=0,1 python train.py --lr 0.01 --weight-decay 0.001 --epochs 200 --batch-size 32 --test-batch-size 32 --base-size 400 --crop-size 400 --gpu-ids 0,1 --checkname unet --eval-interval 1 --dataset rssrai2019 5 | 6 | # experiment 1 using Adam no weight-decay 7 | # CUDA_VISIBLE_DEVICES=0,1 python train.py --learn-rate 0.001 --weight-decay 0 --epochs 1000 --batch-size 32 --test-batch-size 32 --base-size 400 --crop-size 400 --gpu-ids 0,1 --checkname unet --eval-interval 1 --dataset rssrai2019 8 | 9 | # experiment 2 将编码器模块中每块增加一层卷积层,并在下采样层最后添加dropout=0.5, RandomGammaTransform, RandomBilateralFilter, RandomNoise 10 | # CUDA_VISIBLE_DEVICES=0,1 python train.py --learn-rate 0.001 --weight-decay 0 --epochs 1000 --batch-size 20 --test-batch-size 20 --base-size 400 --crop-size 400 --gpu-ids 0,1 --checkname unet --eval-interval 1 --dataset rssrai2019 11 | 12 | # experiment 3 在实验2的基础上add weight-decay amsgrad, 当epoch为总数的50%,80%时,将学习率缩小10倍 13 | # CUDA_VISIBLE_DEVICES=0,1 python train.py --learn-rate 0.001 --weight-decay 1e-4 --epochs 1000 --batch-size 20 --test-batch-size 20 --base-size 400 --crop-size 400 --gpu-ids 0,1 --checkname unet --eval-interval 1 --dataset rssrai2019 14 | 15 | # experiment 4 实验3 L2系数过大导致精度过低,将其继续缩小10倍。学习率更新改为每次0.3倍,增加epoch数量 16 | # CUDA_VISIBLE_DEVICES=0,1 python train.py --learn-rate 0.001 --weight-decay 1e-5 --epochs 1500 --batch-size 20 --test-batch-size 20 --base-size 400 --crop-size 400 --gpu-ids 0,1 --checkname unet --eval-interval 1 --dataset rssrai2019 17 | 18 | # experiment 5 使用UNetNested网络训练 19 | CUDA_VISIBLE_DEVICES=0,1 python train.py --learn-rate 0.001 --weight-decay 0 --epochs 1000 --batch-size 12 --test-batch-size 12 --base-size 400 --crop-size 400 --gpu-ids 0,1 --backbone unetNested --checkname unetNested --eval-interval 1 --dataset rssrai2019 20 | 21 | -------------------------------------------------------------------------------- /train_combine_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: train_combine_net.py 6 | @time: 2019/8/6 下午3:20 7 | """ 8 | import argparse 9 | import os 10 | import numpy as np 11 | import torch 12 | import cv2 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | from utils.saver import Saver 17 | from utils.summaries import TensorboardSummary 18 | from models.backbone.UNet import UNet 19 | from models.backbone.UNetNested import UNetNested 20 | from models.CombineNet import CombineNet 21 | from utils.loss import SegmentationLosses 22 | from utils.metrics import Evaluator 23 | 24 | 25 | class Trainer(object): 26 | def __init__(self, args): 27 | self.args = args 28 | 29 | # Define Saver 30 | self.saver = Saver(args) 31 | self.saver.save_experiment_config() 32 | # Define Tensorboard Summary 33 | self.summary = TensorboardSummary(self.saver.experiment_dir) 34 | self.writer = self.summary.create_summary() 35 | 36 | self.nclass = 16 37 | # Define network 38 | self.unet_model = UNet(in_channels=4, n_classes=self.nclass) 39 | self.unetNested_model = UNetNested(in_channels=4, n_classes=self.nclass) 40 | self.combine_net_model = CombineNet(in_channels=192, n_classes=self.nclass) 41 | 42 | train_params = [{'params': self.combine_net_model.get_params()}] 43 | # Define Optimizer 44 | self.optimizer = torch.optim.Adam(train_params, self.args.learn_rate, weight_decay=args.weight_decay, amsgrad=True) 45 | 46 | self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) 47 | 48 | # Define Evaluator 49 | self.evaluator = Evaluator(self.nclass) 50 | 51 | # Using cuda 52 | if args.cuda: 53 | self.unet_model = self.unet_model.cuda() 54 | self.unetNested_model = self.unetNested_model.cuda() 55 | self.combine_net_model = self.combine_net_model.cuda() 56 | 57 | # Load Unet checkpoint 58 | if not os.path.isfile(args.unet_checkpoint_file): 59 | raise RuntimeError("=> no Unet checkpoint found at '{}'".format(args.unet_checkpoint_file)) 60 | checkpoint = torch.load(args.unet_checkpoint_file) 61 | self.unet_model.load_state_dict(checkpoint['state_dict']) 62 | print("=> loaded Unet checkpoint '{}'".format(args.unet_checkpoint_file)) 63 | 64 | # Load UNetNested checkpoint 65 | if not os.path.isfile(args.unetNested_checkpoint_file): 66 | raise RuntimeError("=> no UNetNested checkpoint found at '{}'".format(args.unetNested_checkpoint_file)) 67 | checkpoint = torch.load(args.unetNested_checkpoint_file) 68 | self.unetNested_model.load_state_dict(checkpoint['state_dict']) 69 | print("=> loaded UNetNested checkpoint '{}'".format(args.unetNested_checkpoint_file)) 70 | 71 | # Resuming combineNet checkpoint 72 | self.best_pred = 0.0 73 | if args.resume is not None: 74 | if not os.path.isfile(args.resume): 75 | raise RuntimeError("=> no combineNet checkpoint found at '{}'" .format(args.resume)) 76 | checkpoint = torch.load(args.resume) 77 | args.start_epoch = checkpoint['epoch'] 78 | if args.cuda: 79 | self.combine_net_model.module.load_state_dict(checkpoint['state_dict']) 80 | else: 81 | self.combine_net_model.load_state_dict(checkpoint['state_dict']) 82 | if not args.ft: 83 | self.optimizer.load_state_dict(checkpoint['optimizer']) 84 | self.best_pred = checkpoint['best_pred'] 85 | print("=> loaded combineNet checkpoint '{}' (epoch {})" 86 | .format(args.resume, checkpoint['epoch'])) 87 | 88 | # Clear start epoch if fine-tuning 89 | if args.ft: 90 | args.start_epoch = 0 91 | 92 | def training(self, epoch): 93 | print('[Epoch: %d, previous best = %.4f]' % (epoch, self.best_pred)) 94 | train_loss = 0.0 95 | self.combine_net_model.train() 96 | self.evaluator.reset() 97 | num_img_tr = len(train_files) 98 | tbar = tqdm(train_files, desc='\r') 99 | 100 | for i, filename in enumerate(tbar): 101 | image = Image.open(os.path.join(train_dir, filename)) 102 | label = Image.open(os.path.join(train_label_dir, os.path.basename(filename)[:-4] + '_labelTrainIds.png')) 103 | label = np.array(label).astype(np.float32) 104 | label = label.reshape((1, 400, 400)) 105 | label = torch.from_numpy(label).float() 106 | label = label.cuda() 107 | 108 | # UNet_multi_scale_predict 109 | unt_pred = self.unet_multi_scale_predict(image) 110 | 111 | # UNetNested_multi_scale_predict 112 | unetnested_pred = self.unetnested_multi_scale_predict(image) 113 | 114 | net_input = torch.cat([unt_pred, unetnested_pred], 1) 115 | 116 | self.optimizer.zero_grad() 117 | output = self.combine_net_model(net_input) 118 | loss = self.criterion(output, label) 119 | loss.backward() 120 | self.optimizer.step() 121 | train_loss += loss.item() 122 | tbar.set_description('Train loss: %.5f' % (train_loss / (i + 1))) 123 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 124 | 125 | pred = output.data.cpu().numpy() 126 | label = label.cpu().numpy() 127 | pred = np.argmax(pred, axis=1) 128 | # Add batch sample into evaluator 129 | self.evaluator.add_batch(label, pred) 130 | 131 | # Fast test during the training 132 | Acc = self.evaluator.Pixel_Accuracy() 133 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 134 | mIoU = self.evaluator.Mean_Intersection_over_Union() 135 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 136 | self.writer.add_scalar('train/mIoU', mIoU, epoch) 137 | self.writer.add_scalar('train/Acc', Acc, epoch) 138 | self.writer.add_scalar('train/Acc_class', Acc_class, epoch) 139 | self.writer.add_scalar('train/fwIoU', FWIoU, epoch) 140 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 141 | 142 | print('train validation:') 143 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 144 | print('Loss: %.3f' % train_loss) 145 | print('---------------------------------') 146 | 147 | def validation(self, epoch): 148 | test_loss = 0.0 149 | self.combine_net_model.eval() 150 | self.evaluator.reset() 151 | tbar = tqdm(val_files, desc='\r') 152 | num_img_val = len(val_files) 153 | 154 | for i, filename in enumerate(tbar): 155 | image = Image.open(os.path.join(val_dir, filename)) 156 | label = Image.open(os.path.join(val_label_dir, os.path.basename(filename)[:-4] + '_labelTrainIds.png')) 157 | label = np.array(label).astype(np.float32) 158 | label = label.reshape((1, 400, 400)) 159 | label = torch.from_numpy(label).float() 160 | label = label.cuda() 161 | 162 | # UNet_multi_scale_predict 163 | unt_pred = self.unet_multi_scale_predict(image) 164 | 165 | # UNetNested_multi_scale_predict 166 | unetnested_pred = self.unetnested_multi_scale_predict(image) 167 | 168 | net_input = torch.cat([unt_pred, unetnested_pred], 1) 169 | 170 | with torch.no_grad(): 171 | output = self.combine_net_model(net_input) 172 | loss = self.criterion(output, label) 173 | test_loss += loss.item() 174 | tbar.set_description('Test loss: %.5f' % (test_loss / (i + 1))) 175 | self.writer.add_scalar('val/total_loss_iter', loss.item(), i + num_img_val * epoch) 176 | pred = output.data.cpu().numpy() 177 | label = label.cpu().numpy() 178 | pred = np.argmax(pred, axis=1) 179 | # Add batch sample into evaluator 180 | self.evaluator.add_batch(label, pred) 181 | 182 | # Fast test during the training 183 | Acc = self.evaluator.Pixel_Accuracy() 184 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 185 | mIoU = self.evaluator.Mean_Intersection_over_Union() 186 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 187 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 188 | self.writer.add_scalar('val/mIoU', mIoU, epoch) 189 | self.writer.add_scalar('val/Acc', Acc, epoch) 190 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch) 191 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch) 192 | print('test validation:') 193 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 194 | print('Loss: %.3f' % test_loss) 195 | print('====================================') 196 | 197 | new_pred = mIoU 198 | if new_pred > self.best_pred: 199 | is_best = True 200 | self.best_pred = new_pred 201 | self.saver.save_checkpoint({ 202 | 'epoch': epoch + 1, 203 | 'state_dict': self.combine_net_model.state_dict(), 204 | 'optimizer': self.optimizer.state_dict(), 205 | 'best_pred': self.best_pred, 206 | }, is_best) 207 | 208 | def unet_multi_scale_predict(self, image_ori: Image): 209 | self.unet_model.eval() 210 | 211 | # 预测原图 212 | sample_ori = image_ori.copy() 213 | output_ori = self.unet_predict(sample_ori) 214 | 215 | # 预测旋转三个 216 | angle_list = [90, 180, 270] 217 | for angle in angle_list: 218 | img_rotate = image_ori.rotate(angle, Image.BILINEAR) 219 | output = self.unet_predict(img_rotate) 220 | pred = output.data.cpu().numpy()[0] 221 | pred = pred.transpose((1, 2, 0)) 222 | m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) 223 | pred = cv2.warpAffine(pred, m_rotate, (400, 400)) 224 | pred = pred.transpose((2, 0, 1)) 225 | output = torch.from_numpy(np.array([pred, ])).float() 226 | output_ori = torch.cat([output_ori, output.cuda()], 1) 227 | 228 | # 预测竖直翻转 229 | img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) 230 | output = self.unet_predict(img_flip) 231 | pred = output.data.cpu().numpy()[0] 232 | pred = pred.transpose((1, 2, 0)) 233 | pred = cv2.flip(pred, 0) 234 | pred = pred.transpose((2, 0, 1)) 235 | output = torch.from_numpy(np.array([pred, ])).float() 236 | output_ori = torch.cat([output_ori, output.cuda()], 1) 237 | 238 | # 预测水平翻转 239 | img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) 240 | output = self.unet_predict(img_flip) 241 | pred = output.data.cpu().numpy()[0] 242 | pred = pred.transpose((1, 2, 0)) 243 | pred = cv2.flip(pred, 1) 244 | pred = pred.transpose((2, 0, 1)) 245 | output = torch.from_numpy(np.array([pred, ])).float() 246 | output_ori = torch.cat([output_ori, output.cuda()], 1) 247 | 248 | return output_ori 249 | 250 | def unet_predict(self, img: Image) -> torch.Tensor: 251 | img = self.transform_test(img) 252 | if self.args.cuda: 253 | img = img.cuda() 254 | with torch.no_grad(): 255 | output = self.unet_model(img) 256 | return output 257 | 258 | def unetnested_predict(self, img: Image) -> torch.Tensor: 259 | img = self.transform_test(img) 260 | if self.args.cuda: 261 | img = img.cuda() 262 | with torch.no_grad(): 263 | output = self.unetNested_model(img) 264 | return output 265 | 266 | def unetnested_multi_scale_predict(self, image_ori: Image): 267 | self.unetNested_model.eval() 268 | 269 | # 预测原图 270 | sample_ori = image_ori.copy() 271 | output_ori = self.unetnested_predict(sample_ori) 272 | 273 | # 预测旋转三个 274 | angle_list = [90, 180, 270] 275 | for angle in angle_list: 276 | img_rotate = image_ori.rotate(angle, Image.BILINEAR) 277 | output = self.unetnested_predict(img_rotate) 278 | pred = output.data.cpu().numpy()[0] 279 | pred = pred.transpose((1, 2, 0)) 280 | m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) 281 | pred = cv2.warpAffine(pred, m_rotate, (400, 400)) 282 | pred = pred.transpose((2, 0, 1)) 283 | output = torch.from_numpy(np.array([pred, ])).float() 284 | output_ori = torch.cat([output_ori, output.cuda()], 1) 285 | 286 | # 预测竖直翻转 287 | img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) 288 | output = self.unetnested_predict(img_flip) 289 | pred = output.data.cpu().numpy()[0] 290 | pred = pred.transpose((1, 2, 0)) 291 | pred = cv2.flip(pred, 0) 292 | pred = pred.transpose((2, 0, 1)) 293 | output = torch.from_numpy(np.array([pred, ])).float() 294 | output_ori = torch.cat([output_ori, output.cuda()], 1) 295 | 296 | # 预测水平翻转 297 | img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) 298 | output = self.unetnested_predict(img_flip) 299 | pred = output.data.cpu().numpy()[0] 300 | pred = pred.transpose((1, 2, 0)) 301 | pred = cv2.flip(pred, 1) 302 | pred = pred.transpose((2, 0, 1)) 303 | output = torch.from_numpy(np.array([pred, ])).float() 304 | output_ori = torch.cat([output_ori, output.cuda()], 1) 305 | 306 | return output_ori 307 | 308 | @staticmethod 309 | def transform_test(img): 310 | # Normalize 311 | mean = (0.544650, 0.352033, 0.384602, 0.352311) 312 | std = (0.249456, 0.241652, 0.228824, 0.227583) 313 | img = np.array(img).astype(np.float32) 314 | img /= 255.0 315 | img -= mean 316 | img /= std 317 | # ToTensor 318 | img = img.transpose((2, 0, 1)) 319 | img = np.array([img, ]) 320 | img = torch.from_numpy(img).float() 321 | return img 322 | 323 | 324 | def main(): 325 | parser = argparse.ArgumentParser(description="PyTorch CombineNet Training") 326 | parser.add_argument('--backbone', type=str, default='combine_net', 327 | choices=['combine_net'], 328 | help='backbone name (default: combine_net)') 329 | parser.add_argument('--dataset', type=str, default='rssrai2019', 330 | choices=['rssrai2019'], 331 | help='dataset name (default: rssrai2019)') 332 | parser.add_argument('--workers', type=int, default=2, 333 | metavar='N', help='dataloader threads') 334 | parser.add_argument('--base-size', type=int, default=400, 335 | help='base image size') 336 | parser.add_argument('--crop-size', type=int, default=400, 337 | help='crop image size') 338 | parser.add_argument('--loss-type', type=str, default='ce', 339 | choices=['ce', 'focal'], 340 | help='loss func type (default: ce)') 341 | # training hyper params 342 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 343 | help='number of epochs to train (default: auto)') 344 | parser.add_argument('--start_epoch', type=int, default=0, metavar='N', 345 | help='start epochs (default:0)') 346 | 347 | # optimizer params 348 | parser.add_argument('--learn-rate', type=float, default=None, metavar='LR', 349 | help='learning rate (default: auto)') 350 | parser.add_argument('--weight-decay', type=float, default=5e-4, 351 | metavar='M', help='w-decay (default: 5e-4)') 352 | # cuda, seed and logging 353 | parser.add_argument('--no-cuda', action='store_true', default=False, 354 | help='disables CUDA training') 355 | parser.add_argument('--gpu-ids', type=str, default='0', 356 | help='use which gpu to train, must be a comma-separated list of integers only (default=0)') 357 | parser.add_argument('--seed', type=int, default=1, metavar='S', 358 | help='random seed (default: 1)') 359 | # checking point 360 | parser.add_argument('--unet_checkpoint_file', type=str, default=None, 361 | help='put the path to Unet checkpoint file') 362 | parser.add_argument('--unetNested_checkpoint_file', type=str, default=None, 363 | help='put the path to UNetNested checkpoint file') 364 | parser.add_argument('--resume', type=str, default=None, 365 | help='put the path to combineNet resuming file if needed') 366 | parser.add_argument('--checkname', type=str, default=None, 367 | help='set the checkpoint name') 368 | # finetuning pre-trained models 369 | parser.add_argument('--ft', action='store_true', default=False, 370 | help='finetuning on a different dataset') 371 | # evaluation option 372 | parser.add_argument('--eval-interval', type=int, default=1, 373 | help='evaluation interval (default: 1)') 374 | parser.add_argument('--no-val', action='store_true', default=False, 375 | help='skip validation during training') 376 | 377 | args = parser.parse_args() 378 | args.cuda = not args.no_cuda and torch.cuda.is_available() 379 | if args.cuda: 380 | try: 381 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 382 | except ValueError: 383 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 384 | 385 | # default settings for epochs, batch_size and lr 386 | if args.epochs is None: 387 | epoches = {'rssrai2019': 100} 388 | args.epochs = epoches[args.dataset.lower()] 389 | 390 | if args.learn_rate is None: 391 | lrs = {'rssrai2019': 0.001} 392 | args.learn_rate = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 393 | 394 | print(args) 395 | torch.manual_seed(args.seed) 396 | trainer = Trainer(args) 397 | print('Starting Epoch:', trainer.args.start_epoch) 398 | print('Total Epoches:', trainer.args.epochs) 399 | print('====================================') 400 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 401 | trainer.training(epoch) 402 | if epoch % args.eval_interval == (args.eval_interval - 1): 403 | trainer.validation(epoch) 404 | 405 | trainer.writer.close() 406 | 407 | 408 | if __name__ == "__main__": 409 | train_dir = r'/home/lab/ygy/rssrai2019/datasets/image/train_mix' 410 | train_label_dir = r'/home/lab/ygy/rssrai2019/datasets/label/train_mix_id_image' 411 | val_dir = r'/home/lab/ygy/rssrai2019/datasets/image/val_mix' 412 | val_label_dir = r'/home/lab/ygy/rssrai2019/datasets/label/val_mix_id_image' 413 | 414 | train_files = os.listdir(train_dir) 415 | val_files = os.listdir(val_dir) 416 | main() 417 | -------------------------------------------------------------------------------- /train_combine_net.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python train_combine_net.py --learn-rate 0.001 --weight-decay 0 --epochs 100 --base-size 400 --crop-size 400 --gpu-ids 0 --checkname combine_net --eval-interval 1 --dataset rssrai2019 --unet_checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_2/checkpoint.pth.tar --unetNested_checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unetNested/experiment_0/checkpoint.pth.tar -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | -------------------------------------------------------------------------------- /utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | from mypath import Path 6 | 7 | 8 | def calculate_weigths_labels(dataset, dataloader, num_classes): 9 | # Create an instance from the data loader 10 | z = np.zeros((num_classes,)) 11 | # Initialize tqdm 12 | tqdm_batch = tqdm(dataloader) 13 | print('Calculating classes weights') 14 | for sample in tqdm_batch: 15 | y = sample['label'] 16 | y = y.detach().cpu().numpy() 17 | mask = (y >= 0) & (y < num_classes) 18 | labels = y[mask].astype(np.uint8) 19 | count_l = np.bincount(labels, minlength=num_classes) 20 | z += count_l 21 | tqdm_batch.close() 22 | total_frequency = np.sum(z) 23 | class_weights = [] 24 | for frequency in z: 25 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 26 | class_weights.append(class_weight) 27 | ret = np.array(class_weights) 28 | classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') 29 | np.save(classes_weights_path, ret) 30 | 31 | return ret 32 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SegmentationLosses(object): 7 | def __init__(self, weight=None, batch_average=True, ignore_index=255, cuda=False): 8 | self.ignore_index = ignore_index 9 | self.weight = weight 10 | self.batch_average = batch_average 11 | self.cuda = cuda 12 | 13 | def build_loss(self, mode='ce'): 14 | """Choices: ['ce' or 'focal']""" 15 | if mode == 'ce': 16 | return self.cross_entropy_loss 17 | elif mode == 'focal': 18 | return self.focal_loss 19 | else: 20 | raise NotImplementedError 21 | 22 | def cross_entropy_loss(self, logit, target): 23 | n, c, h, w = logit.size() 24 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 25 | reduction='elementwise_mean') 26 | if self.cuda: 27 | criterion = criterion.cuda() 28 | 29 | loss = criterion(logit, target.long()) 30 | 31 | if self.batch_average: 32 | loss /= n 33 | 34 | return loss 35 | 36 | def focal_loss(self, logit, target, gamma=2, alpha=0.5): 37 | n, c, h, w = logit.size() 38 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 39 | reduction='elementwise_mean') 40 | if self.cuda: 41 | criterion = criterion.cuda() 42 | 43 | logpt = -criterion(logit, target.long()) 44 | pt = torch.exp(logpt) 45 | if alpha is not None: 46 | logpt *= alpha 47 | loss = -((1 - pt) ** gamma) * logpt 48 | 49 | if self.batch_average: 50 | loss /= n 51 | 52 | return loss 53 | 54 | 55 | if __name__ == "__main__": 56 | loss = SegmentationLosses(cuda=True) 57 | a = torch.rand(1, 3, 7, 7).cuda() 58 | b = torch.rand(1, 7, 7).cuda() 59 | print(loss.cross_entropy_loss(a, b).item()) 60 | print(loss.focal_loss(a, b, gamma=0, alpha=None).item()) 61 | print(loss.focal_loss(a, b, gamma=2, alpha=0.5).item()) 62 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 | # Created by: Hang Zhang 5 | # ECE Department, Rutgers University 6 | # Email: zhang.hang@rutgers.edu 7 | # Copyright (c) 2017 8 | # 9 | # This source code is licensed under the MIT-style license found in the 10 | # LICENSE file in the root directory of this source tree 11 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 12 | 13 | import math 14 | 15 | 16 | class LR_Scheduler(object): 17 | """Learning Rate Scheduler 18 | 19 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 20 | 21 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 22 | 23 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 24 | 25 | Args: 26 | args: 27 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 28 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 29 | :attr:`args.lr_step` 30 | 31 | iters_per_epoch: number of iterations per epoch 32 | """ 33 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 34 | lr_step=0, warmup_epochs=0): 35 | self.mode = mode 36 | print('Using {} LR Scheduler!'.format(self.mode)) 37 | self.lr = base_lr 38 | if mode == 'step': 39 | assert lr_step 40 | self.lr_step = lr_step 41 | self.iters_per_epoch = iters_per_epoch 42 | self.N = num_epochs * iters_per_epoch 43 | self.epoch = -1 44 | self.warmup_iters = warmup_epochs * iters_per_epoch 45 | 46 | def __call__(self, optimizer, i, epoch, best_pred): 47 | T = epoch * self.iters_per_epoch + i 48 | if self.mode == 'cos': 49 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 50 | elif self.mode == 'poly': 51 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 52 | elif self.mode == 'step': 53 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 54 | else: 55 | raise NotImplemented 56 | # warm up lr schedule 57 | if self.warmup_iters > 0 and T < self.warmup_iters: 58 | lr = lr * 1.0 * T / self.warmup_iters 59 | if epoch > self.epoch: 60 | print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch, lr, best_pred)) 61 | self.epoch = epoch 62 | assert lr >= 0 63 | self._adjust_learning_rate(optimizer, lr) 64 | 65 | @staticmethod 66 | def _adjust_learning_rate(optimizer, lr): 67 | if len(optimizer.param_groups) == 1: 68 | optimizer.param_groups[0]['lr'] = lr 69 | else: 70 | # enlarge the lr at the head 71 | optimizer.param_groups[0]['lr'] = lr 72 | for i in range(1, len(optimizer.param_groups)): 73 | optimizer.param_groups[i]['lr'] = lr * 10 74 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | class Evaluator(object): 6 | def __init__(self, num_class): 7 | self.num_class = num_class 8 | self.confusion_matrix = np.zeros((self.num_class,)*2) 9 | 10 | def Pixel_Accuracy(self): 11 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 12 | return Acc 13 | 14 | def Pixel_Accuracy_Class(self): 15 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 16 | Acc = np.nanmean(Acc) 17 | return Acc 18 | 19 | def Mean_Intersection_over_Union(self): 20 | MIoU = np.diag(self.confusion_matrix) / ( 21 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 22 | np.diag(self.confusion_matrix)) 23 | MIoU = np.nanmean(MIoU) 24 | return MIoU 25 | 26 | def Frequency_Weighted_Intersection_over_Union(self): 27 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 28 | iu = np.diag(self.confusion_matrix) / ( 29 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 30 | np.diag(self.confusion_matrix)) 31 | 32 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 33 | return FWIoU 34 | 35 | def _generate_matrix(self, gt_image, pre_image): 36 | mask = (gt_image >= 0) & (gt_image < self.num_class) 37 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 38 | count = np.bincount(label, minlength=self.num_class**2) 39 | confusion_matrix = count.reshape(self.num_class, self.num_class) 40 | return confusion_matrix 41 | 42 | def add_batch(self, gt_image, pre_image): 43 | assert gt_image.shape == pre_image.shape 44 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 45 | 46 | def reset(self): 47 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 48 | -------------------------------------------------------------------------------- /utils/save_model_and_params.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 读取单独保存的模型参数,将其与模型结构一起重新保存 4 | @author:HuiYi or 会意 5 | @file: vis.py.py 6 | @time: 2019/7/30 下午7:00 7 | """ 8 | import torch 9 | from models.backbone.UNet import UNet 10 | 11 | model_path_list = [ 12 | '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/checkpoint.pth.tar', 13 | '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_1/checkpoint.pth.tar', 14 | '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_2/checkpoint.pth.tar' 15 | ] 16 | 17 | if __name__ == '__main__': 18 | model = UNet(in_channels=4, n_classes=16, sync_bn=False) 19 | model = model.cuda() 20 | param = '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/checkpoint.pth.tar' 21 | checkpoint = torch.load(param) 22 | model.load_state_dict(checkpoint['state_dict']) 23 | torch.save(model, '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/model_and_param.pth.tar') 24 | print('save finish') 25 | 26 | # load 27 | # model = torch.load('/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_1/model_and_param.pth.tar') 28 | # params = model.state_dict() 29 | # print('load') 30 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import shutil 4 | import torch 5 | from collections import OrderedDict 6 | import glob 7 | 8 | 9 | class Saver(object): 10 | def __init__(self, args): 11 | self.args = args 12 | self.directory = os.path.join('run', args.dataset, args.checkname) 13 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 14 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 15 | 16 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 17 | if not os.path.exists(self.experiment_dir): 18 | os.makedirs(self.experiment_dir) 19 | 20 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 21 | """Saves checkpoint to disk""" 22 | filename = os.path.join(self.experiment_dir, filename) 23 | torch.save(state, filename) 24 | if is_best: 25 | best_pred = state['best_pred'] 26 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 27 | f.write(str(best_pred)) 28 | if self.runs: 29 | previous_miou = [0.0] 30 | for run in self.runs: 31 | run_id = run.split('_')[-1] 32 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 33 | if os.path.exists(path): 34 | with open(path, 'r') as f: 35 | miou = float(f.readline()) 36 | previous_miou.append(miou) 37 | else: 38 | continue 39 | max_miou = max(previous_miou) 40 | if best_pred > max_miou: 41 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 42 | else: 43 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 44 | 45 | def save_experiment_config(self): 46 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 47 | log_file = open(logfile, 'w') 48 | p = OrderedDict() 49 | p['dataset'] = self.args.dataset 50 | p['backbone'] = self.args.backbone 51 | p['learn_rate'] = self.args.learn_rate 52 | # p['lr_scheduler'] = self.args.lr_scheduler 53 | p['weight_decay'] = self.args.weight_decay 54 | p['loss_type'] = self.args.loss_type 55 | p['epoch'] = self.args.epochs 56 | p['base_size'] = self.args.base_size 57 | p['crop_size'] = self.args.crop_size 58 | 59 | for key, val in p.items(): 60 | log_file.write(key + ':' + str(val) + '\n') 61 | log_file.close() 62 | -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import torch 4 | from torchvision.utils import make_grid 5 | from tensorboardX import SummaryWriter 6 | from dataloaders.utils import decode_seg_map_sequence 7 | 8 | 9 | class TensorboardSummary(object): 10 | def __init__(self, directory): 11 | self.directory = directory 12 | 13 | def create_summary(self): 14 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 15 | return writer 16 | 17 | @staticmethod 18 | def visualize_image(writer, dataset, image, target, output, global_step): 19 | # grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 20 | # writer.add_image('Image', grid_image, global_step) 21 | grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 22 | dataset=dataset), 3, normalize=False, range=(0, 255)) 23 | writer.add_image('Predicted label', grid_image, global_step) 24 | grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 25 | dataset=dataset), 3, normalize=False, range=(0, 255)) 26 | writer.add_image('Ground truth label', grid_image, global_step) 27 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: vis.py.py 6 | @time: 2019/6/23 下午7:09 7 | """ 8 | import argparse 9 | import os 10 | import numpy as np 11 | from tqdm import tqdm 12 | import torch 13 | from PIL import Image 14 | 15 | from dataloaders import make_data_loader 16 | from models.backbone.UNet import UNet 17 | from models.backbone.UNetNested import UNetNested 18 | 19 | from dataloaders.utils import decode_segmap 20 | 21 | 22 | class Visualization(object): 23 | def __init__(self, args): 24 | self.args = args 25 | 26 | # Define Dataloader 27 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 28 | _, _, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 29 | 30 | self.model = None 31 | # Define network 32 | if self.args.backbone == 'unet': 33 | self.model = UNet(in_channels=4, n_classes=self.nclass) 34 | print("using UNet") 35 | if self.args.backbone == 'unetNested': 36 | self.model = UNetNested(in_channels=4, n_classes=self.nclass) 37 | print("using UNetNested") 38 | 39 | # Using cuda 40 | if args.cuda: 41 | self.model = self.model.cuda() 42 | 43 | if not os.path.isfile(args.checkpoint_file): 44 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.checkpoint_file)) 45 | checkpoint = torch.load(args.checkpoint_file) 46 | 47 | self.model.load_state_dict(checkpoint['state_dict']) 48 | print("=> loaded checkpoint '{}'".format(args.checkpoint_file)) 49 | 50 | def visualization(self): 51 | self.model.eval() 52 | tbar = tqdm(self.test_loader, desc='\r') 53 | for i, sample in enumerate(tbar): 54 | image = sample['image'] 55 | img_path = sample['img_path'] 56 | if self.args.cuda: 57 | image = image.cuda() 58 | with torch.no_grad(): 59 | output = self.model(image) 60 | tbar.set_description('Vis image:') 61 | pred = output.data.cpu().numpy() 62 | pred = np.argmax(pred, axis=1)[0] 63 | 64 | rgb = decode_segmap(pred, self.args.dataset) 65 | pred_img = Image.fromarray(pred, mode='L') 66 | rgb_img = Image.fromarray(rgb, mode='RGB') 67 | filename = os.path.basename(img_path[0]) 68 | pred_img.save(os.path.join(self.args.vis_logdir, 'raw_train_id', filename)) 69 | rgb_img.save(os.path.join(self.args.vis_logdir, 'vis_color', filename)) 70 | 71 | @staticmethod 72 | def transform_test(img): 73 | # Normalize 74 | mean = (0.544650, 0.352033, 0.384602, 0.352311) 75 | std = (0.249456, 0.241652, 0.228824, 0.227583) 76 | img = np.array(img).astype(np.float32) 77 | img /= 255.0 78 | img -= mean 79 | img /= std 80 | # ToTensor 81 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 82 | img = np.array([img, ]) 83 | img = torch.from_numpy(img).float() 84 | return img 85 | 86 | def predict(self, img: np.array) -> np.array: 87 | img = self.transform_test(img) 88 | if self.args.cuda: 89 | img = img.cuda() 90 | with torch.no_grad(): 91 | output = self.model(img) 92 | pred = output.data.cpu().numpy() 93 | return pred 94 | 95 | def multi_scale_predict(self): 96 | import cv2 97 | test_dir = r'/home/lab/ygy/rssrai2019/datasets/image/test_crop' 98 | files = os.listdir(test_dir) 99 | self.model.eval() 100 | tbar = tqdm(files, desc='\r') 101 | 102 | for i, filename in enumerate(tbar): 103 | image_predict_prob_list = [] 104 | image_ori = Image.open(os.path.join(test_dir, filename)) 105 | 106 | # 预测原图 107 | sample_ori = image_ori.copy() 108 | pred = self.predict(sample_ori)[0] 109 | ori_pred = np.argmax(pred, axis=0) 110 | image_predict_prob_list.append(pred) 111 | 112 | # 预测旋转三个 113 | angle_list = [90, 180, 270] 114 | for angle in angle_list: 115 | img_rotate = image_ori.rotate(angle, Image.BILINEAR) 116 | pred = self.predict(img_rotate)[0] 117 | pred = pred.transpose((1, 2, 0)) 118 | m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0-angle, 1) 119 | pred = cv2.warpAffine(pred, m_rotate, (400, 400)) 120 | pred = pred.transpose((2, 0, 1)) 121 | image_predict_prob_list.append(pred) 122 | 123 | # 预测竖直翻转 124 | img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) 125 | pred = self.predict(img_flip)[0] 126 | pred = cv2.flip(pred, 0) 127 | image_predict_prob_list.append(pred) 128 | 129 | # 预测水平翻转 130 | img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) 131 | pred = self.predict(img_flip)[0] 132 | pred = cv2.flip(pred, 1) 133 | image_predict_prob_list.append(pred) 134 | 135 | # 求和平均 136 | final_predict_prob = sum(image_predict_prob_list) / len(image_predict_prob_list) 137 | final_pred = np.argmax(final_predict_prob, axis=0) 138 | 139 | rgb_ori = decode_segmap(ori_pred, self.args.dataset) 140 | rgb = decode_segmap(final_pred, self.args.dataset) 141 | pred_img = Image.fromarray(final_pred, mode='1') 142 | rgb_ori_img = Image.fromarray(rgb_ori, mode='RGB') 143 | rgb_img = Image.fromarray(rgb, mode='RGB') 144 | pred_img.save(os.path.join(self.args.vis_logdir, 'raw_train_id', filename)) 145 | rgb_ori_img.save(os.path.join(self.args.vis_logdir, 'vis_color_ori', filename)) 146 | rgb_img.save(os.path.join(self.args.vis_logdir, 'vis_color', filename)) 147 | 148 | 149 | def main(): 150 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") 151 | parser.add_argument('--backbone', type=str, default='unet', 152 | choices=['unet', 'unetNested'], 153 | help='backbone name (default: resnet)') 154 | parser.add_argument('--dataset', type=str, default='rssrai2019', 155 | choices=['rssrai2019'], 156 | help='dataset name (default: pascal)') 157 | parser.add_argument('--workers', type=int, default=4, 158 | metavar='N', help='dataloader threads') 159 | parser.add_argument('--batch-size', type=int, default=1, 160 | metavar='N', help='input batch size for testing (default: auto)') 161 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N', 162 | help='input batch size for testing (default: auto)') 163 | 164 | parser.add_argument('--crop-size', type=int, default=400, 165 | help='crop image size') 166 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA') 167 | 168 | parser.add_argument('--checkpoint_file', type=str, default=None, 169 | help='put the path to checkpoint file') 170 | parser.add_argument('--vis_logdir', type=str, default=None, 171 | help='store the vis image result dir') 172 | 173 | args = parser.parse_args() 174 | args.cuda = not args.no_cuda and torch.cuda.is_available() 175 | 176 | visual = Visualization(args) 177 | visual.visualization() 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /vis.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Unet 4 | # python vis.py --checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_4/checkpoint.pth.tar --vis_logdir /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/vis_log 5 | 6 | # UNetNested 7 | python vis.py --backbone unetNested --checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unetNested/experiment_0/checkpoint.pth.tar --vis_logdir /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unetNested/vis_log -------------------------------------------------------------------------------- /vis_combine_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | @function: 4 | @author:HuiYi or 会意 5 | @file: vis_combine_net.py 6 | @time: 2019/8/7 上午11:25 7 | """ 8 | import argparse 9 | import os 10 | import numpy as np 11 | from tqdm import tqdm 12 | import torch 13 | import cv2 14 | from PIL import Image 15 | 16 | from models.backbone.UNet import UNet 17 | from models.backbone.UNetNested import UNetNested 18 | from models.CombineNet import CombineNet 19 | 20 | from dataloaders.utils import decode_segmap 21 | 22 | 23 | class Visualization: 24 | def __init__(self, args): 25 | self.args = args 26 | 27 | self.nclass = 16 28 | # Define network 29 | self.unet_model = UNet(in_channels=4, n_classes=self.nclass) 30 | self.unetNested_model = UNetNested(in_channels=4, n_classes=self.nclass) 31 | self.combine_net_model = CombineNet(in_channels=192, n_classes=self.nclass) 32 | 33 | # Using cuda 34 | if args.cuda: 35 | self.unet_model = self.unet_model.cuda() 36 | self.unetNested_model = self.unetNested_model.cuda() 37 | self.combine_net_model = self.combine_net_model.cuda() 38 | 39 | # Load Unet model 40 | if not os.path.isfile(args.unet_checkpoint_file): 41 | raise RuntimeError("=> no unet checkpoint found at '{}'".format(args.unet_checkpoint_file)) 42 | checkpoint = torch.load(args.unet_checkpoint_file) 43 | self.unet_model.load_state_dict(checkpoint['state_dict']) 44 | print("=> loaded unet checkpoint '{}'".format(args.unet_checkpoint_file)) 45 | 46 | # Load UNetNested model 47 | if not os.path.isfile(args.unetNested_checkpoint_file): 48 | raise RuntimeError("=> no UNetNested checkpoint found at '{}'".format(args.unetNested_checkpoint_file)) 49 | checkpoint = torch.load(args.unetNested_checkpoint_file) 50 | self.unetNested_model.load_state_dict(checkpoint['state_dict']) 51 | print("=> loaded UNetNested checkpoint '{}'".format(args.unetNested_checkpoint_file)) 52 | 53 | # Load Combine Net 54 | if not os.path.isfile(args.combine_net_checkpoint_file): 55 | raise RuntimeError("=> no combine net checkpoint found at '{}'".format(args.combine_net_checkpoint_file)) 56 | checkpoint = torch.load(args.combine_net_checkpoint_file) 57 | self.combine_net_model.load_state_dict(checkpoint['state_dict']) 58 | print("=> loaded combine net checkpoint '{}'".format(args.combine_net_checkpoint_file)) 59 | 60 | def visualization(self): 61 | self.combine_net_model.eval() 62 | tbar = tqdm(test_files, desc='\r') 63 | 64 | for i, filename in enumerate(tbar): 65 | image = Image.open(os.path.join(test_dir, filename)) 66 | 67 | # UNet_multi_scale_predict 68 | unt_pred = self.unet_multi_scale_predict(image) 69 | 70 | # UNetNested_multi_scale_predict 71 | unetnested_pred = self.unetnested_multi_scale_predict(image) 72 | 73 | net_input = torch.cat([unt_pred, unetnested_pred], 1) 74 | 75 | with torch.no_grad(): 76 | output = self.combine_net_model(net_input) 77 | pred = output.data.cpu().numpy()[0] 78 | pred = np.argmax(pred, axis=0) 79 | 80 | rgb = decode_segmap(pred, self.args.dataset) 81 | 82 | pred_img = Image.fromarray(pred, mode='L') 83 | rgb_img = Image.fromarray(rgb, mode='RGB') 84 | 85 | pred_img.save(os.path.join(self.args.vis_logdir, 'raw_train_id', filename)) 86 | rgb_img.save(os.path.join(self.args.vis_logdir, 'vis_color', filename)) 87 | 88 | def unet_multi_scale_predict(self, image_ori: Image): 89 | self.unet_model.eval() 90 | 91 | # 预测原图 92 | sample_ori = image_ori.copy() 93 | output_ori = self.unet_predict(sample_ori) 94 | 95 | # 预测旋转三个 96 | angle_list = [90, 180, 270] 97 | for angle in angle_list: 98 | img_rotate = image_ori.rotate(angle, Image.BILINEAR) 99 | output = self.unet_predict(img_rotate) 100 | pred = output.data.cpu().numpy()[0] 101 | pred = pred.transpose((1, 2, 0)) 102 | m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) 103 | pred = cv2.warpAffine(pred, m_rotate, (400, 400)) 104 | pred = pred.transpose((2, 0, 1)) 105 | output = torch.from_numpy(np.array([pred, ])).float() 106 | output_ori = torch.cat([output_ori, output.cuda()], 1) 107 | 108 | # 预测竖直翻转 109 | img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) 110 | output = self.unet_predict(img_flip) 111 | pred = output.data.cpu().numpy()[0] 112 | pred = pred.transpose((1, 2, 0)) 113 | pred = cv2.flip(pred, 0) 114 | pred = pred.transpose((2, 0, 1)) 115 | output = torch.from_numpy(np.array([pred, ])).float() 116 | output_ori = torch.cat([output_ori, output.cuda()], 1) 117 | 118 | # 预测水平翻转 119 | img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) 120 | output = self.unet_predict(img_flip) 121 | pred = output.data.cpu().numpy()[0] 122 | pred = pred.transpose((1, 2, 0)) 123 | pred = cv2.flip(pred, 1) 124 | pred = pred.transpose((2, 0, 1)) 125 | output = torch.from_numpy(np.array([pred, ])).float() 126 | output_ori = torch.cat([output_ori, output.cuda()], 1) 127 | 128 | return output_ori 129 | 130 | def unet_predict(self, img: Image) -> torch.Tensor: 131 | img = self.transform_test(img) 132 | if self.args.cuda: 133 | img = img.cuda() 134 | with torch.no_grad(): 135 | output = self.unet_model(img) 136 | return output 137 | 138 | def unetnested_predict(self, img: Image) -> torch.Tensor: 139 | img = self.transform_test(img) 140 | if self.args.cuda: 141 | img = img.cuda() 142 | with torch.no_grad(): 143 | output = self.unetNested_model(img) 144 | return output 145 | 146 | def unetnested_multi_scale_predict(self, image_ori: Image): 147 | self.unetNested_model.eval() 148 | 149 | # 预测原图 150 | sample_ori = image_ori.copy() 151 | output_ori = self.unetnested_predict(sample_ori) 152 | 153 | # 预测旋转三个 154 | angle_list = [90, 180, 270] 155 | for angle in angle_list: 156 | img_rotate = image_ori.rotate(angle, Image.BILINEAR) 157 | output = self.unetnested_predict(img_rotate) 158 | pred = output.data.cpu().numpy()[0] 159 | pred = pred.transpose((1, 2, 0)) 160 | m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) 161 | pred = cv2.warpAffine(pred, m_rotate, (400, 400)) 162 | pred = pred.transpose((2, 0, 1)) 163 | output = torch.from_numpy(np.array([pred, ])).float() 164 | output_ori = torch.cat([output_ori, output.cuda()], 1) 165 | 166 | # 预测竖直翻转 167 | img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) 168 | output = self.unetnested_predict(img_flip) 169 | pred = output.data.cpu().numpy()[0] 170 | pred = pred.transpose((1, 2, 0)) 171 | pred = cv2.flip(pred, 0) 172 | pred = pred.transpose((2, 0, 1)) 173 | output = torch.from_numpy(np.array([pred, ])).float() 174 | output_ori = torch.cat([output_ori, output.cuda()], 1) 175 | 176 | # 预测水平翻转 177 | img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) 178 | output = self.unetnested_predict(img_flip) 179 | pred = output.data.cpu().numpy()[0] 180 | pred = pred.transpose((1, 2, 0)) 181 | pred = cv2.flip(pred, 1) 182 | pred = pred.transpose((2, 0, 1)) 183 | output = torch.from_numpy(np.array([pred, ])).float() 184 | output_ori = torch.cat([output_ori, output.cuda()], 1) 185 | 186 | return output_ori 187 | 188 | @staticmethod 189 | def transform_test(img): 190 | # Normalize 191 | mean = (0.544650, 0.352033, 0.384602, 0.352311) 192 | std = (0.249456, 0.241652, 0.228824, 0.227583) 193 | img = np.array(img).astype(np.float32) 194 | img /= 255.0 195 | img -= mean 196 | img /= std 197 | # ToTensor 198 | img = img.transpose((2, 0, 1)) 199 | img = np.array([img, ]) 200 | img = torch.from_numpy(img).float() 201 | return img 202 | 203 | 204 | def main(): 205 | parser = argparse.ArgumentParser(description="PyTorch CombineNet Training") 206 | parser.add_argument('--backbone', type=str, default='combine_net', 207 | choices=['combine_net'], 208 | help='backbone name (default: combine_net)') 209 | parser.add_argument('--dataset', type=str, default='rssrai2019', 210 | choices=['rssrai2019'], 211 | help='dataset name (default: pascal)') 212 | 213 | parser.add_argument('--unet_checkpoint_file', type=str, default=None, 214 | help='put the path to UNet checkpoint file') 215 | parser.add_argument('--unetNested_checkpoint_file', type=str, default=None, 216 | help='put the path to UNetNested checkpoint file') 217 | parser.add_argument('--combine_net_checkpoint_file', type=str, default=None, 218 | help='put the path to combineNet checkpoint file') 219 | parser.add_argument('--vis_logdir', type=str, default=None, 220 | help='store the vis image result dir') 221 | 222 | args = parser.parse_args() 223 | args.cuda = torch.cuda.is_available() 224 | 225 | visual = Visualization(args) 226 | visual.visualization() 227 | 228 | 229 | if __name__ == "__main__": 230 | test_dir = r'/home/lab/ygy/rssrai2019/datasets/image/test_overlay_crop' 231 | 232 | test_files = os.listdir(test_dir) 233 | 234 | main() 235 | -------------------------------------------------------------------------------- /vis_combine_net.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 python vis_combine_net.py --combine_net_checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/combine_net/experiment_1/checkpoint.pth.tar --unet_checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_2/checkpoint.pth.tar --unetNested_checkpoint_file /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unetNested/experiment_0/checkpoint.pth.tar --vis_logdir /home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/combine_net/vis_log --------------------------------------------------------------------------------