├── README.md ├── configs └── parameter.yaml ├── data ├── images │ ├── aaa.jpg │ ├── bbb.jpg │ └── ccc.jpg ├── labels │ ├── aaa.jpg │ ├── bbb.jpg │ └── ccc.jpg ├── test.txt └── train.txt ├── data_augmentation.py ├── demo.py ├── main.py ├── models ├── deeplab_v3_plus.py ├── hed_series │ ├── hed_res.py │ ├── hed_vgg16.py │ ├── hf_fcn_res.py │ └── hf_fcn_vgg16.py ├── models.py ├── pspnet.py ├── spp.py └── unet.py ├── readmes ├── data_aug.md ├── main_modify.jpg ├── param_modify.jpg └── train_cusom.md └── utils ├── aug_GDAL.py ├── aug_PIL.py ├── dataset.py ├── metrics.py ├── plots.py ├── trainval.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | ## Semantic Segmentation Pytorch 2 | ``` 3 | author is leilei 4 | Restart this project from 2017-10-01 5 | 6 | Now the 1.alpha.0 version has been basically completed, to be tested. 7 | TODO 8 | Add distributed and optimize code. 9 | ``` 10 | 11 | ### Environment 12 | ``` 13 | python: 3.6+ 14 | ubuntu16.04 or 18.04 15 | pytorch 1.6 (cuda10.2 docker) 16 | tensorboard 2.0 17 | scikit-learn 0.24.1 18 | ``` 19 | 20 | ### **Note** 21 | + If a black border is introduced, it will be regarded as one type, and the default is 0 ! 22 | + label value is [1, N], 0 is black border class ! 23 | + Not supporting distributed(NCCL), just support DataParallel. 24 | 25 | ### Getting Started 26 | + [How to Use](./readmes/train_cusom.md) 27 | 28 | ### Demo 29 | + Just see [demo.py](./demo.py) 30 | 31 | ### Evaluation index 32 | + Just see [metrics.py](./utils/metrics.py) 33 | + Support acc, mean_precision, mean_recall, mean_iou 34 | 35 | ### Support Network 36 | - [x] [deeplab_v3_plus](models/deeplab_v3_plus.py) 37 | - [x] [pspnet](models/pspnet.py) 38 | - [x] [unet](models/unet.py) 39 | - [x] [spp-net](models/spp.py) 40 | - [x] [HF_FCN](models/hed_series/hf_fcn_vgg16.py) 41 | - [ ] [deeplab_v3](https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py) 42 | - [ ] [HRNet](https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/pytorch-v1.1) 43 | - [ ] [U^2Net](https://github.com/NathanUA/U-2-Net) 44 | - [ ] ... 45 | 46 | ### Data Aug 47 | + [**data-augumentations**](./readmes/data_aug.md) 48 | ``` 49 | support 50 | random zoom-in/out, random noise, 51 | random blur, random color-jitter(brightness-contrast-saturation-hue) 52 | random affine, random rotate, random flip 53 | ``` 54 | 55 | ### Others 56 | * [building-segmentation-dataset](https://github.com/gengyanlei/build_segmentation_dataset) 57 | * [reflective-clothes-detect-dataset](https://github.com/gengyanlei/reflective-clothes-detect) 58 | -------------------------------------------------------------------------------- /configs/parameter.yaml: -------------------------------------------------------------------------------- 1 | input_hw: (256, 256) # model input height-width 2 | mean: [0.485, 0.456, 0.406] 3 | std: [0.229, 0.224, 0.225] 4 | is_gdal: False 5 | value_scale: None # is_gdal=True,value_scale is int 6 | 7 | device: '1, 0' # gpu id 8 | batch_size: 16 9 | num_workers: 8 10 | lr0: 1e-3 # init lr 11 | lrf: 0.1 # final OneCycleLR learning rate (lr0 * lrf) 12 | momentum: 0.9 # adam betas[0], SGD momentum 13 | train_txt_path: './data/train.txt' 14 | test_txt_path: './data/test.txt' 15 | 16 | class_number: 5+1 # _background_ must be 0 17 | class_names: ['_background_', '1', '2', '3', '4', '5'] 18 | test_interval: 1 # test per test_interval epoches 19 | epoches: 100 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /data/images/aaa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/data/images/aaa.jpg -------------------------------------------------------------------------------- /data/images/bbb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/data/images/bbb.jpg -------------------------------------------------------------------------------- /data/images/ccc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/data/images/ccc.jpg -------------------------------------------------------------------------------- /data/labels/aaa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/data/labels/aaa.jpg -------------------------------------------------------------------------------- /data/labels/bbb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/data/labels/bbb.jpg -------------------------------------------------------------------------------- /data/labels/ccc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/data/labels/ccc.jpg -------------------------------------------------------------------------------- /data/test.txt: -------------------------------------------------------------------------------- 1 | /home/gengyanlei/data/images/ccc.jpg -------------------------------------------------------------------------------- /data/train.txt: -------------------------------------------------------------------------------- 1 | /home/gengyanlei/data/images/aaa.jpg 2 | /home/gengyanlei/data/images/bbb.jpg -------------------------------------------------------------------------------- /data_augmentation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 先处理图像保存到一个文件夹下,(数据增强先做,然后读取数据就不需要边读边数据增强了) 3 | ADE数据集:151类。 4 | ''' 5 | import os 6 | import cv2 7 | import random 8 | import numpy as np 9 | 10 | train_img_path=r'/home/*/*/Dataset/ADE/images/training' 11 | train_lab_path=r'/home/*/*/Dataset/ADE/annotations/training' 12 | #val_img_path=r'/home/*/ADE/images/validation' 13 | #val_lab_path=r'/home/*/ADE/annotations/validation' 14 | save_path=r'/home/*/*/Dataset/ADE/HDF5' 15 | 16 | names=sorted(os.listdir(train_img_path)) 17 | 18 | N1=20206*5 # 一个epoch 总共N1张 19 | num=20206 # 一共num张图片,然后每张图片做5次处理,保存。 20 | 21 | s_img_path=r'/home/*/*/Dataset/ADE/HDF5/image'#save image 22 | s_lab_path=r'/home/*/*/Dataset/ADE/HDF5/label' 23 | 24 | def first_data_augmen(image,label):# first 25 | randint=random.randint(0,4) 26 | if randint==2: 27 | f_scale=0.5+random.randint(0,10)/10 28 | image=cv2.resize(image,(0,0),fx=f_scale,fy=f_scale) 29 | label=cv2.resize(label,(0,0),fx=f_scale,fy=f_scale,interpolation=cv2.INTER_NEAREST) 30 | else : 31 | image=image 32 | label=label 33 | return image,label 34 | 35 | def final_data_augmen(image,label): # final 36 | randint=random.randint(1,8) 37 | if randint==1:# left-right flip 38 | image=cv2.flip(image,1) 39 | label=cv2.flip(label,1) 40 | elif randint==2:# up-down-flip 41 | image=cv2.flip(image,0) 42 | label=cv2.flip(label,0) 43 | elif randint==3:# rotation 90 first width and then hight 44 | M=cv2.getRotationMatrix2D((image.shape[1]//2,image.shape[0]//2),90,1.0) 45 | image=cv2.warpAffine(image,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 46 | label=cv2.warpAffine(label,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 47 | elif randint==4:# rotation 270 48 | M=cv2.getRotationMatrix2D((image.shape[1]//2,image.shape[0]//2),270,1.0) 49 | image=cv2.warpAffine(image,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 50 | label=cv2.warpAffine(label,M,(image.shape[1],image.shape[0]),flags=cv2.INTER_NEAREST) 51 | return image,label 52 | 53 | def middle_data_augmen(image,label):#middle 54 | H,W=label.shape 55 | if H>=256 and W>=256: 56 | # random crop 57 | h=random.randint(0,H-256) 58 | w=random.randint(0,W-256) 59 | img=image[h:h+256,w:w+256,:] 60 | lab=label[h:h+256,w:w+256] 61 | else : 62 | # less than 256 ,follow the minimal to 256,and the other to be int(256/min*max) 63 | if H256 64 | image=cv2.resize(image,(int(W*256/H),256))# default INTER_LINEAR 65 | label=cv2.resize(label,(int(W*256/H),256),interpolation=cv2.INTER_NEAREST) 66 | else :# W<=H ,W is min =>256 67 | image=cv2.resize(image,(256,int(H*256/W))) 68 | label=cv2.resize(label,(256,int(H*256/W)),interpolation=cv2.INTER_NEAREST) 69 | H,W=label.shape 70 | h=random.randint(0,H-256) 71 | w=random.randint(0,W-256) 72 | img=image[h:h+256,w:w+256,:] 73 | lab=label[h:h+256,w:w+256] 74 | return img,lab 75 | 76 | # process data 77 | num_i=0 78 | for i in range(num): 79 | image=cv2.imread(os.path.join(train_img_path,names[i]),-1) 80 | name=names[i].split('.')[0]+'.png' 81 | print(name) 82 | label=cv2.imread(os.path.join(train_lab_path,name),-1) 83 | for j in range(5): 84 | image_f,label_f=first_data_augmen(image,label)# random scale size 85 | image_m,label_m=middle_data_augmen(image_f,label_f)# random crop 86 | img,lab=final_data_augmen(image_m,label_m)# random flip rotation or normal 87 | if lab.max()>150: 88 | print('########################') 89 | break 90 | cv2.imwrite(os.path.join(s_img_path,str(num_i)+'.jpg'),img) 91 | cv2.imwrite(os.path.join(s_lab_path,str(num_i)+'.png'),lab) 92 | 93 | num_i+=1 94 | print(num_i==N1) 95 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | demo for predict one image, and visualizate. 3 | ''' 4 | import os 5 | import cv2 6 | import yaml 7 | import torch 8 | import random 9 | import argparse 10 | import numpy as np 11 | from PIL import Image 12 | from utils.util import check_path 13 | import torchvision.transforms.functional as tf 14 | 15 | img_formats = ['jpg', 'png', 'tif', 'jpeg'] 16 | 17 | def predict(args, save_img=True): 18 | # load parameter.yaml 19 | with open(args.cfg_path, 'r', encoding='utf-8') as f: 20 | param_dict = yaml.load(f, Loader=yaml.FullLoader) 21 | mean = param_dict['mean'] 22 | std = param_dict['std'] 23 | class_names = param_dict['class_names'] 24 | 25 | # save dir check exist 26 | check_path(args.output) 27 | 28 | # load model 29 | device = torch.device("cuda:{}".format(args.device)) 30 | model = torch.load(args.weights, map_location=device)['model'].eval() # if fail, before eval() add .to(device) 31 | 32 | # label color, _background_->(0,0,0) 33 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in class_names[1:]] 34 | colors.insert(0, [0,0,0]) 35 | 36 | # imread image 37 | names = os.listdir(args.source) 38 | img_names = [name for name in names if name.split('.')[-1].lower() in img_formats] 39 | 40 | with torch.no_grad(): 41 | for img_name in img_names: 42 | img = cv2.imread(os.path.join(args.source, img_name), cv2.IMREAD_COLOR)[..., ::-1] # bgr->rgb 43 | img_ = Image.fromarray(img, mode="RGB") 44 | # ToTensor -> Normalize -> gpu [1,3,H,W] 45 | img_norm = tf.normalize(tf.to_tensor(img_), mean, std).to(device) 46 | 47 | pred = model(img_norm) # [1,Class_num,H,W] 48 | pred_ = pred.argmax(dim=1)[0].cpu().numpy() 49 | pred3 = np.zeros(pred_.shape+(3,), dtype=np.uint8) 50 | if save_img: 51 | for i in range(len(class_names)): 52 | pred3[pred_==i] = colors[i] 53 | cv2.imwrite(os.path.join(args.output, img_name), pred3) 54 | 55 | return 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser("Semantic segmentation Predict") 60 | parser.add_argument("--weights", type=str, default=r'./runs/exp/weights/best.pt', help="weight's path") 61 | parser.add_argument("--cfg_path", type=str, default=r'./configs/parameter.yaml', help="config param") 62 | parser.add_argument("--source", type=str, default=r'./data/images', help="input source") 63 | parser.add_argument("--device", type=str, default='0', help="gpu id, suggest 1 gpu") 64 | parser.add_argument("--output", type=str, default=r'./outputs', help="save dir") 65 | # parser.add_argument("--img_size", type=int, default=512, help="input image size") # Any size 66 | args = parser.parse_args() 67 | 68 | predict(args, save_img=True) 69 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author is leilei 3 | ''' 4 | import os 5 | import torch 6 | import yaml 7 | import math 8 | import argparse 9 | from torch import nn 10 | from torch import optim 11 | from models.models import * 12 | from utils.dataset import load_data 13 | from utils.util import init_seeds, check_path, increment_path 14 | from utils.trainval import train 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | 18 | def main(args): 19 | # read super parameters 20 | with open(args.cfg_path, 'r', encoding='utf-8') as f: 21 | param_dict = yaml.load(f, Loader=yaml.FullLoader) 22 | 23 | # creat save folder path 24 | save_dir = increment_path(args.project) # str 25 | check_path(save_dir) 26 | param_dict['save_dir'] = save_dir # update to param_dict 27 | param_dict['model_name'] = args.model_name # update to param_dict 28 | # tensorboard 29 | tb_writer = SummaryWriter(save_dir) 30 | 31 | # set gpu 32 | os.environ['CUDA_VISIBLE_DEVICES'] = param_dict['device'] 33 | 34 | # data loader 35 | data_loader = load_data(params=param_dict) 36 | 37 | init_seeds(seed=1) # activation cudnn 38 | model = deeplab_v3_plus(class_number=param_dict['class_number'], fine_tune=True, backbone='resnet50').cuda() 39 | continue_epoch = 0 40 | if args.resume: 41 | model_dict = model.state_dict() 42 | pretrained_file = torch.load(args.resume) 43 | pretrained_dict = pretrained_file['model'].float().state_dict() 44 | continue_epoch = pretrained_file['epoch'] if 'epoch' in pretrained_file else 0 45 | pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict and v.size()==model_dict[k[7:]].size()} 46 | assert len(pretrained_dict) == len(model_dict), "Unsuccessful import weight" 47 | model_dict.update(pretrained_dict) 48 | model.load_state_dict(model_dict) 49 | model = nn.DataParallel(model) # keys add '.module', and has .module attribute 50 | 51 | # TODO 许多要增加的,先完成,再改善!先成v1版本,再弄v2版本 52 | if args.adam: 53 | optimizer = optim.Adam(model.module.parameters(), lr=param_dict['lr0'], betas=(param_dict['momentum'], 0.999), weight_decay=5e-4) 54 | else: 55 | optimizer = optim.SGD(model.module.parameters(), lr=param_dict['lr0'], momentum=param_dict['momentum'], weight_decay=5e-4) 56 | 57 | # set lr_scheduler Cosine Annealing 58 | lf = lambda x: ((1 + math.cos(x * math.pi / param_dict['epoches'])) / 2) * (1 - param_dict['lrf']) + param_dict['lrf'] # cosine 59 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # no save scheduler params 60 | 61 | # train stage 62 | train(data_loader, model, optimizer, scheduler, tb_writer, param_dict, continue_epoch) 63 | 64 | return 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--model_name', type=str, default='deeplab_v3_plus-resnet50', help='model name') 70 | parser.add_argument('--project', type=str, default='./runs/exp', help='weight and summary... folder') 71 | parser.add_argument('--cfg_path', type=str, default='./configs/parameter.yaml', help='parameter config file') 72 | parser.add_argument('--resume', type=str, default='', help='resume most recent training') 73 | parser.add_argument('--adam', type=bool, default=False, help='Adam optimizer or SGD optimizer') 74 | 75 | args = parser.parse_args() 76 | main(args) 77 | -------------------------------------------------------------------------------- /models/deeplab_v3_plus.py: -------------------------------------------------------------------------------- 1 | """ 2 | code's author: leilei 3 | """ 4 | 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | ''' 11 | also fine-tune 12 | deeplab_v3+ : pytorch resnet 18/34 Basicblock 13 | resnet 50/101/152 Bottleneck 14 | this is not original deeplab_v3+, just be based on pytorch's resnet, so many different. 15 | ''' 16 | class ASPP(nn.Module): 17 | # have bias and relu, no bn 18 | def __init__(self,in_channel=512, depth=256): 19 | super().__init__() 20 | # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True) 21 | self.mean = nn.AdaptiveAvgPool2d((1,1)) 22 | self.conv = nn.Sequential(nn.Conv2d(in_channel,depth,1,1), nn.ReLU(inplace=True)) 23 | 24 | self.atrous_block1 = nn.Sequential(nn.Conv2d(in_channel,depth,1,1), 25 | nn.ReLU(inplace=True)) 26 | self.atrous_block6 = nn.Sequential(nn.Conv2d(in_channel,depth,3,1,padding=6,dilation=6), 27 | nn.ReLU(inplace=True)) 28 | self.atrous_block12 = nn.Sequential(nn.Conv2d(in_channel,depth,3,1,padding=12,dilation=12), 29 | nn.ReLU(inplace=True)) 30 | self.atrous_block18 = nn.Sequential(nn.Conv2d(in_channel,depth,3,1,padding=18,dilation=18), 31 | nn.ReLU(inplace=True)) 32 | 33 | self.conv_1x1_output= nn.Sequential(nn.Conv2d(depth*5,depth,1,1), nn.ReLU(inplace=True)) 34 | 35 | def forward(self,x): 36 | size = x.shape[2:] 37 | 38 | image_features = self.mean(x) 39 | image_features = self.conv(image_features) 40 | image_features = F.upsample(image_features, size=size, mode='bilinear', align_corners=True) 41 | 42 | atrous_block1 = self.atrous_block1(x) 43 | 44 | atrous_block6 = self.atrous_block6(x) 45 | 46 | atrous_block12 = self.atrous_block12(x) 47 | 48 | atrous_block18 = self.atrous_block18(x) 49 | 50 | net = self.conv_1x1_output(torch.cat([image_features,atrous_block1,atrous_block6, 51 | atrous_block12,atrous_block18],dim=1)) 52 | return net 53 | 54 | class Deeplab_v3_plus(nn.Module): 55 | # in_channel = 3 fine-tune 56 | def __init__(self, class_number=5, fine_tune=True, backbone='resnet50'): 57 | super().__init__() 58 | # 可选择resnet系列不同大小的网络 59 | encoder = getattr(torchvision.models, backbone)(pretrained=fine_tune) 60 | self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) 61 | 62 | self.maxpool = encoder.maxpool 63 | self.low_feature = nn.Sequential(nn.Conv2d(64,48,1,1),nn.ReLU(inplace=True)) # no bn, has bias and relu 64 | 65 | self.layer1 = encoder.layer1 # 256 66 | self.layer2 = encoder.layer2 # 512 67 | self.layer3 = encoder.layer3 # 1024 68 | self.layer4 = encoder.layer4 # 2048 69 | 70 | self.aspp = ASPP(in_channel=self.layer4[-1].conv1.in_channels, depth=256) 71 | 72 | self.conv_cat = nn.Sequential(nn.Conv2d(256+48,256,3,1,padding=1),nn.ReLU(inplace=True)) 73 | self.conv_cat1 = nn.Sequential(nn.Conv2d(256,256,3,1,padding=1),nn.ReLU(inplace=True)) 74 | self.conv_cat2 = nn.Sequential(nn.Conv2d(256,256,3,1,padding=1),nn.ReLU(inplace=True)) 75 | self.score = nn.Conv2d(256,class_number,1,1)# no relu and first conv then upsample, reduce memory 76 | 77 | def forward(self,x): 78 | size1 = x.shape[2:] # need upsample input size 79 | x = self.start(x) 80 | xm = self.maxpool(x) 81 | 82 | x = self.layer1(xm) 83 | x = self.layer2(x) 84 | x = self.layer3(x) 85 | x = self.layer4(x) 86 | x = self.aspp(x) 87 | 88 | low_feature = self.low_feature(xm) 89 | size2 = low_feature.shape[2:] 90 | decoder_feature = F.upsample(x,size=size2,mode='bilinear',align_corners=True) 91 | 92 | conv_cat = self.conv_cat( torch.cat([low_feature,decoder_feature],dim=1) ) 93 | conv_cat1 = self.conv_cat1(conv_cat) 94 | conv_cat2 = self.conv_cat2(conv_cat1) 95 | score_small = self.score(conv_cat2) 96 | score = F.upsample(score_small,size=size1,mode='bilinear',align_corners=True) 97 | 98 | return score 99 | 100 | -------------------------------------------------------------------------------- /models/hed_series/hed_res.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: LeiLei 3 | """ 4 | 5 | ''' 6 | HED是基于VGG16构建的, 7 | 基于VGG16 或者resnet34 系列进行类似HED网络结构构建。 8 | 核心:就是3或者4或者5个尺度变化,而且pytorch也有每个残差block的类 输出属性,直接调用。 9 | 类似tensorflow的slim的output_collections。 10 | ''' 11 | import torch 12 | import torchvision 13 | from torch import nn 14 | 15 | class HED_res34(nn.Module): 16 | def __init__(self, num_filters=32, pretrained=False, class_number=2): 17 | super().__init__() 18 | encoder = torchvision.models.resnet34(pretrained=pretrained) 19 | 20 | self.pool = nn.MaxPool2d(3, 2, 1) 21 | 22 | # start 23 | self.start = nn.Sequential(self.encoder.conv1, self.encoder.bn1, self.encoder.relu) # 128*128 24 | self.d_convs = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) 25 | self.scores = nn.UpsamplingBilinear2d(scale_factor=2) # 256*256 26 | 27 | self.layer1 = encoder.layer1 # 64*64 28 | self.d_conv1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) 29 | self.score1 = nn.UpsamplingBilinear2d(scale_factor=4) # 256*256 30 | 31 | self.layer2 = encoder.layer2 # 32*32 32 | self.d_conv2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) 33 | self.score2 = nn.UpsamplingBilinear2d(scale_factor=8) # 256*256 34 | 35 | self.layer3 = encoder.layer3 # 16*16 36 | self.d_conv3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) 37 | self.score3 = nn.UpsamplingBilinear2d(scale_factor=16) # 256*256 38 | 39 | self.layer4 = encoder.layer4 # 8*8 40 | self.d_conv4 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 41 | self.score4 = nn.UpsamplingBilinear2d(scale_factor=32) # 256*256 42 | 43 | self.score = nn.Conv2d(5, class_number, 1, 1) # No relu loss_func has softmax 44 | 45 | def forward(self, x): 46 | x = self.start(x) 47 | s_x = self.d_convs(x) 48 | ss = self.scores(s_x) 49 | x = self.pool(x) 50 | 51 | x = self.layer1(x) 52 | s_x = self.d_conv1(x) 53 | s1 = self.score1(s_x) 54 | 55 | x = self.layer2(x) 56 | s_x = self.d_conv2(x) 57 | s2 = self.score2(s_x) 58 | 59 | x = self.layer3(x) 60 | s_x = self.d_conv3(x) 61 | s3 = self.score3(s_x) 62 | 63 | x = self.layer4(x) 64 | s_x = self.d_conv4(x) 65 | s4 = self.score4(s_x) 66 | 67 | score = self.score(torch.cat([s1, s2, s3, s4, ss], axis=1)) 68 | 69 | return score 70 | 71 | 72 | # hed2 = HED_res34() 73 | # print(hed2) 74 | # print(hed2.state_dict().keys()) -------------------------------------------------------------------------------- /models/hed_series/hed_vgg16.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: LeiLei 3 | """ 4 | 5 | ''' 6 | HED是基于VGG16构建的, 7 | 基于VGG16 或者resnet34 系列进行类似HED网络结构构建。 8 | 核心:就是3或者4或者5个尺度变化,而且pytorch也有每个残差block的类 输出属性,直接调用。 9 | 类似tensorflow的slim的output_collections。 10 | 11 | VGG16 network 12 | Sequential( 13 | (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 14 | (1): ReLU(inplace) 15 | (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | (3): ReLU(inplace) 17 | (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 18 | (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 19 | (6): ReLU(inplace) 20 | (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 21 | (8): ReLU(inplace) 22 | (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 23 | (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 24 | (11): ReLU(inplace) 25 | (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 26 | (13): ReLU(inplace) 27 | (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 28 | (15): ReLU(inplace) 29 | (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 30 | (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 31 | (18): ReLU(inplace) 32 | (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 33 | (20): ReLU(inplace) 34 | (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 35 | (22): ReLU(inplace) 36 | (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 37 | (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 38 | (25): ReLU(inplace) 39 | (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 40 | (27): ReLU(inplace) 41 | (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 42 | (29): ReLU(inplace) 43 | (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 44 | ) 45 | ''' 46 | import torch 47 | import torchvision 48 | from torch import nn 49 | 50 | # input size [256,256] or [512,512] 51 | # 基于vgg16 hed 52 | class HED_vgg16(nn.Module): 53 | def __init__(self, num_filters=32, pretrained=False, class_number=2): 54 | # Here is the function part, with no braces () 55 | super().__init__() 56 | encoder = torchvision.models.vgg16(pretrained=pretrained).features 57 | 58 | self.pool = nn.MaxPool2d(2, 2) 59 | 60 | self.conv1 = encoder[0:4] 61 | self.score1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) # 256*256 62 | 63 | self.conv2 = encoder[5:9] 64 | self.d_conv2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) # 128*128 65 | self.score2 = nn.UpsamplingBilinear2d(scale_factor=2) # 256*256 66 | 67 | self.conv3 = encoder[10:16] 68 | self.d_conv3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # 64*64 69 | self.score3 = nn.UpsamplingBilinear2d(scale_factor=4) # 256*256 70 | 71 | self.conv4 = encoder[17:23] 72 | self.d_conv4 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # 32*32 73 | self.score4 = nn.UpsamplingBilinear2d(scale_factor=8) # 256*256 74 | 75 | self.conv5 = encoder[24:30] 76 | self.d_conv5 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # 16*16 77 | self.score5 = nn.UpsamplingBilinear2d(scale_factor=16) # 256*256 78 | 79 | self.score = nn.Conv2d(5, class_number, 1, 1) # No relu 80 | 81 | def forward(self, x): 82 | # Here is the part that calculates the return value 83 | x = self.conv1(x) 84 | s1 = self.score1(x) 85 | x = self.pool(x) 86 | 87 | x = self.conv2(x) 88 | s_x = self.d_conv2(x) 89 | s2 = self.score2(s_x) 90 | x = self.pool(x) 91 | 92 | x = self.conv3(x) 93 | s_x = self.d_conv3(x) 94 | s3 = self.score3(s_x) 95 | x = self.pool(x) 96 | 97 | x = self.conv3(x) 98 | s_x = self.d_conv4(x) 99 | s4 = self.score4(s_x) 100 | x = self.pool(x) 101 | 102 | x = self.conv5(x) 103 | s_x = self.d_conv5(x) 104 | s5 = self.score5(s_x) 105 | 106 | score = self.score(torch.cat([s1, s2, s3, s4, s5], axis=1)) 107 | 108 | return score 109 | 110 | 111 | ''' you need to write softmax after model and predict output by yourself ''' 112 | # hed1 = HED_vgg16() 113 | # print(hed1) 114 | # print(hed1.state_dict().keys()) -------------------------------------------------------------------------------- /models/hed_series/hf_fcn_res.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: LeiLei 3 | """ 4 | 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | 11 | # 基于resnet34 hf_fcn better hed 12 | class HF_res34(nn.Module): 13 | def __init__(self, class_number=2, pretrained=True, num_filters=32): 14 | super().__init__() 15 | encoder = torchvision.models.resnet34(pretrained=pretrained) 16 | 17 | # start 18 | self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) # 128*128 19 | self.d_convs = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 20 | 21 | self.pool = nn.MaxPool2d(3, 2, 1) 22 | 23 | # layer1 24 | self.layer10 = encoder.layer1[0] 25 | self.layer11 = encoder.layer1[1] 26 | self.layer12 = encoder.layer1[2] 27 | 28 | self.d_conv10 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 29 | self.d_conv11 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 30 | self.d_conv12 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 31 | 32 | # layer2 33 | self.layer20 = encoder.layer2[0] 34 | self.layer21 = encoder.layer2[1] 35 | self.layer22 = encoder.layer2[2] 36 | self.layer23 = encoder.layer2[3] 37 | 38 | self.d_conv20 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 39 | self.d_conv21 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 40 | self.d_conv22 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 41 | self.d_conv23 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 42 | 43 | # layer3 44 | self.layer30 = encoder.layer3[0] 45 | self.layer31 = encoder.layer3[1] 46 | self.layer32 = encoder.layer3[2] 47 | self.layer33 = encoder.layer3[3] 48 | self.layer34 = encoder.layer3[4] 49 | self.layer35 = encoder.layer3[5] 50 | 51 | self.d_conv30 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 52 | self.d_conv31 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 53 | self.d_conv32 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 54 | self.d_conv33 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 55 | self.d_conv34 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 56 | self.d_conv35 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 57 | 58 | # layer4 59 | self.layer40 = encoder.layer4[0] 60 | self.layer41 = encoder.layer4[1] 61 | self.layer42 = encoder.layer4[2] 62 | 63 | self.d_conv40 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 64 | self.d_conv41 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 65 | self.d_conv42 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) # no bn 66 | 67 | self.score = nn.Conv2d(17, class_number, 1, 1) # No relu loss_func has softmax 68 | 69 | def forward(self, x): 70 | input_size = x.shape[2:] # 可以获取 x的形状 那么采用 F.upsample 71 | 72 | x = self.start(x) 73 | s_x = self.d_convs(x) 74 | ss = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 75 | ''' why no relu after upsample ? because before it , d_conv has relu ''' 76 | x = self.pool(x) 77 | # layer1 78 | x = self.layer10(x) 79 | s_x = self.d_conv10(x) 80 | s10 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 81 | x = self.layer11(x) 82 | s_x = self.d_conv11(x) 83 | s11 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 84 | x = self.layer12(x) 85 | s_x = self.d_conv12(x) 86 | s12 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 87 | 88 | # layer2 89 | x = self.layer20(x) 90 | s_x = self.d_conv20(x) 91 | s20 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 92 | x = self.layer21(x) 93 | s_x = self.d_conv21(x) 94 | s21 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 95 | x = self.layer22(x) 96 | s_x = self.d_conv22(x) 97 | s22 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 98 | x = self.layer23(x) 99 | s_x = self.d_conv23(x) 100 | s23 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 101 | 102 | # layer3 103 | x = self.layer30(x) 104 | s_x = self.d_conv30(x) 105 | s30 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 106 | x = self.layer31(x) 107 | s_x = self.d_conv31(x) 108 | s31 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 109 | x = self.layer32(x) 110 | s_x = self.d_conv32(x) 111 | s32 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 112 | x = self.layer33(x) 113 | s_x = self.d_conv33(x) 114 | s33 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 115 | x = self.layer34(x) 116 | s_x = self.d_conv34(x) 117 | s34 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 118 | x = self.layer35(x) 119 | s_x = self.d_conv35(x) 120 | s35 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 121 | 122 | # layer4 123 | x = self.layer40(x) 124 | s_x = self.d_conv40(x) 125 | s40 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 126 | x = self.layer41(x) 127 | s_x = self.d_conv41(x) 128 | s41 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 129 | x = self.layer42(x) 130 | s_x = self.d_conv42(x) 131 | s42 = F.upsample(s_x, size=input_size, mode='bilinear', align_corners=True) 132 | 133 | cat = [ss, 134 | s10, s11, s12, 135 | s20, s21, s22, s23, 136 | s30, s31, s32, s33, s34, s35, 137 | s40, s41, s42] 138 | # score 139 | score = self.score(torch.cat(cat, dim=1)) 140 | 141 | return score 142 | 143 | 144 | def hf_res34(class_number=5, fine_tune=True): 145 | model = HF_res34(class_number=class_number, pretrained=fine_tune) 146 | return model -------------------------------------------------------------------------------- /models/hed_series/hf_fcn_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class HF_FCN(nn.Module): 8 | def __init__(self, class_number=2, pretrained=True, num_filters=32): 9 | super().__init__() 10 | encoder = torchvision.models.vgg16(pretrained=pretrained).features 11 | self.maxpool = encoder[4] 12 | 13 | self.conv1_1 = encoder[0:2] 14 | self.dconv1_1 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) 15 | self.conv1_2 = encoder[2:4] 16 | self.dconv1_2 = nn.Sequential(nn.Conv2d(num_filters * 2, 1, 1, 1), nn.ReLU(inplace=True)) 17 | # 1/2 18 | self.conv2_1 = encoder[5:7] 19 | self.dconv2_1 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) 20 | self.conv2_2 = encoder[7:9] 21 | self.dconv2_2 = nn.Sequential(nn.Conv2d(num_filters * 4, 1, 1, 1), nn.ReLU(inplace=True)) 22 | # 1/4 23 | self.conv3_1 = encoder[10:12] 24 | self.dconv3_1 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) 25 | self.conv3_2 = encoder[12:14] 26 | self.dconv3_2 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) 27 | self.conv3_3 = encoder[14:16] 28 | self.dconv3_3 = nn.Sequential(nn.Conv2d(num_filters * 8, 1, 1, 1), nn.ReLU(inplace=True)) 29 | # 1/8 30 | self.conv4_1 = encoder[17:19] 31 | self.dconv4_1 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 32 | self.conv4_2 = encoder[19:21] 33 | self.dconv4_2 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 34 | self.conv4_3 = encoder[21:23] 35 | self.dconv4_3 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 36 | # 1/16 37 | self.conv5_1 = encoder[24:26] 38 | self.dconv5_1 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 39 | self.conv5_2 = encoder[26:28] 40 | self.dconv5_2 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 41 | self.conv5_3 = encoder[28:30] 42 | self.dconv5_3 = nn.Sequential(nn.Conv2d(num_filters * 16, 1, 1, 1), nn.ReLU(inplace=True)) 43 | 44 | self.score = nn.Conv2d(13, class_number, 1, 1) 45 | 46 | def forward(self, x): 47 | size = x.shape[2:] 48 | 49 | x = self.conv1_1(x) 50 | s1_1 = self.dconv1_1(x) 51 | x = self.conv1_2(x) 52 | s1_2 = self.dconv1_2(x) 53 | x = self.maxpool(x) 54 | 55 | x = self.conv2_1(x) 56 | s = self.dconv2_1(x) # first reduce out_channels then upsample 57 | s2_1 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 58 | x = self.conv2_2(x) 59 | s = self.dconv2_2(x) 60 | s2_2 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 61 | x = self.maxpool(x) 62 | 63 | x = self.conv3_1(x) 64 | s = self.dconv3_1(x) 65 | s3_1 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 66 | x = self.conv3_2(x) 67 | s = self.dconv3_2(x) 68 | s3_2 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 69 | x = self.conv3_3(x) 70 | s = self.dconv3_3(x) 71 | s3_3 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 72 | x = self.maxpool(x) 73 | 74 | x = self.conv4_1(x) 75 | s = self.dconv4_1(x) 76 | s4_1 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 77 | x = self.conv4_2(x) 78 | s = self.dconv4_2(x) 79 | s4_2 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 80 | x = self.conv4_3(x) 81 | s = self.dconv4_3(x) 82 | s4_3 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 83 | x = self.maxpool(x) 84 | 85 | x = self.conv5_1(x) 86 | s = self.dconv5_1(x) 87 | s5_1 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 88 | x = self.conv5_2(x) 89 | s = self.dconv5_2(x) 90 | s5_2 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 91 | x = self.conv5_3(x) 92 | s = self.dconv5_3(x) 93 | s5_3 = F.upsample(s, size=size, mode='bilinear', align_corners=True) 94 | 95 | score = self.score(torch.cat([s1_1, s1_2, 96 | s2_1, s2_2, 97 | s3_1, s3_2, s3_3, 98 | s4_1, s4_2, s4_3, 99 | s5_1, s5_2, s5_3], dim=1)) # no relu 100 | return score 101 | 102 | 103 | def hf_fcn(class_number=5, fine_tune=True): 104 | model = HF_FCN(class_number=class_number, pretrained=fine_tune) 105 | return model -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | get all kinds of models 3 | 获取网络均从此处获取 4 | ''' 5 | from .deeplab_v3_plus import Deeplab_v3_plus 6 | 7 | 8 | def deeplab_v3_plus(class_number=5, fine_tune=True, backbone='resnet50'): 9 | model = Deeplab_v3_plus(class_number=class_number, fine_tune=fine_tune, backbone=backbone) 10 | return model -------------------------------------------------------------------------------- /models/pspnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: leilei 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | import torchvision 8 | import torch.nn.functional as F 9 | 10 | ''' 11 | Note: 12 | PSPNet: first conv7k_2s modify conv3k_2s/conv3k_1s/conv3k_1s(3 layers) 13 | each downsample block: first conv1k_1s modify conv1k_2s; second conv3k_2s modify conv3k_1s 14 | layer1: no downsample 15 | layer2: downsample 16 | layer3: no downsample; each block the second conv3x3 modify atros_conv3k_2r 17 | layer4: no downsample; each block the second conv3x3 modify atros_conv3k_4r 18 | Note: Resnet no bias,so bias = False 19 | 20 | ''' 21 | 22 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding bn relu and no bias""" 24 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, 25 | bias=False), 26 | nn.BatchNorm2d(out_planes), 27 | nn.ReLU(inplace=True)) 28 | 29 | def conv3x3(in_planes, out_planes, stride=1): 30 | """3x3 convolution with padding and no bias""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, 32 | bias=False) 33 | 34 | 35 | def conv1x1(in_planes, out_planes, stride=1): 36 | """1x1 convolution and no bias; downsample 1/stride""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | 40 | def atrous_conv3x3(in_planes, out_planes, rate=1, padding=1, stride=1): 41 | """3x3 atrous convolution and no bias""" 42 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 43 | dilation=rate, bias=False) 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, first_inplanes, inplanes, planes, rate=1, padding=1, stride=1, downsample=None): 50 | ''' 51 | pspnet conv1_3's num_output=128 not 64 so we modify some code 52 | first_inplanes: only layer1 not same (conv1_3)128 != (layer1-block1-conv1k_1s)64 53 | ''' 54 | super().__init__() 55 | self.conv1 = conv1x1(inplanes, planes, stride) #### 56 | self.bn1 = nn.BatchNorm2d(planes) 57 | self.conv2 = atrous_conv3x3(planes, planes, rate, padding) #### 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = conv1x1(planes, planes * self.expansion) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | # only first layer1 block in_channel different 66 | if (first_inplanes != inplanes) and (downsample is not None): 67 | self.conv1 = conv1x1(first_inplanes, planes, stride) 68 | self.downsample = nn.Sequential(conv1x1(first_inplanes, planes * self.expansion, stride), 69 | nn.BatchNorm2d(planes * self.expansion)) 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class SppBlock(nn.Module): 95 | # no bias 96 | def __init__(self, level, in_channel=2048, out_numput=512): 97 | super().__init__() 98 | self.level = level 99 | self.convblock = nn.Sequential(conv1x1(in_channel, out_numput), 100 | nn.BatchNorm2d(out_numput), nn.ReLU(inplace=True)) 101 | 102 | def forward(self, x): 103 | size = x.shape[2:] 104 | x = F.adaptive_avg_pool2d(x, output_size=(self.level, self.level)) # average pool 105 | x = self.convblock(x) 106 | x = F.upsample(x, size=size, mode='bilinear', align_corners=True) 107 | 108 | return x 109 | 110 | 111 | class SppBlock1(nn.Module): 112 | # no bias k=10/20/30/60 113 | def __init__(self, level, k, s, in_channel=2048, out_numput=512): 114 | super().__init__() 115 | self.level = level 116 | self.avgpool = nn.AvgPool2d(k, s) 117 | self.convblock = nn.Sequential(conv1x1(in_channel, out_numput), 118 | nn.BatchNorm2d(out_numput), nn.ReLU(inplace=True)) 119 | 120 | def forward(self, x): 121 | size = x.shape[2:] 122 | x = self.avgpool(x) 123 | x = self.convblock(x) 124 | x = F.upsample(x, size=size, mode='bilinear', align_corners=True) 125 | 126 | return x 127 | 128 | 129 | class SPP(nn.Module): 130 | def __init__(self, in_channel=2048): 131 | super().__init__() 132 | self.spp1 = SppBlock(level=1, in_channel=in_channel) 133 | self.spp2 = SppBlock(level=2, in_channel=in_channel) 134 | self.spp3 = SppBlock(level=3, in_channel=in_channel) 135 | self.spp6 = SppBlock(level=6, in_channel=in_channel) 136 | 137 | def forward(self, x): 138 | # x 2048 num_output 139 | x1 = self.spp1(x) 140 | x2 = self.spp2(x) 141 | x3 = self.spp3(x) 142 | x6 = self.spp6(x) 143 | out = torch.cat([x, x1, x2, x3, x6], dim=1) 144 | 145 | return out 146 | 147 | 148 | class PSPNet(nn.Module): 149 | def __init__(self, block, layers, class_number, dropout_rate=0.2, in_channel=3): 150 | super().__init__() 151 | self.inplanes = 64 152 | self.conv1_1 = conv3x3_bn_relu(in_channel, 64, stride=2) 153 | self.conv1_2 = conv3x3_bn_relu(64, 64) 154 | self.conv1_3 = conv3x3_bn_relu(64, 128) 155 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 156 | 157 | self.layer1 = self._make_layer(block, 128, 64, layers[0]) # 64 / 256 158 | self.layer2 = self._make_layer(block, 256, 128, layers[1], stride=2) # 128 / 512 159 | self.layer3 = self._make_layer(block, 512, 256, layers[2], rate=2, padding=2) # 256 / 1024 160 | self.layer4 = self._make_layer(block, 1024, 512, layers[2], rate=4, padding=4) # 512 / 2048 161 | 162 | self.spp = SPP(in_channel=2048) 163 | 164 | self.conv5_4 = conv3x3_bn_relu(2048 + 512 * 4, 512) ##if you want modify in_channel, need your own modify## 165 | 166 | self.dropout = nn.Dropout2d(p=dropout_rate) 167 | self.conv6 = nn.Conv2d(512, class_number, 1, 1) 168 | 169 | ''' init weight ''' 170 | print('## init weight ##') 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 174 | if m.bias is not None: 175 | nn.init.constant_(m.bias, 0) 176 | if isinstance(m, nn.BatchNorm2d): 177 | nn.init.constant_(m.weight, 1) 178 | nn.init.constant_(m.bias, 0) 179 | # no convtranspose linear 180 | 181 | def forward(self, x): 182 | size = x.shape[2:] 183 | x = self.conv1_1(x) 184 | x = self.conv1_2(x) 185 | x = self.conv1_3(x) 186 | x = self.maxpool(x) 187 | x = self.layer1(x) 188 | x = self.layer2(x) 189 | x = self.layer3(x) 190 | x = self.layer4(x) 191 | x = self.spp(x) 192 | x = self.conv5_4(x) 193 | x = self.dropout(x) 194 | x = self.conv6(x) 195 | x = F.upsample(x, size, mode='bilinear', align_corners=True) 196 | 197 | return x 198 | 199 | '''first_inplanes, inplanes, planes, rate=1, padding=1, stride=1, downsample=None''' 200 | 201 | def _make_layer(self, block, first_inplanes, planes, blocks, rate=1, padding=1, stride=1): 202 | downsample = None 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | downsample = nn.Sequential( 205 | conv1x1(self.inplanes, planes * block.expansion, stride), # with down stride same 206 | nn.BatchNorm2d(planes * block.expansion)) 207 | 208 | layers = [] 209 | layers.append(block(first_inplanes, self.inplanes, planes, rate, padding, stride, downsample)) 210 | self.inplanes = planes * block.expansion 211 | for _ in range(1, blocks): 212 | layers.append(block(self.inplanes, self.inplanes, planes, rate, padding)) 213 | 214 | return nn.Sequential(*layers) 215 | 216 | 217 | def pspnet(class_number, dropout_rate=1): 218 | model = PSPNet(Bottleneck, layers=[3, 4, 6, 3], class_number=class_number, dropout_rate=dropout_rate) 219 | return model 220 | -------------------------------------------------------------------------------- /models/spp.py: -------------------------------------------------------------------------------- 1 | ''' 2 | sppnet base on vgg16 3 | input size random ; but batch size need set to be 1 , and we don't use 'batch_size' . 4 | ''' 5 | 6 | import torch 7 | import torchvision 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | class SppNet(nn.Module): 12 | def __init__(self, batch_size=1, out_pool_size=[1, 2, 4], class_number=2): 13 | super().__init__() 14 | # use already written network , eg vgg16 15 | vgg = torchvision.models.vgg16(pretrained=False).features[:-1] 16 | self.out_pool_size = out_pool_size 17 | self.batch_size = batch_size 18 | # encoder 19 | self.encoder = vgg 20 | # spp if spp is a class , so create network ,it appear (spp) 21 | self.spp = self.make_spp(batch_size=batch_size, out_pool_size=out_pool_size) 22 | # FC 23 | sum0 = 0 24 | for i in out_pool_size: 25 | sum0 += i ** 2 26 | self.fc = nn.Sequential(nn.Linear(512 * sum0, 1024), nn.ReLU(inplace=True)) 27 | self.score = nn.Linear(1024, class_number) 28 | 29 | def make_spp(self, batch_size=1, out_pool_size=[1, 2, 4]): 30 | func = [] 31 | for i in range(len(out_pool_size)): 32 | func.append(nn.AdaptiveAvgPool2d(output_size=(out_pool_size[i], out_pool_size[i]))) 33 | return func 34 | 35 | def forward(self, x): 36 | assert x.shape[0] == 1, 'batch size need to set to be 1' 37 | encoder = self.encoder(x) 38 | spp = [] 39 | for i in range(len(self.out_pool_size)): 40 | spp.append(self.spp[i](encoder).view(self.batch_size, -1)) 41 | fc = self.fc(torch.cat(spp, dim=1)) 42 | score = self.score(fc) 43 | return score 44 | 45 | 46 | ''' or another ''' 47 | 48 | class SppNet1(nn.Module): 49 | def __init__(self, batch_size=1, out_pool_size=[1, 2, 4], class_number=2): 50 | super().__init__() 51 | # use already written network , eg vgg16 52 | vgg = torchvision.models.vgg16(pretrained=False).features[:-1] 53 | self.out_pool_size = out_pool_size 54 | self.batch_size = batch_size 55 | # encoder 56 | self.encoder = vgg 57 | # FC 58 | sum0 = 0 59 | for i in out_pool_size: 60 | sum0 += i ** 2 61 | self.fc = nn.Sequential(nn.Linear(512 * sum0, 1024), nn.ReLU(inplace=True)) 62 | self.score = nn.Linear(1024, class_number) 63 | 64 | def forward(self, x): 65 | assert x.shape[0] == 1, 'batch size need to set to be 1' 66 | encoder = self.encoder(x) 67 | spp = [] 68 | for i in range(len(self.out_pool_size)): 69 | spp.append(F.adaptive_avg_pool2d(encoder, output_size=(self.out_pool_size[i], self.out_pool_size[i])).view( 70 | self.batch_size, -1)) 71 | fc = self.fc(torch.cat(spp, dim=1)) 72 | score = self.score(fc) 73 | return score 74 | 75 | 76 | # spp = SppNet(class_number=2) 77 | # print(spp) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code's author is leilei 3 | ''' 4 | 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | ''' 12 | U_Net: original not based on vgg11 or vgg16 13 | only resnet has bias=False,so need you in write resnet notice bias=False 14 | batch_norm :is_training on pytorch is model.eval(); on tf is placeholder 15 | ''' 16 | 17 | def conv1x1_bn_relu(in_planes, out_planes, stride=1): 18 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride), 19 | nn.BatchNorm2d(out_planes), 20 | nn.ReLU(inplace=True)) 21 | 22 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 23 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1), 24 | nn.BatchNorm2d(out_planes), 25 | nn.ReLU(inplace=True)) 26 | 27 | def upsample(in_features, out_features): 28 | shape = out_features.shape[2:] # h w 29 | return F.upsample(in_features, size=shape, mode='bilinear', align_corners=True) 30 | 31 | def concat(in_features1, in_features2): 32 | return torch.cat([in_features1, in_features2], dim=1) 33 | 34 | class U_Net(nn.Module): 35 | def __init__(self, class_number=5, in_channels=3): 36 | super().__init__() 37 | # encoder 38 | self.conv1_1 = conv3x3_bn_relu(in_channels, 64) 39 | self.conv1_2 = conv3x3_bn_relu(64, 64) 40 | 41 | self.maxpool = nn.MaxPool2d(2, 2) # only one for all 42 | 43 | self.conv2_1 = conv3x3_bn_relu(64, 128) 44 | self.conv2_2 = conv3x3_bn_relu(128, 128) 45 | 46 | self.conv3_1 = conv3x3_bn_relu(128, 256) 47 | self.conv3_2 = conv3x3_bn_relu(256, 256) 48 | 49 | self.conv4_1 = conv3x3_bn_relu(256, 512) 50 | self.conv4_2 = conv3x3_bn_relu(512, 512) 51 | 52 | self.conv5_1 = conv3x3_bn_relu(512, 1024) 53 | self.conv5_2 = conv3x3_bn_relu(1024, 1024) 54 | 55 | # decoder 56 | self.conv6 = conv3x3_bn_relu(1024, 512) 57 | self.conv6_1 = conv3x3_bn_relu(1024, 512) ## 58 | self.conv6_2 = conv3x3_bn_relu(512, 512) 59 | 60 | self.conv7 = conv3x3_bn_relu(512, 256) 61 | self.conv7_1 = conv3x3_bn_relu(512, 256) ## 62 | self.conv7_2 = conv3x3_bn_relu(256, 256) 63 | 64 | self.conv8 = conv3x3_bn_relu(256, 128) 65 | self.conv8_1 = conv3x3_bn_relu(256, 128) ## 66 | self.conv8_2 = conv3x3_bn_relu(128, 128) 67 | 68 | self.conv9 = conv3x3_bn_relu(128, 64) 69 | self.conv9_1 = conv3x3_bn_relu(128, 64) ## 70 | self.conv9_2 = conv3x3_bn_relu(64, 64) 71 | 72 | self.score = nn.Conv2d(64, class_number, 1, 1) 73 | 74 | def forward(self, x): 75 | # encoder 76 | conv1_1 = self.conv1_1(x) 77 | conv1_2 = self.conv1_2(conv1_1) 78 | pool1 = self.maxpool(conv1_2) 79 | 80 | conv2_1 = self.conv2_1(pool1) 81 | conv2_2 = self.conv2_2(conv2_1) 82 | pool2 = self.maxpool(conv2_2) 83 | 84 | conv3_1 = self.conv3_1(pool2) 85 | conv3_2 = self.conv3_2(conv3_1) 86 | pool3 = self.maxpool(conv3_2) 87 | 88 | conv4_1 = self.conv4_1(pool3) 89 | conv4_2 = self.conv4_2(conv4_1) 90 | pool4 = self.maxpool(conv4_2) 91 | 92 | conv5_1 = self.conv5_1(pool4) 93 | conv5_2 = self.conv5_2(conv5_1) 94 | 95 | # decoder 96 | up6 = upsample(conv5_2, conv4_2) 97 | conv6 = self.conv6(up6) 98 | merge6 = concat(conv6, conv4_2) 99 | conv6_1 = self.conv6_1(merge6) 100 | conv6_2 = self.conv6_2(conv6_1) 101 | 102 | up7 = upsample(conv6_2, conv3_2) 103 | conv7 = self.conv7(up7) 104 | merge7 = concat(conv7, conv3_2) 105 | conv7_1 = self.conv7_1(merge7) 106 | conv7_2 = self.conv7_2(conv7_1) 107 | 108 | up8 = upsample(conv7_2, conv2_2) 109 | conv8 = self.conv8(up8) 110 | merge8 = concat(conv8, conv2_2) 111 | conv8_1 = self.conv8_1(merge8) 112 | conv8_2 = self.conv8_2(conv8_1) 113 | 114 | up9 = upsample(conv8_2, conv1_2) 115 | conv9 = self.conv9(up9) 116 | merge9 = concat(conv9, conv1_2) 117 | conv9_1 = self.conv9_1(merge9) 118 | conv9_2 = self.conv9_2(conv9_1) 119 | 120 | score = self.score(conv9_2) 121 | 122 | return score 123 | 124 | 125 | def unet_orig(class_number, in_channels=3): 126 | model = U_Net(class_number, in_channels) 127 | return model -------------------------------------------------------------------------------- /readmes/data_aug.md: -------------------------------------------------------------------------------- 1 | ### Dataset Augumentation 2 | 3 | #### data augs 4 | - [x] random flip 5 | - [x] random rotate 6 | - [x] random crop 7 | - [x] random noise 8 | - [x] hue-brightness-contrast-saturation 9 | - [x] zoom(in out) 10 | - [ ] ~~copy-paste?~~ 11 | - [ ] mosaic? 12 | 13 | #### data aug use PIL (3 channel RGB) 14 | + [dataset-PIL-Augumentations](../utils/aug_PIL.py) 15 | 16 | #### data aug use GDAL (>3 channel) 17 | + [dataset-GDAL-cv2-Augumentations](../utils/aug_GDAL.py) 18 | -------------------------------------------------------------------------------- /readmes/main_modify.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/readmes/main_modify.jpg -------------------------------------------------------------------------------- /readmes/param_modify.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengyanlei/segmentation_pytorch/5681dc088d0f9bbde461ee018ad7c36b6b16733c/readmes/param_modify.jpg -------------------------------------------------------------------------------- /readmes/train_cusom.md: -------------------------------------------------------------------------------- 1 | ### Getting Started 2 | 3 | #### Train 4 | + [Note](https://github.com/gengyanlei/segmentation_pytorch#note): 0 is the background, and the object category starts from 1! 5 | + You need to prepare the data in the format of [Dataset Details](https://github.com/gengyanlei/segmentation_pytorch/blob/master/readmes/train_cusom.md#dataset-details). 6 | + Modify the configuration file parameters [parameter.yaml](../configs/parameter.yaml). 7 | + Modify [main.py](../main.py) args's params and model-network-code. 8 | 9 |
10 | Figure Notes (click to expand) 11 |

