├── Pytorch-FCN ├── utiles │ ├── __init__.py │ ├── hiddenlayer_polt_model.py │ ├── BilinearUpSampling.py │ ├── data.py │ ├── evalution_segmentaion.py │ └── functional.py ├── imgs │ └── pic │ │ └── read.txt ├── logs │ └── README.md ├── CamVid │ ├── test │ │ └── read.txt │ ├── train │ │ └── readme.txt │ ├── val │ │ └── read.txt │ ├── test_labels │ │ └── read.txt │ ├── val_labels │ │ └── read.txt │ ├── train_labels │ │ └── read.txt │ └── class_dict.csv ├── README.md ├── test.py ├── predict.py ├── train.py ├── data.py └── FCN.py └── README.md /Pytorch-FCN/utiles/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pytorch-FCN/imgs/pic/read.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放预测的图像。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/logs/README.md: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放训练后的模型。 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-segmentation 2 | 语义分割 3 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/test/read.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放测试集图像。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/train/readme.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放训练集。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/val/read.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放验证集图像。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/test_labels/read.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放测试集标签。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/val_labels/read.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放验证集标签。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/train_labels/read.txt: -------------------------------------------------------------------------------- 1 | 此文件夹用于存放训练集标签。 2 | -------------------------------------------------------------------------------- /Pytorch-FCN/CamVid/class_dict.csv: -------------------------------------------------------------------------------- 1 | name,r,g,b 2 | Sky,128, 128, 128 3 | Building,128, 0, 0 4 | Pole,192, 192, 128 5 | Road,128, 64, 128 6 | Sidewalk,0,0,192 7 | Tree,128,128,0 8 | SignSymbol,192,128,128 9 | Fence,64,64,128 10 | Car,64,0,128 11 | Pedestrian,64,64,0 12 | Bicyclist,0,128,192 13 | unlabelled,0,0,0 -------------------------------------------------------------------------------- /Pytorch-FCN/utiles/hiddenlayer_polt_model.py: -------------------------------------------------------------------------------- 1 | import hiddenlayer as hl 2 | import torch 3 | from by19_best import By_3d 4 | 5 | model = By_3d().cuda() 6 | hl_graph = hl.build_graph(model, (torch.zeros([1, 3, 352, 480]).cuda())) 7 | hl_graph.theme = hl.graph.THEMES["blue"].copy() # Two options: basic and blue 8 | 9 | hl_graph.save('/home/zjy/what/best.jpg') -------------------------------------------------------------------------------- /Pytorch-FCN/README.md: -------------------------------------------------------------------------------- 1 | ### Pytorch复现FCN网络 2 | #### 1、环境配置 3 | Windows10,pytorch=1.3,python=3.6 4 | 参考博客:https://github.com/wkentaro/pytorch-fcn 5 | #### 2、文件说明 6 | CamVid文件夹:数据集,里面包含训练集,验证集,测试集; 7 | logs文件夹:存放训练后的模型文件.pth; 8 | imgs文件夹:存放预测后的图像; 9 | data.py:数据处理;FCN.py:网络模型文件,包含FCN32s、FCN16s、FCN8s; 10 | #### 3、复现步骤: 11 | 模型训练: 12 | python train.py 13 | 模型测试: 14 | python test.py 15 | 模型预测: 16 | python predict.py 17 | #### 4、CamVid数据集下载: 18 | https://download.csdn.net/download/weixin_44753371/12299379 19 | #### 5、训练自己的数据集 20 | 可以详情参考博客: 21 | https://blog.csdn.net/weixin_44753371/article/details/105292287 22 | -------------------------------------------------------------------------------- /Pytorch-FCN/test.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from torch.utils.data import DataLoader 3 | from utiles.evalution_segmentaion import eval_semantic_segmentation 4 | import torch.nn.functional as F 5 | import torch as t 6 | from predict import test_dataset 7 | from FCN import FCN8s,VGGNet 8 | 9 | BATCH_SIZE = 2 10 | miou_list = [0] 11 | test_data = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) 12 | 13 | vgg_model = VGGNet(requires_grad=True) 14 | net = FCN8s(pretrained_net=vgg_model,n_class=12) 15 | net.eval() 16 | net.cuda() 17 | net.load_state_dict(t.load('D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/logs/last.pth')) #加载模型 18 | 19 | train_acc = 0 20 | train_miou = 0 21 | train_class_acc = 0 22 | train_mpa = 0 23 | error = 0 24 | 25 | 26 | for i, sample in enumerate(test_data): #(data, label)-->sample 27 | # data = Variable(data).cuda() 28 | # label = Variable(label).cuda() 29 | # out = net(data) 30 | # out = F.log_softmax(out, dim=1) 31 | 32 | #我认为增添的 33 | data = sample['img'].cuda() #valImg= --> data 34 | label = sample['label'].long().cuda() #valLabel= --> label= 35 | out = net(data) #valImg --> data 36 | out = F.log_softmax(out, dim=1) 37 | 38 | pre_label = out.max(dim=1)[1].data.cpu().numpy() 39 | pre_label = [i for i in pre_label] 40 | 41 | true_label = label.data.cpu().numpy() 42 | true_label = [i for i in true_label] 43 | 44 | eval_metrix = eval_semantic_segmentation(pre_label, true_label) 45 | train_acc = eval_metrix['mean_class_accuracy'] + train_acc 46 | train_miou = eval_metrix['miou'] + train_miou 47 | train_mpa = eval_metrix['pixel_accuracy'] + train_mpa 48 | if len(eval_metrix['class_accuracy']) < 12: 49 | eval_metrix['class_accuracy'] = 0 50 | train_class_acc = train_class_acc + eval_metrix['class_accuracy'] 51 | error += 1 52 | else: 53 | train_class_acc = train_class_acc + eval_metrix['class_accuracy'] 54 | 55 | print(eval_metrix['class_accuracy'], '================', i) 56 | 57 | 58 | epoch_str = ('test_acc :{:.5f} ,test_miou:{:.5f}, test_mpa:{:.5f}, test_class_acc :{:}'.format(train_acc /(len(test_data)-error), 59 | train_miou/(len(test_data)-error), train_mpa/(len(test_data)-error), 60 | train_class_acc/(len(test_data)-error))) 61 | 62 | if train_miou/(len(test_data)-error) > max(miou_list): 63 | miou_list.append(train_miou/(len(test_data)-error)) 64 | print(epoch_str+'==========last') 65 | -------------------------------------------------------------------------------- /Pytorch-FCN/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | from utiles.data import img_transform 5 | from PIL import Image 6 | import torch as t 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from torch.utils.data import DataLoader 10 | from FCN import FCN8s,VGGNet 11 | 12 | 13 | TEST_ROOT = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/test' #'/xxx/CamVid/test' #测试集路径 14 | TEST_LABEL = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/test_labels' #'/xxx/CamVid/test_labels' #测试集标签路径 15 | 16 | imgs = os.listdir(TEST_ROOT) 17 | imgs = [os.path.join(TEST_ROOT, img) for img in imgs] 18 | imgs.sort() 19 | 20 | labels = os.listdir(TEST_LABEL) 21 | labels = [os.path.join(TEST_LABEL, label) for label in labels] 22 | labels.sort() 23 | 24 | 25 | input_size = (352, 480) #height width 26 | 27 | 28 | class TestDataset(Dataset): 29 | def __init__(self, transform, crop_size): 30 | self.imgs = imgs 31 | self.labels = labels 32 | self.transforms = transform 33 | self.crop_size = crop_size 34 | 35 | def __getitem__(self, index): 36 | img, label = Image.open(self.imgs[index]), Image.open(self.labels[index]).convert('RGB') 37 | #img, label = self.transforms(img, label, self.crop_size) 38 | img, label, label1, label2, label3, label4, label5 = self.transforms(img, label, self.crop_size) #改 39 | #sample = {'img': img, 'label': label} 40 | sample = {'img': img, 'label': label, 'label1': label1, 'label2': label2, 'label3': label3, 'label4': label4, 41 | 'label5': label5} #改 42 | return sample 43 | 44 | def __len__(self): 45 | return len(self.imgs) 46 | 47 | 48 | test_dataset = TestDataset(img_transform, input_size) 49 | test_data = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) 50 | 51 | 52 | vgg_model = VGGNet(requires_grad=True) 53 | net = FCN8s(pretrained_net=vgg_model,n_class=12).cuda() 54 | net.load_state_dict(t.load("D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/logs/last.pth")) #模型加载路径 55 | net.eval() 56 | 57 | pd_label_color = pd.read_csv('D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/class_dict.csv', sep=',') #CSV路径 58 | name_value = pd_label_color['name'].values 59 | num_class = len(name_value) 60 | 61 | colormap = [] 62 | for i in range(len(pd_label_color.index)): 63 | # 通过行号索引行数据 64 | tmp = pd_label_color.iloc[i] 65 | color = [] 66 | color.append(tmp['r']) 67 | color.append(tmp['g']) 68 | color.append(tmp['b']) 69 | colormap.append(color) 70 | 71 | cm = np.array(colormap).astype('uint8') 72 | 73 | dir = "D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/imgs/pic" #保存图像的路径 74 | 75 | for i, sample in enumerate(test_data): 76 | valImg = sample['img'].cuda() 77 | valLabel = sample['label'].long().cuda() 78 | out = net(valImg) 79 | out = F.log_softmax(out, dim=1) 80 | pre_label = out.max(1)[1].squeeze().cpu().data.numpy() 81 | pre = cm[pre_label] 82 | pre1 = Image.fromarray(pre) 83 | pre1.save(dir + str(i) + '.png') -------------------------------------------------------------------------------- /Pytorch-FCN/train.py: -------------------------------------------------------------------------------- 1 | from utiles.evalution_segmentaion import eval_semantic_segmentation 2 | from torch import optim 3 | from torch.autograd import Variable 4 | from datetime import datetime 5 | from torch.utils.data import DataLoader 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch as t 9 | import data 10 | from FCN import FCN8s,VGGNet #FCN网络 11 | 12 | BATCH_SIZE = 2 13 | train_data = DataLoader(data.Cam_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 14 | val_data = DataLoader(data.Cam_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 15 | 16 | 17 | 18 | def main(): 19 | 20 | 21 | vgg_model = VGGNet(requires_grad=True) 22 | net = FCN8s(pretrained_net=vgg_model,n_class=12) 23 | net = net.cuda() 24 | criterion = nn.NLLLoss().cuda() 25 | optimizer = optim.Adam(net.parameters(), lr=1e-4) 26 | 27 | eval_miou_list = [] 28 | best = [0] 29 | print('-----------------------train-----------------------') 30 | 31 | 32 | for epoch in range(30): 33 | if epoch % 10 == 0 and epoch != 0: 34 | for group in optimizer.param_groups: 35 | group['lr'] *= 0.5 36 | 37 | train_loss = 0 38 | train_acc = 0 39 | train_miou = 0 40 | train_class_acc = 0 41 | #global net #自认为加的 42 | net = net.train() 43 | prec_time = datetime.now() 44 | for i, sample in enumerate(train_data): 45 | imgdata = Variable(sample['img'].cuda()) 46 | imglabel = Variable(sample['label'].long().cuda()) 47 | 48 | optimizer.zero_grad() 49 | out = net(imgdata) 50 | out = F.log_softmax(out, dim=1) 51 | 52 | loss = criterion(out, imglabel) 53 | 54 | loss.backward() 55 | optimizer.step() 56 | train_loss = loss.item() + train_loss 57 | 58 | pre_label = out.max(dim=1)[1].data.cpu().numpy() 59 | pre_label = [i for i in pre_label] 60 | 61 | true_label = imglabel.data.cpu().numpy() 62 | true_label = [i for i in true_label] 63 | 64 | eval_metrix = eval_semantic_segmentation(pre_label, true_label) 65 | train_acc = eval_metrix['mean_class_accuracy'] + train_acc 66 | train_miou = eval_metrix['miou'] + train_miou 67 | train_class_acc = train_class_acc + eval_metrix['class_accuracy'] 68 | 69 | net = net.eval() 70 | eval_loss = 0 71 | eval_acc = 0 72 | eval_miou = 0 73 | eval_class_acc = 0 74 | 75 | for j, sample in enumerate(val_data): 76 | valImg = Variable(sample['img'].cuda()) 77 | valLabel = Variable(sample['label'].long().cuda()) 78 | 79 | out = net(valImg) 80 | out = F.log_softmax(out, dim=1) 81 | loss = criterion(out, valLabel) 82 | eval_loss = loss.item() + eval_loss 83 | pre_label = out.max(dim=1)[1].data.cpu().numpy() 84 | pre_label = [i for i in pre_label] 85 | 86 | true_label = valLabel.data.cpu().numpy() 87 | true_label = [i for i in true_label] 88 | 89 | eval_metrics = eval_semantic_segmentation(pre_label, true_label) 90 | eval_acc = eval_metrics['mean_class_accuracy'] + eval_acc 91 | eval_miou = eval_metrics['miou'] + eval_miou 92 | eval_class_acc = eval_metrix['class_accuracy'] + eval_class_acc 93 | 94 | cur_time = datetime.now() 95 | h, remainder = divmod((cur_time - prec_time).seconds, 3600) 96 | m, s = divmod(remainder, 60) 97 | 98 | epoch_str = ('Epoch: {}, Train Loss: {:.5f}, Train Acc: {:.5f}, Train Mean IU: {:.5f}, Train_class_acc:{:} \ 99 | Valid Loss: {:.5f}, Valid Acc: {:.5f}, Valid Mean IU: {:.5f} ,Valid Class Acc:{:}'.format( 100 | epoch, train_loss / len(train_data), train_acc / len(train_data), train_miou / len(train_data), train_class_acc / len(train_data), 101 | eval_loss / len(train_data), eval_acc/len(val_data), eval_miou/len(val_data),eval_class_acc / len(val_data))) 102 | time_str = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s) 103 | print(epoch_str + time_str) 104 | 105 | if (max(best) <= eval_miou/len(val_data)): 106 | best.append(eval_miou/len(val_data)) 107 | t.save(net.state_dict(), 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/logs/last.pth') # 'xxx.pth' #保存模型 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | 113 | -------------------------------------------------------------------------------- /Pytorch-FCN/utiles/BilinearUpSampling.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import tensorflow as tf 3 | from keras.layers import * 4 | 5 | def resize_images_bilinear(X, height_factor=1, width_factor=1, target_height=None, target_width=None, data_format='default'): 6 | '''Resizes the images contained in a 4D tensor of shape 7 | - [batch, channels, height, width] (for 'channels_first' data_format) 8 | - [batch, height, width, channels] (for 'channels_last' data_format) 9 | by a factor of (height_factor, width_factor). Both factors should be 10 | positive integers. 11 | ''' 12 | if data_format == 'default': 13 | data_format = K.image_data_format() 14 | if data_format == 'channels_first': 15 | original_shape = K.int_shape(X) 16 | if target_height and target_width: 17 | new_shape = tf.constant(np.array((target_height, target_width)).astype('int32')) 18 | else: 19 | new_shape = tf.shape(X)[2:] 20 | new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) 21 | X = K.permute_dimensions(X, [0, 2, 3, 1]) 22 | X = tf.image.resize_bilinear(X, new_shape) 23 | X = K.permute_dimensions(X, [0, 3, 1, 2]) 24 | if target_height and target_width: 25 | X.set_shape((None, None, target_height, target_width)) 26 | else: 27 | X.set_shape((None, None, original_shape[2] * height_factor, original_shape[3] * width_factor)) 28 | return X 29 | elif data_format == 'channels_last': 30 | original_shape = K.int_shape(X) 31 | if target_height and target_width: 32 | new_shape = tf.constant(np.array((target_height, target_width)).astype('int32')) 33 | else: 34 | new_shape = tf.shape(X)[1:3] 35 | new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) 36 | X = tf.image.resize_bilinear(X, new_shape) 37 | if target_height and target_width: 38 | X.set_shape((None, target_height, target_width, None)) 39 | else: 40 | X.set_shape((None, original_shape[1] * height_factor, original_shape[2] * width_factor, None)) 41 | return X 42 | else: 43 | raise Exception('Invalid data_format: ' + data_format) 44 | 45 | class BilinearUpSampling2D(Layer): 46 | def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs): 47 | if data_format == 'default': 48 | data_format = K.image_data_format() 49 | self.size = tuple(size) 50 | if target_size is not None: 51 | self.target_size = tuple(target_size) 52 | else: 53 | self.target_size = None 54 | assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}' 55 | self.data_format = data_format 56 | self.input_spec = [InputSpec(ndim=4)] 57 | super(BilinearUpSampling2D, self).__init__(**kwargs) 58 | 59 | def compute_output_shape(self, input_shape): 60 | if self.data_format == 'channels_first': 61 | width = int(self.size[0] * input_shape[2] if input_shape[2] is not None else None) 62 | height = int(self.size[1] * input_shape[3] if input_shape[3] is not None else None) 63 | if self.target_size is not None: 64 | width = self.target_size[0] 65 | height = self.target_size[1] 66 | return (input_shape[0], 67 | input_shape[1], 68 | width, 69 | height) 70 | elif self.data_format == 'channels_last': 71 | width = int(self.size[0] * input_shape[1] if input_shape[1] is not None else None) 72 | height = int(self.size[1] * input_shape[2] if input_shape[2] is not None else None) 73 | if self.target_size is not None: 74 | width = self.target_size[0] 75 | height = self.target_size[1] 76 | return (input_shape[0], 77 | width, 78 | height, 79 | input_shape[3]) 80 | else: 81 | raise Exception('Invalid data_format: ' + self.data_format) 82 | 83 | def call(self, x, mask=None): 84 | if self.target_size is not None: 85 | return resize_images_bilinear(x, target_height=self.target_size[0], target_width=self.target_size[1], data_format=self.data_format) 86 | else: 87 | return resize_images_bilinear(x, height_factor=self.size[0], width_factor=self.size[1], data_format=self.data_format) 88 | 89 | def get_config(self): 90 | config = {'size': self.size, 'target_size': self.target_size} 91 | base_config = super(BilinearUpSampling2D, self).get_config() 92 | return dict(list(base_config.items()) + list(config.items())) 93 | -------------------------------------------------------------------------------- /Pytorch-FCN/data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import torch as t 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | from utiles import functional as ff 10 | import skimage.transform 11 | import numpy 12 | 13 | 14 | TRAIN_ROOT = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/train' #/xxx/CamVid/train #训练集的路径 15 | TRAIN_LABEL = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/train_labels' #/xxx/CamVid/train_labels #训练集的标签路径 16 | TEST_ROOT = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/val' #/xxx/CamVid/val #验证集路径 17 | TEST_LABEL = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/val_labels' #/xxx/CamVid/val_labels #验证集标签路径 18 | 19 | train_imgs = os.listdir(TRAIN_ROOT) 20 | train_imgs = [os.path.join(TRAIN_ROOT, img) for img in train_imgs] 21 | train_imgs.sort() 22 | 23 | train_labels = os.listdir(TRAIN_LABEL) 24 | train_labels = [os.path.join(TRAIN_LABEL, label) for label in train_labels] 25 | train_labels.sort() 26 | 27 | test_imgs = os.listdir(TEST_ROOT) 28 | test_imgs = [os.path.join(TEST_ROOT, img) for img in test_imgs] 29 | test_imgs.sort() 30 | 31 | test_labels = os.listdir(TEST_LABEL) 32 | test_labels = [os.path.join(TEST_LABEL, label) for label in test_labels] 33 | test_labels.sort() 34 | 35 | 36 | class FixedCrop(object): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be cropped. 40 | i, j, h, w (int): Image position to be cropped 41 | padding (int or sequence, optional): Optional padding on each border 42 | of the image. Default is 0, i.e no padding. If a sequence of length 43 | 4 is provided, it is used to pad left, top, right, bottom borders 44 | respectively. 45 | 46 | Returns: 47 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 48 | """ 49 | 50 | def __init__(self, i, j, h, w, padding=0): 51 | self.i = i 52 | self.j = j 53 | self.h = h 54 | self.w = w 55 | self.padding = padding 56 | 57 | def __call__(self, img): 58 | """ 59 | Args: 60 | img (PIL Image): Image to be cropped. 61 | 62 | Returns: 63 | PIL Image: Cropped image. 64 | """ 65 | if self.padding > 0: 66 | img = F.pad(img, self.padding) 67 | 68 | return ff.crop(img, self.i, self.j, self.h, self.w) 69 | 70 | 71 | pd_label_color = pd.read_csv('D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/class_dict.csv', sep=',') #/media/zjy/shuju/CamVid_2D/CamVid/class_dict.csv 72 | name_value = pd_label_color['name'].values # ndarray type 73 | num_class = len(name_value) 74 | 75 | colormap = [] 76 | for i in range(len(pd_label_color.index)): 77 | # 通过行号索引行数据 78 | tmp = pd_label_color.iloc[i] 79 | color = [] 80 | color.append(tmp['r']) 81 | color.append(tmp['g']) 82 | color.append(tmp['b']) 83 | colormap.append(color) 84 | 85 | 86 | def center_crop(data, label, crop_size): 87 | height, width = crop_size 88 | data, rect1 = ff.center_crop(data, (height, width)) 89 | label = FixedCrop(*rect1)(label) 90 | 91 | return data, label 92 | 93 | 94 | cm2lbl = np.zeros(256 ** 3) 95 | for i, cm in enumerate(colormap): 96 | cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i 97 | 98 | 99 | def image2label(img): 100 | data = np.array(img, dtype='int32') 101 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2] 102 | return np.array(cm2lbl[idx], dtype='int64') 103 | 104 | 105 | def img_transform(img, label, crop_size): 106 | img, label = center_crop(img, label, crop_size) 107 | label = numpy.array(label) 108 | label1 = skimage.transform.resize(label, (label.shape[0] // 2, label.shape[1] // 2), order=0, mode='reflect', 109 | preserve_range=True) 110 | 111 | label2 = skimage.transform.resize(label, (label.shape[0] // 4, label.shape[1] // 4), order=0, mode='reflect', 112 | preserve_range=True) 113 | 114 | label3 = skimage.transform.resize(label, (label.shape[0] // 8, label.shape[1] // 8), order=0, mode='reflect', 115 | preserve_range=True) 116 | 117 | label4 = skimage.transform.resize(label, (label.shape[0] // 16, label.shape[1] // 16), order=0, mode='reflect', 118 | preserve_range=True) 119 | 120 | label5 = skimage.transform.resize(label, (label.shape[0] // 32, label.shape[1] // 32), order=0, mode='reflect', 121 | preserve_range=True) 122 | 123 | label = Image.fromarray(label.astype('uint8')) 124 | label1 = Image.fromarray(label1.astype('uint8')) 125 | label2 = Image.fromarray(label2.astype('uint8')) 126 | label3 = Image.fromarray(label3.astype('uint8')) 127 | label4 = Image.fromarray(label4.astype('uint8')) 128 | label5 = Image.fromarray(label5.astype('uint8')) 129 | 130 | transform_img = transforms.Compose( 131 | [ 132 | transforms.ToTensor(), 133 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 134 | ] 135 | ) 136 | 137 | img = transform_img(img) 138 | label = image2label(label) 139 | label1 = image2label(label1) 140 | label2 = image2label(label2) 141 | label3 = image2label(label3) 142 | label4 = image2label(label4) 143 | label5 = image2label(label5) 144 | 145 | label = t.from_numpy(label) 146 | label1 = t.from_numpy(label1) 147 | label2 = t.from_numpy(label2) 148 | label3 = t.from_numpy(label3) 149 | label4 = t.from_numpy(label4) 150 | label5 = t.from_numpy(label5) 151 | 152 | return img, label, label1, label2, label3, label4, label5 153 | 154 | 155 | class CamvidDataset(Dataset): 156 | def __init__(self, train=True, crop_size=None, transform=None): 157 | self.train = train 158 | self.train_imgs = train_imgs 159 | self.train_labels = train_labels 160 | 161 | self.test_imgs = test_imgs 162 | self.test_labels = test_labels 163 | 164 | if self.train: 165 | self.imgs = self.train_imgs 166 | self.labels = self.train_labels 167 | else: 168 | self.imgs = self.test_imgs 169 | self.labels = self.test_labels 170 | 171 | self.crop_size = crop_size 172 | self.transforms = transform 173 | 174 | def __getitem__(self, index): 175 | img = self.imgs[index] 176 | label = self.labels[index] 177 | img = Image.open(img) 178 | label = Image.open(label).convert('RGB') 179 | 180 | img, label, label1, label2, label3, label4, label5 = self.transforms(img, label, self.crop_size) 181 | 182 | sample = {'img': img, 'label': label, 'label1': label1, 'label2': label2, 'label3': label3, 'label4': label4, 183 | 'label5': label5} 184 | 185 | return sample 186 | 187 | def __len__(self): 188 | return len(self.imgs) 189 | 190 | 191 | input_size = (352, 480) 192 | Cam_train = CamvidDataset(True, input_size, img_transform) 193 | Cam_val = CamvidDataset(False, input_size, img_transform) 194 | 195 | -------------------------------------------------------------------------------- /Pytorch-FCN/utiles/data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import torch as t 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | from utiles import functional as ff 10 | import skimage.transform 11 | import numpy 12 | 13 | 14 | TRAIN_ROOT = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/train' #/xxx/CamVid/train #训练集的路径 15 | TRAIN_LABEL = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/train_labels' #/xxx/CamVid/train_labels #训练集的标签路径 16 | TEST_ROOT = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/val' #/xxx/CamVid/val #验证集路径 17 | TEST_LABEL = 'D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/val_labels' #/xxx/CamVid/val_labels #验证集标签路径 18 | 19 | train_imgs = os.listdir(TRAIN_ROOT) 20 | train_imgs = [os.path.join(TRAIN_ROOT, img) for img in train_imgs] 21 | train_imgs.sort() 22 | 23 | train_labels = os.listdir(TRAIN_LABEL) 24 | train_labels = [os.path.join(TRAIN_LABEL, label) for label in train_labels] 25 | train_labels.sort() 26 | 27 | test_imgs = os.listdir(TEST_ROOT) 28 | test_imgs = [os.path.join(TEST_ROOT, img) for img in test_imgs] 29 | test_imgs.sort() 30 | 31 | test_labels = os.listdir(TEST_LABEL) 32 | test_labels = [os.path.join(TEST_LABEL, label) for label in test_labels] 33 | test_labels.sort() 34 | 35 | 36 | class FixedCrop(object): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be cropped. 40 | i, j, h, w (int): Image position to be cropped 41 | padding (int or sequence, optional): Optional padding on each border 42 | of the image. Default is 0, i.e no padding. If a sequence of length 43 | 4 is provided, it is used to pad left, top, right, bottom borders 44 | respectively. 45 | 46 | Returns: 47 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 48 | """ 49 | 50 | def __init__(self, i, j, h, w, padding=0): 51 | self.i = i 52 | self.j = j 53 | self.h = h 54 | self.w = w 55 | self.padding = padding 56 | 57 | def __call__(self, img): 58 | """ 59 | Args: 60 | img (PIL Image): Image to be cropped. 61 | 62 | Returns: 63 | PIL Image: Cropped image. 64 | """ 65 | if self.padding > 0: 66 | img = F.pad(img, self.padding) 67 | 68 | return ff.crop(img, self.i, self.j, self.h, self.w) 69 | 70 | 71 | pd_label_color = pd.read_csv('D:/机器学习/cvpaper/03语义分割-fcn论文原文及代码附件(1)/code/CamVid/class_dict.csv', sep=',') #/media/zjy/shuju/CamVid_2D/CamVid/class_dict.csv 72 | name_value = pd_label_color['name'].values # ndarray type 73 | num_class = len(name_value) 74 | 75 | colormap = [] 76 | for i in range(len(pd_label_color.index)): 77 | # 通过行号索引行数据 78 | tmp = pd_label_color.iloc[i] 79 | color = [] 80 | color.append(tmp['r']) 81 | color.append(tmp['g']) 82 | color.append(tmp['b']) 83 | colormap.append(color) 84 | 85 | 86 | def center_crop(data, label, crop_size): 87 | height, width = crop_size 88 | data, rect1 = ff.center_crop(data, (height, width)) 89 | label = FixedCrop(*rect1)(label) 90 | 91 | return data, label 92 | 93 | 94 | cm2lbl = np.zeros(256 ** 3) 95 | for i, cm in enumerate(colormap): 96 | cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i 97 | 98 | 99 | def image2label(img): 100 | data = np.array(img, dtype='int32') 101 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2] 102 | return np.array(cm2lbl[idx], dtype='int64') 103 | 104 | 105 | def img_transform(img, label, crop_size): 106 | img, label = center_crop(img, label, crop_size) 107 | label = numpy.array(label) 108 | label1 = skimage.transform.resize(label, (label.shape[0] // 2, label.shape[1] // 2), order=0, mode='reflect', 109 | preserve_range=True) 110 | 111 | label2 = skimage.transform.resize(label, (label.shape[0] // 4, label.shape[1] // 4), order=0, mode='reflect', 112 | preserve_range=True) 113 | 114 | label3 = skimage.transform.resize(label, (label.shape[0] // 8, label.shape[1] // 8), order=0, mode='reflect', 115 | preserve_range=True) 116 | 117 | label4 = skimage.transform.resize(label, (label.shape[0] // 16, label.shape[1] // 16), order=0, mode='reflect', 118 | preserve_range=True) 119 | 120 | label5 = skimage.transform.resize(label, (label.shape[0] // 32, label.shape[1] // 32), order=0, mode='reflect', 121 | preserve_range=True) 122 | 123 | label = Image.fromarray(label.astype('uint8')) 124 | label1 = Image.fromarray(label1.astype('uint8')) 125 | label2 = Image.fromarray(label2.astype('uint8')) 126 | label3 = Image.fromarray(label3.astype('uint8')) 127 | label4 = Image.fromarray(label4.astype('uint8')) 128 | label5 = Image.fromarray(label5.astype('uint8')) 129 | 130 | transform_img = transforms.Compose( 131 | [ 132 | transforms.ToTensor(), 133 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 134 | ] 135 | ) 136 | 137 | img = transform_img(img) 138 | label = image2label(label) 139 | label1 = image2label(label1) 140 | label2 = image2label(label2) 141 | label3 = image2label(label3) 142 | label4 = image2label(label4) 143 | label5 = image2label(label5) 144 | 145 | label = t.from_numpy(label) 146 | label1 = t.from_numpy(label1) 147 | label2 = t.from_numpy(label2) 148 | label3 = t.from_numpy(label3) 149 | label4 = t.from_numpy(label4) 150 | label5 = t.from_numpy(label5) 151 | 152 | return img, label, label1, label2, label3, label4, label5 153 | 154 | 155 | class CamvidDataset(Dataset): 156 | def __init__(self, train=True, crop_size=None, transform=None): 157 | self.train = train 158 | self.train_imgs = train_imgs 159 | self.train_labels = train_labels 160 | 161 | self.test_imgs = test_imgs 162 | self.test_labels = test_labels 163 | 164 | if self.train: 165 | self.imgs = self.train_imgs 166 | self.labels = self.train_labels 167 | else: 168 | self.imgs = self.test_imgs 169 | self.labels = self.test_labels 170 | 171 | self.crop_size = crop_size 172 | self.transforms = transform 173 | 174 | def __getitem__(self, index): 175 | img = self.imgs[index] 176 | label = self.labels[index] 177 | img = Image.open(img) 178 | label = Image.open(label).convert('RGB') 179 | 180 | img, label, label1, label2, label3, label4, label5 = self.transforms(img, label, self.crop_size) 181 | 182 | sample = {'img': img, 'label': label, 'label1': label1, 'label2': label2, 'label3': label3, 'label4': label4, 183 | 'label5': label5} 184 | 185 | return sample 186 | 187 | def __len__(self): 188 | return len(self.imgs) 189 | 190 | 191 | input_size = (352, 480) 192 | Cam_train = CamvidDataset(True, input_size, img_transform) 193 | Cam_val = CamvidDataset(False, input_size, img_transform) 194 | 195 | -------------------------------------------------------------------------------- /Pytorch-FCN/utiles/evalution_segmentaion.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numpy as np 4 | import six 5 | 6 | 7 | def calc_semantic_segmentation_confusion(pred_labels, gt_labels): 8 | """Collect a confusion matrix. 9 | 10 | The number of classes :math:`n\_class` is 11 | :math:`max(pred\_labels, gt\_labels) + 1`, which is 12 | the maximum class id of the inputs added by one. 13 | 14 | Args: 15 | pred_labels (iterable of numpy.ndarray): A collection of predicted 16 | labels. The shape of a label array 17 | is :math:`(H, W)`. :math:`H` and :math:`W` 18 | are height and width of the label. 19 | gt_labels (iterable of numpy.ndarray): A collection of ground 20 | truth labels. The shape of a ground truth label array is 21 | :math:`(H, W)`, and its corresponding prediction label should 22 | have the same shape. 23 | A pixel with value :obj:`-1` will be ignored during evaluation. 24 | 25 | Returns: 26 | numpy.ndarray: 27 | A confusion matrix. Its shape is :math:`(n\_class, n\_class)`. 28 | The :math:`(i, j)` th element corresponds to the number of pixels 29 | that are labeled as class :math:`i` by the ground truth and 30 | class :math:`j` by the prediction. 31 | 32 | """ 33 | pred_labels = iter(pred_labels) 34 | 35 | gt_labels = iter(gt_labels) 36 | 37 | n_class = 12 38 | confusion = np.zeros((n_class, n_class), dtype=np.int64) 39 | for pred_label, gt_label in six.moves.zip(pred_labels, gt_labels): 40 | if pred_label.ndim != 2 or gt_label.ndim != 2: 41 | raise ValueError('ndim of labels should be two.') 42 | if pred_label.shape != gt_label.shape: 43 | raise ValueError('Shape of ground truth and prediction should' 44 | ' be same.') 45 | pred_label = pred_label.flatten() 46 | gt_label = gt_label.flatten() 47 | 48 | # Dynamically expand the confusion matrix if necessary. 49 | lb_max = np.max((pred_label, gt_label)) 50 | # print(lb_max) 51 | if lb_max >= n_class: 52 | expanded_confusion = np.zeros( 53 | (lb_max + 1, lb_max + 1), dtype=np.int64) 54 | expanded_confusion[0:n_class, 0:n_class] = confusion 55 | 56 | n_class = lb_max + 1 57 | confusion = expanded_confusion 58 | 59 | # Count statistics from valid pixels. 极度巧妙 × class_nums 正好使得每个ij能够对应. 60 | mask = gt_label >= 0 61 | confusion += np.bincount( 62 | n_class * gt_label[mask].astype(int) + 63 | pred_label[mask], minlength=n_class ** 2).reshape((n_class, n_class)) 64 | 65 | for iter_ in (pred_labels, gt_labels): 66 | # This code assumes any iterator does not contain None as its items. 67 | if next(iter_, None) is not None: 68 | raise ValueError('Length of input iterables need to be same') 69 | 70 | # confusion = np.delete(confusion, 11, axis=0) 71 | # confusion = np.delete(confusion, 11, axis=1) 72 | return confusion 73 | 74 | 75 | def calc_semantic_segmentation_iou(confusion): 76 | """Calculate Intersection over Union with a given confusion matrix. 77 | 78 | The definition of Intersection over Union (IoU) is as follows, 79 | where :math:`N_{ij}` is the number of pixels 80 | that are labeled as class :math:`i` by the ground truth and 81 | class :math:`j` by the prediction. 82 | 83 | * :math:`\\text{IoU of the i-th class} = \ 84 | \\frac{N_{ii}}{\\sum_{j=1}^k N_{ij} + \\sum_{j=1}^k N_{ji} - N_{ii}}` 85 | 86 | Args: 87 | confusion (numpy.ndarray): A confusion matrix. Its shape is 88 | :math:`(n\_class, n\_class)`. 89 | The :math:`(i, j)` th element corresponds to the number of pixels 90 | that are labeled as class :math:`i` by the ground truth and 91 | class :math:`j` by the prediction. 92 | 93 | Returns: 94 | numpy.ndarray: 95 | An array of IoUs for the :math:`n\_class` classes. Its shape is 96 | :math:`(n\_class,)`. 97 | 98 | """ 99 | iou_denominator = (confusion.sum(axis=1) + confusion.sum(axis=0) 100 | - np.diag(confusion)) 101 | iou = np.diag(confusion) / iou_denominator 102 | return iou[:-1] 103 | # return iou 104 | 105 | 106 | def eval_semantic_segmentation(pred_labels, gt_labels): 107 | """Evaluate metrics used in Semantic Segmentation. 108 | 109 | This function calculates Intersection over Union (IoU), Pixel Accuracy 110 | and Class Accuracy for the task of semantic segmentation. 111 | 112 | The definition of metrics calculated by this function is as follows, 113 | where :math:`N_{ij}` is the number of pixels 114 | that are labeled as class :math:`i` by the ground truth and 115 | class :math:`j` by the prediction. 116 | 117 | * :math:`\\text{IoU of the i-th class} = \ 118 | \\frac{N_{ii}}{\\sum_{j=1}^k N_{ij} + \\sum_{j=1}^k N_{ji} - N_{ii}}` 119 | * :math:`\\text{mIoU} = \\frac{1}{k} \ 120 | \\sum_{i=1}^k \ 121 | \\frac{N_{ii}}{\\sum_{j=1}^k N_{ij} + \\sum_{j=1}^k N_{ji} - N_{ii}}` 122 | * :math:`\\text{Pixel Accuracy} = \ 123 | \\frac \ 124 | {\\sum_{i=1}^k N_{ii}} \ 125 | {\\sum_{i=1}^k \\sum_{j=1}^k N_{ij}}` 126 | * :math:`\\text{Class Accuracy} = \ 127 | \\frac{N_{ii}}{\\sum_{j=1}^k N_{ij}}` 128 | * :math:`\\text{Mean Class Accuracy} = \\frac{1}{k} \ 129 | \\sum_{i=1}^k \ 130 | \\frac{N_{ii}}{\\sum_{j=1}^k N_{ij}}` 131 | 132 | The more detailed description of the above metrics can be found in a 133 | review on semantic segmentation [#]_. 134 | 135 | The number of classes :math:`n\_class` is 136 | :math:`max(pred\_labels, gt\_labels) + 1`, which is 137 | the maximum class id of the inputs added by one. 138 | 139 | .. [#] Alberto Garcia-Garcia, Sergio Orts-Escolano, Sergiu Oprea, \ 140 | Victor Villena-Martinez, Jose Garcia-Rodriguez. \ 141 | `A Review on Deep Learning Techniques Applied to Semantic Segmentation \ 142 | `_. arXiv 2017. 143 | 144 | Args: 145 | pred_labels (iterable of numpy.ndarray): A collection of predicted 146 | labels. The shape of a label array 147 | is :math:`(H, W)`. :math:`H` and :math:`W` 148 | are height and width of the label. 149 | For example, this is a list of labels 150 | :obj:`[label_0, label_1, ...]`, where 151 | :obj:`label_i.shape = (H_i, W_i)`. 152 | gt_labels (iterable of numpy.ndarray): A collection of ground 153 | truth labels. The shape of a ground truth label array is 154 | :math:`(H, W)`, and its corresponding prediction label should 155 | have the same shape. 156 | A pixel with value :obj:`-1` will be ignored during evaluation. 157 | 158 | Returns: 159 | dict: 160 | 161 | The keys, value-types and the description of the values are listed 162 | below. 163 | 164 | * **iou** (*numpy.ndarray*): An array of IoUs for the \ 165 | :math:`n\_class` classes. Its shape is :math:`(n\_class,)`. 166 | * **miou** (*float*): The average of IoUs over classes. 167 | * **pixel_accuracy** (*float*): The computed pixel accuracy. 168 | * **class_accuracy** (*numpy.ndarray*): An array of class accuracies \ 169 | for the :math:`n\_class` classes. \ 170 | Its shape is :math:`(n\_class,)`. 171 | * **mean_class_accuracy** (*float*): The average of class accuracies. 172 | 173 | """ 174 | # Evaluation code is based on 175 | # https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/ 176 | # score.py#L37 177 | confusion = calc_semantic_segmentation_confusion( 178 | pred_labels, gt_labels) 179 | iou = calc_semantic_segmentation_iou(confusion) 180 | pixel_accuracy = np.diag(confusion).sum() / confusion.sum() 181 | class_accuracy = np.diag(confusion) / (np.sum(confusion, axis=1) + 1e-10) 182 | 183 | return {'iou': iou, 'miou': np.nanmean(iou), 184 | 'pixel_accuracy': pixel_accuracy, 185 | 'class_accuracy': class_accuracy, 186 | 'mean_class_accuracy': np.nanmean(class_accuracy[:-1])} 187 | # 'mean_class_accuracy': np.nanmean(class_accuracy)} -------------------------------------------------------------------------------- /Pytorch-FCN/FCN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torchvision import models 9 | from torchvision.models.vgg import VGG 10 | 11 | 12 | class FCN32s(nn.Module): 13 | 14 | def __init__(self, pretrained_net, n_class): 15 | super().__init__() 16 | self.n_class = n_class 17 | self.pretrained_net = pretrained_net 18 | self.relu = nn.ReLU(inplace=True) 19 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1) 20 | self.bn1 = nn.BatchNorm2d(512) 21 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1) 22 | self.bn2 = nn.BatchNorm2d(256) 23 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1) 24 | self.bn3 = nn.BatchNorm2d(128) 25 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1) 26 | self.bn4 = nn.BatchNorm2d(64) 27 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1) 28 | self.bn5 = nn.BatchNorm2d(32) 29 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 30 | 31 | def forward(self, x): 32 | output = self.pretrained_net(x) 33 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 34 | 35 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16) 36 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 37 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 38 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 39 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 40 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 41 | 42 | return score # size=(N, n_class, x.H/1, x.W/1) 43 | 44 | 45 | class FCN16s(nn.Module): 46 | 47 | def __init__(self, pretrained_net, n_class): 48 | super().__init__() 49 | self.n_class = n_class 50 | self.pretrained_net = pretrained_net 51 | self.relu = nn.ReLU(inplace=True) 52 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 53 | self.bn1 = nn.BatchNorm2d(512) 54 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 55 | self.bn2 = nn.BatchNorm2d(256) 56 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 57 | self.bn3 = nn.BatchNorm2d(128) 58 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 59 | self.bn4 = nn.BatchNorm2d(64) 60 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 61 | self.bn5 = nn.BatchNorm2d(32) 62 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 63 | 64 | def forward(self, x): 65 | output = self.pretrained_net(x) 66 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 67 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 68 | 69 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16) 70 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16) 71 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 72 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 73 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 74 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 75 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 76 | 77 | return score # size=(N, n_class, x.H/1, x.W/1) 78 | 79 | 80 | class FCN8s(nn.Module): 81 | 82 | def __init__(self, pretrained_net, n_class): 83 | super().__init__() 84 | self.n_class = n_class 85 | self.pretrained_net = pretrained_net 86 | self.relu = nn.ReLU(inplace=True) 87 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 88 | self.bn1 = nn.BatchNorm2d(512) 89 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 90 | self.bn2 = nn.BatchNorm2d(256) 91 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 92 | self.bn3 = nn.BatchNorm2d(128) 93 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 94 | self.bn4 = nn.BatchNorm2d(64) 95 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 96 | self.bn5 = nn.BatchNorm2d(32) 97 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 98 | 99 | def forward(self, x): 100 | output = self.pretrained_net(x) 101 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 102 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 103 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8) 104 | 105 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16) 106 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16) 107 | score = self.relu(self.deconv2(score)) # size=(N, 256, x.H/8, x.W/8) 108 | score = self.bn2(score + x3) # element-wise add, size=(N, 256, x.H/8, x.W/8) 109 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 110 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 111 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 112 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 113 | 114 | return score # size=(N, n_class, x.H/1, x.W/1) 115 | 116 | 117 | class FCNs(nn.Module): 118 | 119 | def __init__(self, pretrained_net, n_class): 120 | super().__init__() 121 | self.n_class = n_class 122 | self.pretrained_net = pretrained_net 123 | self.relu = nn.ReLU(inplace=True) 124 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 125 | self.bn1 = nn.BatchNorm2d(512) 126 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 127 | self.bn2 = nn.BatchNorm2d(256) 128 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 129 | self.bn3 = nn.BatchNorm2d(128) 130 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 131 | self.bn4 = nn.BatchNorm2d(64) 132 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 133 | self.bn5 = nn.BatchNorm2d(32) 134 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 135 | 136 | def forward(self, x): 137 | output = self.pretrained_net(x) 138 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 139 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 140 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8) 141 | x2 = output['x2'] # size=(N, 128, x.H/4, x.W/4) 142 | x1 = output['x1'] # size=(N, 64, x.H/2, x.W/2) 143 | 144 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16) 145 | score = score + x4 # element-wise add, size=(N, 512, x.H/16, x.W/16) 146 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 147 | score = score + x3 # element-wise add, size=(N, 256, x.H/8, x.W/8) 148 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 149 | score = score + x2 # element-wise add, size=(N, 128, x.H/4, x.W/4) 150 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 151 | score = score + x1 # element-wise add, size=(N, 64, x.H/2, x.W/2) 152 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 153 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 154 | 155 | return score # size=(N, n_class, x.H/1, x.W/1) 156 | 157 | 158 | class VGGNet(VGG): 159 | def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False): 160 | super().__init__(make_layers(cfg[model])) 161 | self.ranges = ranges[model] 162 | 163 | if pretrained: 164 | exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model) 165 | 166 | if not requires_grad: 167 | for param in super().parameters(): 168 | param.requires_grad = False 169 | 170 | if remove_fc: # delete redundant fully-connected layer params, can save memory 171 | del self.classifier 172 | 173 | if show_params: 174 | for name, param in self.named_parameters(): 175 | print(name, param.size()) 176 | 177 | def forward(self, x): 178 | output = {} 179 | 180 | # get the output of each maxpooling layer (5 maxpool in VGG net) 181 | for idx in range(len(self.ranges)): 182 | for layer in range(self.ranges[idx][0], self.ranges[idx][1]): 183 | x = self.features[layer](x) 184 | output["x%d"%(idx+1)] = x 185 | 186 | return output 187 | 188 | 189 | ranges = { 190 | 'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)), 191 | 'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)), 192 | 'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)), 193 | 'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37)) 194 | } 195 | 196 | # cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 197 | cfg = { 198 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 199 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 200 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 201 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 202 | } 203 | 204 | def make_layers(cfg, batch_norm=False): 205 | layers = [] 206 | in_channels = 3 207 | for v in cfg: 208 | if v == 'M': 209 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 210 | else: 211 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 212 | if batch_norm: 213 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 214 | else: 215 | layers += [conv2d, nn.ReLU(inplace=True)] 216 | in_channels = v 217 | return nn.Sequential(*layers) 218 | 219 | 220 | -------------------------------------------------------------------------------- /Pytorch-FCN/utiles/functional.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | from PIL import Image, ImageOps, ImageEnhance 5 | 6 | # from mxnet import nd 7 | 8 | try: 9 | import accimage 10 | except ImportError: 11 | accimage = None 12 | import numpy as np 13 | import numbers 14 | import collections 15 | 16 | 17 | def _is_pil_image(img): 18 | if accimage is not None: 19 | return isinstance(img, (Image.Image, accimage.Image)) 20 | else: 21 | return isinstance(img, Image.Image) 22 | 23 | 24 | def _is_tensor_image(img): 25 | return torch.is_tensor(img) and img.ndimension() == 3 26 | 27 | 28 | # def _is_ndarray_image(img): 29 | # return isinstance(img, nd.NDArray) and img.ndim == 3 30 | 31 | 32 | def _is_numpy_image(img): 33 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 34 | 35 | 36 | def to_tensor(pic): 37 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 38 | See ``ToTensor`` for more details. 39 | Args: 40 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 41 | Returns: 42 | Tensor: Converted image. 43 | """ 44 | if not (_is_pil_image(pic) or _is_numpy_image(pic)): 45 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format( 46 | type(pic))) 47 | 48 | if isinstance(pic, np.ndarray): 49 | # handle numpy array 50 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 51 | # backward compatibility 52 | return img.float().div(255) 53 | 54 | if accimage is not None and isinstance(pic, accimage.Image): 55 | nppic = np.zeros( 56 | [pic.channels, pic.height, pic.width], dtype=np.float32) 57 | pic.copyto(nppic) 58 | return torch.from_numpy(nppic) 59 | 60 | # handle PIL Image 61 | if pic.mode == 'I': 62 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 63 | elif pic.mode == 'I;16': 64 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 65 | else: 66 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 67 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 68 | if pic.mode == 'YCbCr': 69 | nchannel = 3 70 | elif pic.mode == 'I;16': 71 | nchannel = 1 72 | else: 73 | nchannel = len(pic.mode) 74 | img = img.view(pic.size[1], pic.size[0], nchannel) 75 | # put it from HWC to CHW format 76 | # yikes, this transpose takes 80% of the loading time/CPU 77 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 78 | if isinstance(img, torch.ByteTensor): 79 | return img.float().div(255) 80 | else: 81 | return img 82 | 83 | 84 | # def to_array(pic): 85 | # """Convert a ``PIL Image`` or ``numpy.ndarray`` to nd.array. 86 | # See ``ToArray`` for more details. 87 | # Args: 88 | # pic (PIL Image or numpy.ndarray): Image to be converted to nd.array. 89 | # Returns: 90 | # Array: Converted image. 91 | # """ 92 | # if not (_is_pil_image(pic) or _is_numpy_image(pic)): 93 | # raise TypeError('pic should be PIL Image or ndarray. Got {}'.format( 94 | # type(pic))) 95 | # 96 | # if isinstance(pic, np.ndarray): 97 | # # handle numpy array 98 | # img = nd.array(pic.transpose((2, 0, 1))) 99 | # # backward compatibility 100 | # return img.astype(np.float32) / 255 101 | # 102 | # if accimage is not None and isinstance(pic, accimage.Image): 103 | # nppic = np.zeros( 104 | # [pic.channels, pic.height, pic.width], dtype=np.float32) 105 | # pic.copyto(nppic) 106 | # return nd.array(nppic) 107 | # 108 | # # handle PIL Image 109 | # if pic.mode == 'I': 110 | # img = nd.array(pic, dtype=np.int32) 111 | # elif pic.mode == 'I;16': 112 | # img = nd.array(pic, dtype=np.int16) 113 | # else: 114 | # img = nd.array(pic, dtype=np.float32) 115 | # # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 116 | # if pic.mode == 'YCbCr': 117 | # nchannel = 3 118 | # elif pic.mode == 'I;16': 119 | # nchannel = 1 120 | # else: 121 | # nchannel = len(pic.mode) 122 | # img = img.reshape((img.shape[0], img.shape[1], nchannel)) 123 | # # put it from HWC to CHW format 124 | # # yikes, this transpose takes 80% of the loading time/CPU 125 | # img = img.transpose((2, 0, 1)) 126 | # return img.astype(np.float32) / 255 127 | 128 | 129 | def to_pil_image(pic, mode=None): 130 | """Convert a tensor or an ndarray to PIL Image. 131 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 132 | Args: 133 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 134 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 135 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 136 | Returns: 137 | PIL Image: Image converted to PIL Image. 138 | """ 139 | if not (_is_numpy_image(pic) or _is_tensor_image(pic)): 140 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format( 141 | type(pic))) 142 | 143 | npimg = pic 144 | if isinstance(pic, torch.FloatTensor): 145 | pic = pic.mul(255).byte() 146 | if torch.is_tensor(pic): 147 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 148 | 149 | if not isinstance(npimg, np.ndarray): 150 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 151 | 'not {}'.format(type(npimg))) 152 | 153 | if npimg.shape[2] == 1: 154 | expected_mode = None 155 | npimg = npimg[:, :, 0] 156 | if npimg.dtype == np.uint8: 157 | expected_mode = 'L' 158 | if npimg.dtype == np.int16: 159 | expected_mode = 'I;16' 160 | if npimg.dtype == np.int32: 161 | expected_mode = 'I' 162 | elif npimg.dtype == np.float32: 163 | expected_mode = 'F' 164 | if mode is not None and mode != expected_mode: 165 | raise ValueError( 166 | "Incorrect mode ({}) supplied for input type {}. Should be {}" 167 | .format(mode, np.dtype, expected_mode)) 168 | mode = expected_mode 169 | 170 | elif npimg.shape[2] == 4: 171 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 172 | if mode is not None and mode not in permitted_4_channel_modes: 173 | raise ValueError( 174 | "Only modes {} are supported for 4D inputs".format( 175 | permitted_4_channel_modes)) 176 | 177 | if mode is None and npimg.dtype == np.uint8: 178 | mode = 'RGBA' 179 | else: 180 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 181 | if mode is not None and mode not in permitted_3_channel_modes: 182 | raise ValueError( 183 | "Only modes {} are supported for 3D inputs".format( 184 | permitted_3_channel_modes)) 185 | if mode is None and npimg.dtype == np.uint8: 186 | mode = 'RGB' 187 | 188 | if mode is None: 189 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 190 | 191 | return Image.fromarray(npimg, mode=mode) 192 | 193 | 194 | def normalize(tensor, mean, std): 195 | """Normalize a tensor image with mean and standard deviation. 196 | See ``Normalize`` for more details. 197 | Args: 198 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 199 | mean (sequence): Sequence of means for each channel. 200 | std (sequence): Sequence of standard deviations for each channely. 201 | Returns: 202 | Tensor: Normalized Tensor image. 203 | """ 204 | if not _is_tensor_image(tensor): 205 | raise TypeError('tensor is not a torch image.') 206 | # TODO: make efficient 207 | for t, m, s in zip(tensor, mean, std): 208 | t.sub_(m).div_(s) 209 | return tensor 210 | 211 | 212 | def resize(img, size, interpolation=Image.BILINEAR): 213 | """Resize the input PIL Image to the given size. 214 | Args: 215 | img (PIL Image): Image to be resized. 216 | size (sequence or int): Desired output size. If size is a sequence like 217 | (h, w), the output size will be matched to this. If size is an int, 218 | the smaller edge of the image will be matched to this number maintaing 219 | the aspect ratio. i.e, if height > width, then image will be rescaled to 220 | (size * height / width, size) 221 | interpolation (int, optional): Desired interpolation. Default is 222 | ``PIL.Image.BILINEAR`` 223 | Returns: 224 | PIL Image: Resized image. 225 | """ 226 | if not _is_pil_image(img): 227 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 228 | if not (isinstance(size, int) or 229 | (isinstance(size, collections.Iterable) and len(size) == 2)): 230 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 231 | 232 | if isinstance(size, int): 233 | w, h = img.size 234 | if (w <= h and w == size) or (h <= w and h == size): 235 | return img 236 | if w < h: 237 | ow = size 238 | oh = int(size * h / w) 239 | return img.resize((ow, oh), interpolation) 240 | else: 241 | oh = size 242 | ow = int(size * w / h) 243 | return img.resize((ow, oh), interpolation) 244 | else: 245 | return img.resize(size[::-1], interpolation) 246 | 247 | 248 | def pad(img, padding, fill=0): 249 | """Pad the given PIL Image on all sides with the given "pad" value. 250 | Args: 251 | img (PIL Image): Image to be padded. 252 | padding (int or tuple): Padding on each border. If a single int is provided this 253 | is used to pad all borders. If tuple of length 2 is provided this is the padding 254 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 255 | this is the padding for the left, top, right and bottom borders 256 | respectively. 257 | fill: Pixel fill value. Default is 0. If a tuple of 258 | length 3, it is used to fill R, G, B channels respectively. 259 | Returns: 260 | PIL Image: Padded image. 261 | """ 262 | if not _is_pil_image(img): 263 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 264 | 265 | if not isinstance(padding, (numbers.Number, tuple)): 266 | raise TypeError('Got inappropriate padding arg') 267 | if not isinstance(fill, (numbers.Number, str, tuple)): 268 | raise TypeError('Got inappropriate fill arg') 269 | 270 | if isinstance(padding, 271 | collections.Sequence) and len(padding) not in [2, 4]: 272 | raise ValueError( 273 | "Padding must be an int or a 2, or 4 element tuple, not a " + 274 | "{} element tuple".format(len(padding))) 275 | 276 | return ImageOps.expand(img, border=padding, fill=fill) 277 | 278 | 279 | def crop(img, i, j, h, w): 280 | """Crop the given PIL Image. 281 | Args: 282 | img (PIL Image): Image to be cropped. 283 | i: Upper pixel coordinate. 284 | j: Left pixel coordinate. 285 | h: Height of the cropped image. 286 | w: Width of the cropped image. 287 | Returns: 288 | PIL Image: Cropped image. 289 | """ 290 | if not _is_pil_image(img): 291 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 292 | 293 | return img.crop((j, i, j + w, i + h)) 294 | 295 | 296 | def center_crop(img, output_size): 297 | if isinstance(output_size, numbers.Number): 298 | output_size = (int(output_size), int(output_size)) 299 | w, h = img.size 300 | th, tw = output_size 301 | i = int(round((h - th) / 2.)) 302 | j = int(round((w - tw) / 2.)) 303 | return crop(img, i, j, th, tw), (i, j, th, tw) 304 | 305 | 306 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 307 | """Crop the given PIL Image and resize it to desired size. 308 | Notably used in RandomResizedCrop. 309 | Args: 310 | img (PIL Image): Image to be cropped. 311 | i: Upper pixel coordinate. 312 | j: Left pixel coordinate. 313 | h: Height of the cropped image. 314 | w: Width of the cropped image. 315 | size (sequence or int): Desired output size. Same semantics as ``scale``. 316 | interpolation (int, optional): Desired interpolation. Default is 317 | ``PIL.Image.BILINEAR``. 318 | Returns: 319 | PIL Image: Cropped image. 320 | """ 321 | assert _is_pil_image(img), 'img should be PIL Image' 322 | img = crop(img, i, j, h, w) 323 | img = resize(img, size, interpolation) 324 | return img 325 | 326 | 327 | def hflip(img): 328 | """Horizontally flip the given PIL Image. 329 | Args: 330 | img (PIL Image): Image to be flipped. 331 | Returns: 332 | PIL Image: Horizontall flipped image. 333 | """ 334 | if not _is_pil_image(img): 335 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 336 | 337 | return img.transpose(Image.FLIP_LEFT_RIGHT) 338 | 339 | 340 | def vflip(img): 341 | """Vertically flip the given PIL Image. 342 | Args: 343 | img (PIL Image): Image to be flipped. 344 | Returns: 345 | PIL Image: Vertically flipped image. 346 | """ 347 | if not _is_pil_image(img): 348 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 349 | 350 | return img.transpose(Image.FLIP_TOP_BOTTOM) 351 | 352 | 353 | def five_crop(img, size): 354 | """Crop the given PIL Image into four corners and the central crop. 355 | .. Note:: 356 | This transform returns a tuple of images and there may be a 357 | mismatch in the number of inputs and targets your ``Dataset`` returns. 358 | Args: 359 | size (sequence or int): Desired output size of the crop. If size is an 360 | int instead of sequence like (h, w), a square crop (size, size) is 361 | made. 362 | Returns: 363 | tuple: tuple (tl, tr, bl, br, center) corresponding top left, 364 | top right, bottom left, bottom right and center crop. 365 | """ 366 | if isinstance(size, numbers.Number): 367 | size = (int(size), int(size)) 368 | else: 369 | assert len( 370 | size) == 2, "Please provide only two dimensions (h, w) for size." 371 | 372 | w, h = img.size 373 | crop_h, crop_w = size 374 | if crop_w > w or crop_h > h: 375 | raise ValueError( 376 | "Requested crop size {} is bigger than input size {}".format( 377 | size, (h, w))) 378 | tl = img.crop((0, 0, crop_w, crop_h)) 379 | tr = img.crop((w - crop_w, 0, w, crop_h)) 380 | bl = img.crop((0, h - crop_h, crop_w, h)) 381 | br = img.crop((w - crop_w, h - crop_h, w, h)) 382 | center = center_crop(img, (crop_h, crop_w)) 383 | return (tl, tr, bl, br, center) 384 | 385 | 386 | def ten_crop(img, size, vertical_flip=False): 387 | """Crop the given PIL Image into four corners and the central crop plus the 388 | flipped version of these (horizontal flipping is used by default). 389 | .. Note:: 390 | This transform returns a tuple of images and there may be a 391 | mismatch in the number of inputs and targets your ``Dataset`` returns. 392 | Args: 393 | size (sequence or int): Desired output size of the crop. If size is an 394 | int instead of sequence like (h, w), a square crop (size, size) is 395 | made. 396 | vertical_flip (bool): Use vertical flipping instead of horizontal 397 | Returns: 398 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, 399 | br_flip, center_flip) corresponding top left, top right, 400 | bottom left, bottom right and center crop and same for the 401 | flipped image. 402 | """ 403 | if isinstance(size, numbers.Number): 404 | size = (int(size), int(size)) 405 | else: 406 | assert len( 407 | size) == 2, "Please provide only two dimensions (h, w) for size." 408 | 409 | first_five = five_crop(img, size) 410 | 411 | if vertical_flip: 412 | img = vflip(img) 413 | else: 414 | img = hflip(img) 415 | 416 | second_five = five_crop(img, size) 417 | return first_five + second_five 418 | 419 | 420 | def adjust_brightness(img, brightness_factor): 421 | """Adjust brightness of an Image. 422 | Args: 423 | img (PIL Image): PIL Image to be adjusted. 424 | brightness_factor (float): How much to adjust the brightness. Can be 425 | any non negative number. 0 gives a black image, 1 gives the 426 | original image while 2 increases the brightness by a factor of 2. 427 | Returns: 428 | PIL Image: Brightness adjusted image. 429 | """ 430 | if not _is_pil_image(img): 431 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 432 | 433 | enhancer = ImageEnhance.Brightness(img) 434 | img = enhancer.enhance(brightness_factor) 435 | return img 436 | 437 | 438 | def adjust_contrast(img, contrast_factor): 439 | """Adjust contrast of an Image. 440 | Args: 441 | img (PIL Image): PIL Image to be adjusted. 442 | contrast_factor (float): How much to adjust the contrast. Can be any 443 | non negative number. 0 gives a solid gray image, 1 gives the 444 | original image while 2 increases the contrast by a factor of 2. 445 | Returns: 446 | PIL Image: Contrast adjusted image. 447 | """ 448 | if not _is_pil_image(img): 449 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 450 | 451 | enhancer = ImageEnhance.Contrast(img) 452 | img = enhancer.enhance(contrast_factor) 453 | return img 454 | 455 | 456 | def adjust_saturation(img, saturation_factor): 457 | """Adjust color saturation of an image. 458 | Args: 459 | img (PIL Image): PIL Image to be adjusted. 460 | saturation_factor (float): How much to adjust the saturation. 0 will 461 | give a black and white image, 1 will give the original image while 462 | 2 will enhance the saturation by a factor of 2. 463 | Returns: 464 | PIL Image: Saturation adjusted image. 465 | """ 466 | if not _is_pil_image(img): 467 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 468 | 469 | enhancer = ImageEnhance.Color(img) 470 | img = enhancer.enhance(saturation_factor) 471 | return img 472 | 473 | 474 | def adjust_hue(img, hue_factor): 475 | """Adjust hue of an image. 476 | The image hue is adjusted by converting the image to HSV and 477 | cyclically shifting the intensities in the hue channel (H). 478 | The image is then converted back to original image mode. 479 | `hue_factor` is the amount of shift in H channel and must be in the 480 | interval `[-0.5, 0.5]`. 481 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 482 | Args: 483 | img (PIL Image): PIL Image to be adjusted. 484 | hue_factor (float): How much to shift the hue channel. Should be in 485 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 486 | HSV space in positive and negative direction respectively. 487 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 488 | with complementary colors while 0 gives the original image. 489 | Returns: 490 | PIL Image: Hue adjusted image. 491 | """ 492 | if not (-0.5 <= hue_factor <= 0.5): 493 | raise ValueError( 494 | 'hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 495 | 496 | if not _is_pil_image(img): 497 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 498 | 499 | input_mode = img.mode 500 | if input_mode in {'L', '1', 'I', 'F'}: 501 | return img 502 | 503 | h, s, v = img.convert('HSV').split() 504 | 505 | np_h = np.array(h, dtype=np.uint8) 506 | # uint8 addition take cares of rotation across boundaries 507 | with np.errstate(over='ignore'): 508 | np_h += np.uint8(hue_factor * 255) 509 | h = Image.fromarray(np_h, 'L') 510 | 511 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 512 | return img 513 | 514 | 515 | def adjust_gamma(img, gamma, gain=1): 516 | """Perform gamma correction on an image. 517 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 518 | based on the following equation: 519 | I_out = 255 * gain * ((I_in / 255) ** gamma) 520 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 521 | Args: 522 | img (PIL Image): PIL Image to be adjusted. 523 | gamma (float): Non negative real number. gamma larger than 1 make the 524 | shadows darker, while gamma smaller than 1 make dark regions 525 | lighter. 526 | gain (float): The constant multiplier. 527 | """ 528 | if not _is_pil_image(img): 529 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 530 | 531 | if gamma < 0: 532 | raise ValueError('Gamma should be a non-negative real number') 533 | 534 | input_mode = img.mode 535 | img = img.convert('RGB') 536 | 537 | np_img = np.array(img, dtype=np.float32) 538 | np_img = 255 * gain * ((np_img / 255) ** gamma) 539 | np_img = np.uint8(np.clip(np_img, 0, 255)) 540 | 541 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 542 | return img 543 | 544 | 545 | def rotate(img, angle, resample=False, expand=False, center=None): 546 | """Rotate the image by angle and then (optionally) translate it by (n_columns, n_rows) 547 | Args: 548 | img (PIL Image): PIL Image to be rotated. 549 | angle ({float, int}): In degrees degrees counter clockwise order. 550 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 551 | An optional resampling filter. 552 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 553 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 554 | expand (bool, optional): Optional expansion flag. 555 | If true, expands the output image to make it large enough to hold the entire rotated image. 556 | If false or omitted, make the output image the same size as the input image. 557 | Note that the expand flag assumes rotation around the center and no translation. 558 | center (2-tuple, optional): Optional center of rotation. 559 | Origin is the upper left corner. 560 | Default is the center of the image. 561 | """ 562 | 563 | if not _is_pil_image(img): 564 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 565 | 566 | return img.rotate(angle, resample, expand, center) 567 | 568 | 569 | def to_grayscale(img, num_output_channels=1): 570 | """Convert image to grayscale version of image. 571 | Args: 572 | img (PIL Image): Image to be converted to grayscale. 573 | Returns: 574 | PIL Image: Grayscale version of the image. 575 | if num_output_channels == 1 : returned image is single channel 576 | if num_output_channels == 3 : returned image is 3 channel with r == g == b 577 | """ 578 | if not _is_pil_image(img): 579 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 580 | 581 | if num_output_channels == 1: 582 | img = img.convert('L') 583 | elif num_output_channels == 3: 584 | img = img.convert('L') 585 | np_img = np.array(img, dtype=np.uint8) 586 | np_img = np.dstack([np_img, np_img, np_img]) 587 | img = Image.fromarray(np_img, 'RGB') 588 | else: 589 | raise ValueError('num_output_channels should be either 1 or 3') 590 | 591 | return img --------------------------------------------------------------------------------