12 |

13 |
14 | 15 | #### Test 16 | + Just to see [trainval.py](../utils/trainval.py) 17 | 18 | #### Dataset Details 19 | ``` 20 | root: 21 | images: 22 | labels: 23 | train.txt: 24 | /home/dataset/seg/images/train/aaa.jpg 25 | /home/dataset/seg/images/train/bbb.jpg 26 | test.txt: 27 | /home/dataset/seg/images/test/ccc.jpg 28 | 29 | how to match images and labels? 30 | '/home/dataset/seg/images/train/aaa.jpg'.replace('images', 'labels') 31 | or 32 | '/home/dataset/seg/labels/train/aaa.jpg'.replace('.jpg', '.png') 33 | 34 | data enhancement: 35 | random flip, rotate, crop, noise, 36 | hue-brightness-contrast-saturation, zoom(in out), copy-paste?, mosaic? 37 | ``` -------------------------------------------------------------------------------- /utils/aug_GDAL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import math 4 | import random 5 | import torch 6 | import numpy as np 7 | # from osgeo import gdal 8 | from torchvision.transforms import transforms 9 | import torchvision.transforms.functional as tf 10 | from utils.util import RandomResizedCrop_get_params 11 | 12 | ''' 13 | author is leilei 14 | 语义分割数据增强时,需将图像和标签图同时操作,对于旋转,偏移等操作,会引入黑边(均为0值), 15 | 将引入的黑边 视为1类,标签值默认为0,真实标签从1开始。 16 | 图像采用BILINEAR,标签图采用NEAREST 17 | 采用GDAL库,进行读取任意通道(尤其是>=4通道的影像),并结合cv2进行处理 18 | 由于GDAL数据增强操作很麻烦,虽然有重采样等操作,但是接口文档不太友好,而且cv2对于float32也支持仿射变换 19 | ''' 20 | 21 | class Gdal_Read: 22 | # 采用GDAL读取任意通道的影像(图像) 23 | def __init__(self): 24 | pass 25 | def read_img(self, filename, only_data=True): 26 | dataset = gdal.Open(filename) # 打开文件 27 | 28 | im_width = dataset.RasterXSize # 栅格矩阵的列数 29 | im_height = dataset.RasterYSize # 栅格矩阵的行数 30 | 31 | im_geotrans = dataset.GetGeoTransform() # 仿射矩阵 32 | im_proj = dataset.GetProjection() # 地图投影信息 33 | im_data = dataset.ReadAsArray(0, 0, im_width, im_height) # 将数据写成数组,对应栅格矩阵 [channels, height, width] RGB的顺序 34 | 35 | im_data = im_data.transpose((1, 2, 0)) # [H,W,C] # RGB顺序 36 | del dataset 37 | 38 | if only_data: 39 | return im_data 40 | return im_proj, im_geotrans, im_data 41 | 42 | # 写文件,以写成tif为例 43 | def write_img(self, filename, im_proj, im_geotrans, im_data): 44 | # gdal数据类型包括 45 | # gdal.GDT_Byte, 46 | # gdal .GDT_UInt16, gdal.GDT_Int16, gdal.GDT_UInt32, gdal.GDT_Int32, 47 | # gdal.GDT_Float32, gdal.GDT_Float64 48 | # cv2 对于int32-64报错,但是对于float32可以 49 | 50 | # 判断栅格数据的数据类型 51 | if 'int8' in im_data.dtype.name: 52 | datatype = gdal.GDT_Byte 53 | elif 'int16' in im_data.dtype.name: 54 | datatype = gdal.GDT_UInt16 55 | else: 56 | datatype = gdal.GDT_Float32 57 | 58 | # 判读数组维数 59 | if len(im_data.shape) == 3: 60 | im_bands, im_height, im_width = im_data.shape 61 | else: 62 | im_bands, (im_height, im_width) = 1, im_data.shape 63 | 64 | # 创建文件 65 | driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间 66 | dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) 67 | 68 | dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 69 | dataset.SetProjection(im_proj) # 写入投影 70 | 71 | if im_bands == 1: 72 | dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据 73 | else: 74 | for i in range(im_bands): 75 | dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) 76 | 77 | del dataset 78 | return 79 | 80 | class Augmentations_GDAL: 81 | def __init__(self, input_hw=(256, 256)): 82 | self.input_hw = input_hw 83 | self.image_fill = 0 # image fill=0,0对应黑边 84 | self.label_fill = 0 # label fill=0,0对应黑边 85 | ''' 86 | 以下操作,均为单操作,不可组合!,所有的操作输出均需要resize至input_hw 87 | 且 image为多通道,label为1通道 88 | 采用GDAL读取,但是数据增强采用cv2执行,cv2支持int16,float32,不支持int32格式 89 | image:[HWC], label:[HW] 90 | ''' 91 | # TODO 92 | def random_rotate(self, image, label, angle=None): 93 | ''' 94 | :param image: GDALasArray(ndarray) uint8 or int16 or float32 95 | :param label: cv2.imread uint8 96 | :param angle: None, list-float, tuple-float 97 | :return: PIL 98 | ''' 99 | if angle is None: 100 | angle = transforms.RandomRotation.get_params([-180, 180]) 101 | elif isinstance(angle, list) or isinstance(angle, tuple): 102 | angle = random.choice(angle) 103 | 104 | h, w = label.shape[:2] 105 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) # 尺度不变,中心旋转 106 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, 107 | borderValue=self.image_fill) 108 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, 109 | borderValue=self.label_fill) 110 | 111 | # resize 112 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 113 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 114 | 115 | return image, label 116 | 117 | def random_flip(self, image, label): 118 | if random.random() > 0.5: 119 | image = cv2.flip(image, 0) 120 | label = cv2.flip(label, 0) 121 | if random.random() < 0.5: 122 | image = cv2.flip(image, 1) 123 | label = cv2.flip(label, 1) 124 | 125 | # resize 126 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 127 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 128 | 129 | return image, label 130 | 131 | # zoom in 132 | def random_resize_crop(self, image, label, scale=(0.3, 1.0), ratio=(1, 1)): 133 | i, j, h, w = RandomResizedCrop_get_params(image, scale=scale, ratio=ratio) # 由于torch的需要pil格式,因此自定义utils 134 | image = image[i:i+h, j:j+w] 135 | label = label[i:i+h, j:j+w] 136 | 137 | # resize 138 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 139 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 140 | 141 | return image, label 142 | 143 | # zoom out 144 | def random_resize_minify(self, image, label, scale=(0.3, 1.0)): 145 | in_hw = label.shape[:2] 146 | factor = transforms.RandomRotation.get_params(scale) # 等比例缩放,也可不等比例 147 | size = (int(in_hw[1] * factor), int(in_hw[0] * factor)) # (w,h) 148 | 149 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 150 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST) 151 | 152 | # pad 153 | top_bottom = (self.input_hw[0] - size[0]) 154 | left_right = (self.input_hw[1] - size[1]) 155 | 156 | top = top_bottom >> 1 if top_bottom > 0 else 0 157 | bottom = top_bottom - top if top_bottom > 0 else 0 158 | left = left_right >> 1 if left_right > 0 else 0 159 | right = left_right - left if left_right > 0 else 0 160 | 161 | image = cv2.copyMakeBorder(image, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_CONSTANT, value=self.image_fill) 162 | label = cv2.copyMakeBorder(label, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_CONSTANT, value=self.label_fill) 163 | 164 | # resize 165 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 166 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 167 | 168 | return image, label 169 | 170 | # core func 171 | def random_affine(self, image, label, perspective=0.0, degrees=0.373, scale=0.898, shear=0.602, translate=0.245): 172 | # 随机仿射(随机偏移,随机旋转,随机放缩等整合) 173 | height, width = image.shape[:2] 174 | 175 | # Center refer yolov5's mosaic aug 176 | C = np.eye(3) 177 | C[0, 2] = -image.shape[1] / 2 # x translation (pixels) 178 | C[1, 2] = -image.shape[0] / 2 # y translation (pixels) 179 | 180 | # Perspective 181 | P = np.eye(3) 182 | P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) 183 | P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) 184 | 185 | # Rotation and Scale 186 | R = np.eye(3) 187 | a = random.uniform(-degrees, degrees) / math.pi * 180 # 增加将弧度 转成角度 188 | # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations 189 | s = random.uniform(1 - scale, 1 + scale) 190 | # s = 2 ** random.uniform(-scale, scale) 191 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) 192 | 193 | # Shear 194 | S = np.eye(3) 195 | S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) 196 | S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) 197 | 198 | # Translation float,先中心偏移,再进行各种操作,然后将中心转移至原始位置左右,都是随机 199 | T = np.eye(3) 200 | T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) 201 | T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) 202 | 203 | # Combined rotation matrix 204 | M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT 205 | if (M != np.eye(3)).any(): # image changed 206 | image = cv2.warpAffine(image, M[:2], dsize=self.input_hw[::-1], borderMode=cv2.BORDER_CONSTANT, borderValue=self.image_fill) 207 | label = cv2.warpAffine(label, M[:2], dsize=self.input_hw[::-1], borderMode=cv2.BORDER_CONSTANT, borderValue=self.label_fill) 208 | else: 209 | # 若未变换,则直接resize,这种概率很小 210 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 211 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 212 | 213 | return image, label 214 | 215 | def random_color_jitter(self, image, label, brightness=0.4, contrast=0.3, saturation=0.2, hue=0.2): 216 | # 随机颜色增强 217 | # TODO 多通道(>=4)的颜色增强 如何操作? 218 | 219 | return image, label 220 | 221 | # gassian noise TODO gassian-blur 222 | def random_noise(self, image, label, noise_sigma=10): 223 | in_hw = label.shape[:2] + (1,) # 需要 与image 同样的维度数量,才可以broadcast 224 | noise = (np.random.randn(*in_hw) * noise_sigma).astype(image.dtype) # +- 225 | image += noise # broadcast 226 | 227 | # resize 228 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 229 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 230 | 231 | return image, label 232 | 233 | def random_blur(self, image, label, kernel_size=(5,5)): 234 | assert len(kernel_size) == 2, "kernel size must be tuple and len()=2" 235 | image = cv2.GaussianBlur(image, ksize=kernel_size, sigmaX=0) 236 | 237 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 238 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 239 | 240 | return image, label 241 | 242 | # def random_mosaic(self, image4, label4): 243 | # # TODO mosaic data-aug 244 | # # image9 label9 245 | # pass 246 | # return 247 | 248 | 249 | class Transforms_GDAL(object): 250 | def __init__(self, input_hw=(256, 256)): 251 | self.aug_gdal = Augmentations_GDAL(input_hw) 252 | self.aug_funcs = [a for a in self.aug_gdal.__dir__() if not a.startswith('_') and a not in self.aug_gdal.__dict__] 253 | print(self.aug_funcs) 254 | 255 | def __call__(self, image, label): 256 | ''' 257 | :param image: PIL RGB uint8 258 | :param label: PIL, uint8 259 | :return: PIL 260 | ''' 261 | aug_name = random.choice(self.aug_funcs) 262 | # aug_name = 'random_rotate' #'random_flip' #'random_blur' #'random_noise' #'random_affine' #'random_resize_minify' #'random_resize_crop' 263 | print(aug_name) # 类实例后,读取数据时会不停的调用这个,每次都应该随机选择吧! 264 | image, label = getattr(self.aug_gdal, aug_name)(image, label) 265 | return image, label 266 | 267 | class TestRescale(object): 268 | # test 269 | def __init__(self, input_hw=(256, 256)): 270 | self.input_hw = input_hw 271 | def __call__(self, image, label): 272 | ''' 273 | :param image: ndarray 274 | :param label: ndarray uint8 275 | :return: 276 | ''' 277 | image = cv2.resize(image, self.input_hw[::-1], interpolation=cv2.INTER_LINEAR) 278 | label = cv2.resize(label, self.input_hw[::-1], interpolation=cv2.INTER_NEAREST) 279 | return image, label 280 | 281 | class ToTensor(object): 282 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). but no norm to 0-1 283 | def __call__(self, image, label): 284 | image = torch.from_numpy(image.transpose((2, 0, 1))) 285 | if not isinstance(image, torch.FloatTensor): 286 | image = image.float() 287 | label = torch.from_numpy(label) 288 | if not isinstance(label, torch.LongTensor): 289 | label = label.long() 290 | return image, label 291 | 292 | class Normalize(object): 293 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std 294 | def __init__(self, mean, std=None, value_scale=255): 295 | # mean's type list or tuple 296 | if std is None: 297 | assert len(mean) > 0 298 | else: 299 | assert len(mean) == len(std) 300 | 301 | # equal to norm [0,1] then similar to pytorch's norm 302 | self.mean = [item * value_scale for item in mean] 303 | try: 304 | self.std = [item * value_scale for item in std] 305 | except: 306 | self.std = std 307 | 308 | def __call__(self, image, label): 309 | # tensor 310 | if self.std is None: 311 | for t, m in zip(image, self.mean): 312 | t.sub_(m) 313 | else: 314 | for t, m, s in zip(image, self.mean, self.std): 315 | t.sub_(m).div_(s) 316 | return image, label 317 | 318 | # Compose pytorch自带的只对img处理,需要重写 319 | class Compose(object): 320 | def __init__(self, transforms): 321 | self.transforms = transforms 322 | 323 | def __call__(self, image, label): 324 | for t in self.transforms: 325 | image, label = t(image, label) 326 | return image, label 327 | 328 | if __name__ == '__main__': 329 | # runer = Gdal_Read() 330 | # # jpg tiff 均可读取 331 | # # im_proj, im_geotrans, im_data = runer.read_img(filename=r'F:\DataSets\jishi_toukui\1bc523b1-7bb4-4a14-9b32-5476f04c853f.jpg') 332 | # im_proj, im_geotrans, im_data = runer.read_img(filename=r'D:\A145984.jpg') 333 | 334 | # image label 需要同时处理 335 | train_transforms = Compose([Transforms_GDAL(input_hw=(150, 150)), 336 | ToTensor(), # just hwc->chw 337 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # note value_scale 338 | ]) 339 | test_transforms = Compose([TestRescale(input_hw=(150, 150)), 340 | ToTensor(), 341 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 342 | ]) 343 | image = np.ones([100,100,3], dtype=np.uint8) 344 | label = np.ones([100,100], dtype=np.uint8) 345 | im_out, lab_out = train_transforms(image, label) 346 | print(im_out.shape) -------------------------------------------------------------------------------- /utils/aug_PIL.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision import transforms 7 | import torchvision.transforms.functional as tf 8 | ''' 9 | author is leilei 10 | 语义分割数据增强时,需将图像和标签图同时操作,对于旋转,偏移等操作,会引入黑边(均为0值), 11 | 将引入的黑边 视为1类,标签值默认为0,真实标签从1开始。 12 | 图像采用BILINEAR,标签图采用NEAREST 13 | 目前采用 torchvision.transforms.functional 的API,此api与PIL的数据增强操作是一致的,只要转成PIL,均采用uint8 14 | https://pytorch.org/docs/1.6.0/torchvision/transforms.html#functional-transforms 15 | ''' 16 | class Augmentations_PIL: 17 | def __init__(self, input_hw=(256, 256)): 18 | self.input_hw = input_hw 19 | self.image_fill = 0 # image fill=0,0对应黑边 20 | self.label_fill = 0 # label fill=0,0对应黑边 21 | ''' 22 | train 阶段 23 | 以下操作,均为单操作,不可组合!,所有的操作输出均需要resize至input_hw 24 | 且 image为3 channel,label为1 channel 25 | 且 输入均为RGB-3通道 26 | image:[HWC], label:[HW] 27 | ''' 28 | def random_rotate(self, image, label, angle=None): 29 | ''' 30 | :param image: PIL RGB uint8 31 | :param label: PIL, uint8 32 | :param angle: None, list-float, tuple-float 33 | :return: PIL 34 | ''' 35 | if angle is None: 36 | angle = transforms.RandomRotation.get_params([-180, 180]) 37 | elif isinstance(angle, list) or isinstance(angle, tuple): 38 | angle = random.choice(angle) 39 | 40 | image = tf.rotate(image, angle, fill=self.image_fill) 41 | label = tf.rotate(label, angle, fill=self.label_fill) 42 | 43 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 44 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 45 | 46 | return image, label 47 | 48 | def random_flip(self, image, label): 49 | if random.random() > 0.5: 50 | image = tf.hflip(image) 51 | label = tf.hflip(label) 52 | if random.random() < 0.5: 53 | image = tf.vflip(image) 54 | label = tf.vflip(label) 55 | 56 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 57 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 58 | 59 | return image, label 60 | 61 | # zoom in 62 | def random_resize_crop(self, image, label, scale=(0.3, 1.0), ratio=(1, 1)): 63 | # 等价于 随即裁剪+resize至指定大小,大部分为放大操作; 64 | i, j, h, w = transforms.RandomResizedCrop.get_params(image, scale=scale, ratio=ratio) # 是在原图上 某个区域范围内(ratio控制区域长宽)随机裁剪 65 | image = tf.resized_crop(image, i, j, h, w, self.input_hw, interpolation=Image.BILINEAR) 66 | label = tf.resized_crop(label, i, j, h, w, self.input_hw, interpolation=Image.NEAREST) 67 | 68 | return image, label 69 | 70 | # zoom out 71 | def random_resize_minify(self, image, label, scale=(0.3, 1.0)): 72 | # 等价于 resize+padding(随机位置),大部分为缩小操作 73 | in_hw = image.size[::-1] 74 | 75 | factor = transforms.RandomRotation.get_params(scale) # 等比例缩放,也可不等比例 76 | size = (int(in_hw[0]*factor), int(in_hw[1]*factor)) # (h,w) 77 | image = tf.resize(image, size, interpolation=Image.BILINEAR) 78 | label = tf.resize(label, size, interpolation=Image.NEAREST) 79 | # pad 80 | top_bottom = (self.input_hw[0] - size[0]) 81 | left_right = (self.input_hw[1] - size[1]) 82 | 83 | top = top_bottom >> 1 if top_bottom > 0 else 0 84 | bottom = top_bottom - top if top_bottom > 0 else 0 85 | left = left_right >> 1 if left_right > 0 else 0 86 | right = left_right - left if left_right > 0 else 0 87 | 88 | tf.pad(image, (left, top, right, bottom), fill=self.image_fill, padding_mode='constant') 89 | # 黑边 默认成 0 类 90 | tf.pad(label, (left, top, right, bottom), fill=self.label_fill, padding_mode='constant') 91 | 92 | # resize 93 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 94 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 95 | 96 | return image, label 97 | 98 | ''' 99 | core function, Similar to cv2.warpAffine() 100 | # 可以将其它的所有操作 都基于此 进行,类似于 cv2的仿射变换矩阵;但是cv2默认是左上角, 101 | 不能保证保持中心不变,除非最后有中心偏移操作!那么之前也应该有中心的某些操作 102 | 可参考torchvision.transforms.functional -> _get_inverse_affine_matrix 103 | ''' 104 | def random_affine(self, image, label): 105 | # 随机仿射(随机偏移,随机旋转,随机放缩等整合) 106 | if random.random() > 0.5: 107 | # 透视变换 RandomPerspective 108 | width, height = image.size 109 | startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5) 110 | # 0值填充,仍是原始图像大小,需要resize 111 | image = tf.perspective(image, startpoints, endpoints, interpolation=Image.BICUBIC, fill=self.image_fill) 112 | label = tf.perspective(label, startpoints, endpoints, interpolation=Image.NEAREST, fill=self.label_fill) 113 | elif random.random() < 0.5: 114 | # TODO 将degrees等参数传出,由用户设置 115 | # 随机旋转-平移-缩放-错切 4种仿射变换 pytorch实现的是保持中心不变 不错切 116 | ret = transforms.RandomAffine.get_params(degrees=(-180, 180), translate=(0.3, 0.3), scale_ranges=(0.3, 3), 117 | shears=None, img_size=image.size) 118 | # angle, translations, scale, shear = ret 119 | # 0值填充,仍是原始图像大小,需要resize 120 | image = tf.affine(image, *ret, resample=0, fillcolor=self.image_fill) # PIL.Image.NEAREST 121 | label = tf.affine(label, *ret, resample=0, fillcolor=self.label_fill) 122 | 123 | # 将图像处理成要求的大小 124 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 125 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 126 | 127 | return image, label 128 | 129 | def random_color_jitter(self, image, label, brightness=0.4, contrast=0.3, saturation=0.2, hue=0.2): 130 | # 随机颜色增强,这里的随机是值,而非发生概率:transforms.RandomApply 131 | transforms_func = transforms.ColorJitter(brightness=brightness, 132 | contrast=contrast, 133 | saturation=saturation, 134 | hue=hue) 135 | image = transforms_func(image) 136 | # label = label 137 | 138 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 139 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 140 | 141 | return image, label 142 | 143 | # gassian noise 144 | def random_noise(self, image, label, noise_sigma=10): 145 | in_hw = image.size[::-1] + (1,) 146 | noise = np.uint8(np.random.randn(*in_hw) * noise_sigma) # +- 147 | 148 | image = np.array(image) + noise # broadcast 149 | image = Image.fromarray(image, "RGB") 150 | 151 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 152 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 153 | 154 | return image, label 155 | 156 | def random_blur(self, image, label, kernel_size=(5,5)): 157 | assert len(kernel_size) == 2, "kernel size must be tuple and len()=2" 158 | image = cv2.GaussianBlur(np.array(image), ksize=kernel_size, sigmaX=0) 159 | image = Image.fromarray(image, "RGB") 160 | 161 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 162 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 163 | 164 | return image, label 165 | 166 | class Transforms_PIL(object): 167 | def __init__(self, input_hw=(256, 256)): 168 | self.aug_pil = Augmentations_PIL(input_hw) 169 | self.aug_funcs = [a for a in self.aug_pil.__dir__() if not a.startswith('_') and a not in self.aug_pil.__dict__] 170 | print(self.aug_funcs) 171 | 172 | def __call__(self, image, label): 173 | ''' 174 | :param image: PIL RGB uint8 175 | :param label: PIL, uint8 176 | :return: PIL 177 | ''' 178 | aug_name = random.choice(self.aug_funcs) 179 | # aug_name = 'random_resize_crop' #'random_rotate' #'random_flip' #'random_blur' #'random_noise' #'random_affine' #'random_resize_minify' #'random_resize_crop' 180 | print(aug_name) # 类实例后,读取数据时会不停的调用这个,每次都应该随机选择吧! 181 | image, label = getattr(self.aug_pil, aug_name)(image, label) 182 | return image, label 183 | 184 | class TestRescale(object): 185 | # test 186 | def __init__(self, input_hw=(256, 256)): 187 | self.input_hw = input_hw 188 | def __call__(self, image, label): 189 | ''' 190 | :param image: PIL RGB uint8 191 | :param label: PIL, uint8 192 | :return: PIL 193 | ''' 194 | image = tf.resize(image, self.input_hw, interpolation=Image.BILINEAR) 195 | label = tf.resize(label, self.input_hw, interpolation=Image.NEAREST) 196 | return image, label 197 | 198 | class ToTensor(object): 199 | # image label -> tensor, image div 255 200 | def __call__(self, image, label): 201 | # PIL uint8 202 | image = tf.to_tensor(image) # transpose HWC->CHW, /255 203 | label = torch.from_numpy(np.array(label)) # PIL->ndarray->tensor 204 | if not isinstance(label, torch.LongTensor): 205 | label = label.long() 206 | return image, label 207 | 208 | class Normalize(object): 209 | # (image-mean)/std 210 | def __init__(self, mean, std, inplace=False): 211 | self.mean = mean # RGB 212 | self.std = std 213 | self.inplace = inplace 214 | 215 | def __call__(self, image, label): 216 | image = tf.normalize(image, self.mean, self.std, self.inplace) 217 | assert isinstance(label, torch.LongTensor) 218 | label = label 219 | return image, label 220 | 221 | # Compose pytorch自带的只对img处理,需要重写 222 | class Compose(object): 223 | def __init__(self, transforms): 224 | self.transforms = transforms 225 | 226 | def __call__(self, image, label): 227 | for t in self.transforms: 228 | image, label = t(image, label) 229 | return image, label 230 | 231 | if __name__ == '__main__': 232 | # aug_pil = Augmentations_PIL() 233 | # # dir包含 属性-所有方法,dict只包含属性 234 | # print(aug_pil.__dict__) 235 | # aug_funcs = [a for a in aug_pil.__dir__() if not a.startswith('_') and a not in aug_pil.__dict__] 236 | # 237 | # trans = Transforms_PIL(input_hw=(150,150)) 238 | image = np.uint8(np.random.rand(100,100,3)*255) 239 | label = np.ones([100,100], dtype=np.uint8) 240 | image = Image.fromarray(image, "RGB") # PIL 241 | label = Image.fromarray(label) # PIL 242 | # image1, label1 = trans(image, label) 243 | 244 | # image label 需要同时处理 245 | train_transforms = Compose([Transforms_PIL(input_hw=(150,150)), 246 | ToTensor(), # /255 totensor 247 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 248 | ]) 249 | test_transforms = Compose([TestRescale(input_hw=(150,150)), 250 | ToTensor(), # /255 251 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 252 | ]) 253 | 254 | im_out, lab_out = train_transforms(image, label) 255 | print(im_out.shape) -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | from torch.utils import data 7 | 8 | ''' 9 | 默认PIL读取图像,由于linux GDAL不友好,暂时不实现! 10 | root/ 11 | train_txt 12 | /root/images/aaa.jpg | replace(images, labels) 13 | images/ 14 | aaa.jpg 15 | bbb.jpg 16 | labels/ 17 | aaa.jpg 18 | bbb.jpg 19 | ''' 20 | 21 | # 数据读取类 22 | class LoadDataset(data.Dataset): 23 | def __init__(self, txt_path, transform=None, is_gdal=False): 24 | assert transform is None, "transform不能为None" 25 | self.transform = transform 26 | self.is_gdal = is_gdal 27 | with open(txt_path, 'r') as f: 28 | self.image_paths = f.readlines() 29 | 30 | def __len__(self): 31 | return len(self.image_paths) 32 | 33 | def __getitem__(self, item): 34 | image_path = self.image_paths[item].strip() 35 | label_path = image_path.replace('images', 'labels') # image in images/ folder, label in labels/ folder 36 | 37 | if not self.is_gdal: 38 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 保证是3通道, 即使是1通道 39 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) #label图像本来就是1通道 40 | image = Image.fromarray(image[:, :, ::-1], 'RGB') # bgr->rgb->PIL 41 | label = Image.fromarray(label) 42 | 43 | image, label = self.transform(image, label) 44 | 45 | return image, label 46 | 47 | 48 | # 获取 train-test各自对应的transform 49 | def get_transform(is_gdal=False, input_hw=(256,256), value_scale=None, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 50 | 51 | if is_gdal: 52 | if value_scale is None: 53 | value_scale = 255 54 | assert isinstance(value_scale, int), "value_scale需要为int类型" 55 | from utils.aug_GDAL import Compose, Transforms_GDAL, TestRescale, ToTensor, Normalize 56 | train_transforms = Compose([Transforms_GDAL(input_hw=input_hw), 57 | ToTensor(), # just hwc->chw 58 | Normalize(mean, std, value_scale), # note value_scale 59 | ]) 60 | test_transforms = Compose([TestRescale(input_hw=input_hw), 61 | ToTensor(), 62 | Normalize(mean, std, value_scale), 63 | ]) 64 | 65 | else: 66 | from utils.aug_PIL import Compose, Transforms_PIL, TestRescale, ToTensor, Normalize 67 | train_transforms = Compose([Transforms_PIL(input_hw=input_hw), 68 | ToTensor(), # /255 totensor 69 | Normalize(mean, std), 70 | ]) 71 | test_transforms = Compose([TestRescale(input_hw=input_hw), 72 | ToTensor(), # /255 73 | Normalize(mean, std), 74 | ]) 75 | 76 | return train_transforms, test_transforms 77 | 78 | 79 | def load_data(params): 80 | ''' 81 | :param params: configs/parameter.yaml pasred params 82 | :return: 83 | ''' 84 | # transform param 85 | is_gdal = params['is_gdal'] 86 | input_hw = params['input_hw'] 87 | value_scale = params['value_scale'] 88 | mean = params['mean'] 89 | std = params['std'] 90 | # data loader 91 | train_txt_path = params['train_txt_path'] 92 | test_txt_path = params['test_txt_path'] 93 | batch_size = params['batch_size'] 94 | num_workers = params['num_workers'] 95 | 96 | # transform 97 | train_transforms, test_transforms = get_transform(is_gdal, input_hw, value_scale, mean, std) 98 | # train 99 | train_dataset = LoadDataset(train_txt_path, train_transforms, is_gdal) 100 | train_loader = data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers) 101 | # test 102 | test_dataset = LoadDataset(test_txt_path, test_transforms, is_gdal) 103 | test_loader = data.DataLoader(test_dataset, batch_size*2, shuffle=False, num_workers=num_workers) 104 | 105 | data_loader = {} 106 | data_loader['train'] = train_loader 107 | data_loader['test'] = test_loader 108 | 109 | return data_loader 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Semantic segmentation metrics functions 3 | ''' 4 | import numpy as np 5 | from sklearn import metrics 6 | 7 | def get_confusion_matrix(predicts, labels, class_number): 8 | # predicts.shape == labels.shape 9 | confusion_matrix = metrics.confusion_matrix(labels.reshape([-1]), predicts.reshape([-1]), labels=range(class_number)) 10 | return confusion_matrix 11 | 12 | def compute_acc_pr_iou(confusion_matrix): 13 | # Calculate various indicators according to the confusion matrix. 14 | diag = np.diag(confusion_matrix) 15 | p_s = np.sum(confusion_matrix, axis=0) 16 | r_s = np.sum(confusion_matrix, axis=1) 17 | 18 | acc = np.sum(diag) / np.sum(confusion_matrix) 19 | mean_precision = np.mean(diag / (p_s + 1e-6)) # per class precison's mean value 20 | mean_recall = np.mean(diag / (r_s + 1e-6)) # per class recall's mean value 21 | mean_iou = np.mean(diag / (p_s + r_s - diag + 1e-6)) 22 | 23 | return acc, round(mean_precision, 4), round(mean_recall, 4), round(mean_iou, 4) 24 | 25 | 26 | if __name__ == '__main__': 27 | m = get_confusion_matrix(predicts=np.ones([10,3]), labels=np.ones([10,3]), class_number=5) 28 | acc, mean_precision, mean_recall, mean_iou = compute_acc_pr_iou(m) 29 | print(m) 30 | print(acc, mean_precision, mean_recall, mean_iou) 31 | 32 | 33 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | # support chinese 6 | # refer https://blog.csdn.net/lucky__ing/article/details/78699198 7 | plt.rcParams['font.family'] = ['Microsoft YaHei'] 8 | plt.rcParams['axes.unicode_minus'] = False 9 | 10 | # 混淆矩阵-可视化 11 | def plot_confusion_matrix(confusion_matrix, save_dir, class_names): 12 | plt.figure() 13 | plt.imshow(confusion_matrix, cmap=plt.cm.Reds) 14 | indices = range(len(confusion_matrix)) # confusion_matrix [N,N] 15 | plt.xticks(indices, list(class_names)) 16 | plt.yticks(indices, list(class_names)) 17 | 18 | plt.colorbar() 19 | 20 | plt.xlabel('Predict-Value') 21 | plt.ylabel('Ground-Truth') 22 | plt.title('Confusion-Matrix') 23 | 24 | for first_index in range(len(confusion_matrix)): 25 | for second_index in range(len(confusion_matrix[first_index])): 26 | plt.text(first_index, second_index, confusion_matrix[first_index][second_index]) 27 | 28 | plt.plot() # draw image 29 | plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'), dpi=300) 30 | # plt.show() # show off image 31 | 32 | return 33 | 34 | 35 | 36 | if __name__ == '__main__': 37 | confusion_matrix = np.array([[90,10,5,1,1],[10,80,15,1,1],[20,5,80,1,1],[1,1,1,95,10],[1,1,1,1,100]]) 38 | save_dir = r'/home/gengyanlei/' 39 | class_names = ['我', '你', '他', '它', '她'] 40 | # class_names = ['wo', 'ni', 'ta', '1', '2'] 41 | plot_confusion_matrix(confusion_matrix, save_dir, class_names) 42 | 43 | -------------------------------------------------------------------------------- /utils/trainval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | train-test stage 3 | ''' 4 | import torch 5 | import numpy as np 6 | from torch import nn 7 | from pathlib import Path 8 | from utils.plots import plot_confusion_matrix 9 | from utils.metrics import get_confusion_matrix, compute_acc_pr_iou 10 | import torch.nn.functional as F 11 | 12 | def train(data_loader, model, optimizer, scheduler, tb_writer, param_dict, continue_epoch): 13 | # weights folder create 14 | save_dir = Path(param_dict['save_dir']) / 'weights' 15 | save_dir.mkdir(parents=True, exist_ok=True) 16 | last = save_dir / 'last.pt' 17 | best = save_dir / 'best.pt' 18 | 19 | cross_entropy = nn.CrossEntropyLoss() 20 | 21 | # first update lr 22 | for epoch in range(0, continue_epoch): 23 | scheduler.step() 24 | 25 | best_fitness = 0 26 | for epoch in range(continue_epoch, param_dict['epoches']): 27 | model.train() 28 | scheduler.step() 29 | train_acc = 0 30 | train_loss = 0 # TODO add list to reduce code 31 | for step, data in enumerate(data_loader['train']): 32 | loss = 0 33 | images, labels = data 34 | inputs = inputs.cuda() 35 | labels = labels.cuda() 36 | 37 | optimizer.zero_grad() 38 | outputs = model(inputs) 39 | 40 | loss += cross_entropy(outputs, labels) 41 | train_loss += loss.cpu().item() 42 | 43 | loss.backward() 44 | torch.cuda.synchronize() 45 | optimizer.step() 46 | torch.cuda.synchronize() 47 | 48 | # TODO train-test tensorboard summary 49 | tb_writer.add_scalar('Loss/train_loss', train_loss/len(data_loader['train']), epoch) 50 | 51 | # val stage 52 | if ((epoch - continue_epoch) % param_dict['test_interval'] == 0) and (epoch - continue_epoch) != 0: 53 | 54 | test_loss, test_indexs = test(data_loader['test'], model, param_dict) 55 | tb_writer.add_scalar('Loss/test_loss', test_loss, epoch) 56 | tags = ['Metrics/Accuracy', 'Metrics/Mean_Precision', 'Metrics/Mean_Recall', 'Metrics/Mean_IoU'] 57 | for tag, index in zip(tags, test_indexs): 58 | tb_writer.add_scalar(tag, index, epoch) 59 | # TODO best_fitness=w1*acc+w2*precision+w3*recall+w4*mean_iou 60 | # save best weight 61 | if test_indexs[-1] > best_fitness: 62 | best_fitness = test_indexs[-1] 63 | torch.save({'model': model, 64 | 'epoch': epoch, 65 | 'model_name': param_dict['model_name'], 66 | 'optimizer': optimizer.state_dict(), 67 | 'best_fitness': best_fitness}, best) 68 | # save last weight 69 | torch.save({'model': model, 70 | 'epoch': epoch, 71 | 'model_name': param_dict['model_name'], 72 | 'optimizer': optimizer.state_dict(), 73 | 'best_fitness': best_fitness if param_dict['test_interval'] == 1 else None}, last) 74 | 75 | # end training, last delete epoch etl information 76 | torch.save({'model': model, 77 | 'model_name': param_dict['model_name'], 78 | 'optimizer': None}, last) 79 | 80 | return 81 | 82 | def test(test_loader, model, param_dict): 83 | confusion_matrix = np.zeros([param_dict['class_number'], param_dict['class_number']], dtype=np.int64) 84 | 85 | cross_entropy = nn.CrossEntropyLoss() # 已经经过mean 86 | 87 | model.train(False) # = model.eval() restrict bn 88 | with torch.no_grad(): 89 | test_acc = 0 90 | test_loss = 0 91 | for step, data in enumerate(test_loader): 92 | images, labels = data 93 | inputs = inputs.cuda() 94 | labels_cuda = labels.cuda() 95 | outputs = model(inputs) 96 | loss = cross_entropy(outputs, labels_cuda) 97 | # compute confusion matrix 98 | outputs_p = F.softmax(outputs, dim=1) # [N,C,H,W] cuda 99 | P = torch.max(outputs_p, 1)[1].data.cpu().numpy() # [N,H,W] numpy-cpu 100 | 101 | m = get_confusion_matrix(P, labels.data.numpy(), class_number=param_dict['class_number']) 102 | confusion_matrix += m 103 | 104 | test_loss += loss.cpu().item() 105 | # plot confusion_matrix and save 106 | plot_confusion_matrix(confusion_matrix, param_dict['save_dir'], param_dict['class_names']) 107 | 108 | acc, mean_precision, mean_recall, mean_iou = compute_acc_pr_iou(confusion_matrix) 109 | 110 | return test_loss/len(test_loader), (acc, mean_precision, mean_recall, mean_iou) 111 | 112 | 113 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import random 4 | import math 5 | import numpy as np 6 | import torch 7 | from torch.backends import cudnn 8 | from torch.optim import lr_scheduler 9 | 10 | ''' 11 | Some auxiliary functions/一些辅助函数 12 | ''' 13 | 14 | 15 | def RandomResizedCrop_get_params(img, scale, ratio): 16 | """Get parameters for ``crop`` for a random sized crop. 17 | 18 | Args: 19 | img (PIL Image): Image to be cropped. 修改成 numpy格式 20 | scale (tuple): range of size of the origin size cropped 21 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 22 | 23 | Returns: 24 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 25 | sized crop. 26 | """ 27 | height, width = img.shape[:2] 28 | area = height * width 29 | 30 | for _ in range(10): 31 | target_area = random.uniform(*scale) * area 32 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 33 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 34 | 35 | w = int(round(math.sqrt(target_area * aspect_ratio))) 36 | h = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if 0 < w <= width and 0 < h <= height: 39 | i = random.randint(0, height - h) 40 | j = random.randint(0, width - w) 41 | return i, j, h, w 42 | 43 | # Fallback to central crop 44 | in_ratio = float(width) / float(height) 45 | if (in_ratio < min(ratio)): 46 | w = width 47 | h = int(round(w / min(ratio))) 48 | elif (in_ratio > max(ratio)): 49 | h = height 50 | w = int(round(h * max(ratio))) 51 | else: # whole image 52 | w = width 53 | h = height 54 | i = (height - h) // 2 55 | j = (width - w) // 2 56 | return i, j, h, w 57 | 58 | 59 | def init_seeds(seed=1): 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | init_torch_seeds(seed) 63 | return 64 | 65 | # 为什么不使用 所有GPU的呢? 66 | def init_torch_seeds(seed=0): 67 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 68 | torch.manual_seed(seed) 69 | if seed == 0: # slower, more reproducible 70 | cudnn.deterministic = True 71 | cudnn.benchmark = False 72 | else: # faster, less reproducible 73 | cudnn.deterministic = False 74 | cudnn.benchmark = True 75 | return 76 | 77 | def check_path(path, is_file=False): 78 | p = Path(path) 79 | if is_file: 80 | p.touch() 81 | return 82 | if not p.exists(): 83 | p.mkdir(parents=True) 84 | return 85 | 86 | def increment_path(path, sep=''): 87 | # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. 88 | path = Path(path) # os-agnostic 89 | if (not path.exists()): 90 | return str(path) 91 | else: 92 | import glob, re 93 | dirs = glob.glob(f"{path}{sep}*") # similar paths 94 | matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] # linux file and folder same name is not existed 95 | i = [int(m.groups()[0]) for m in matches if m] # indices 96 | n = max(i) + 1 if i else 1 # increment number 97 | return f"{path}{sep}{n}" # update path 98 | 99 | --------------------------------------------------------------------------------