├── README.md ├── layers.py ├── losses.py ├── main.py ├── models.py ├── params └── fcn_params.yaml ├── trainer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # GraphBasedGlobalReasoning 2 | PyTorch unofficial implementation of Graph-Based Global Reasoning (http://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Graph-Based_Global_Reasoning_Networks_CVPR_2019_paper.pdf) 3 | 4 | I changed some of details from implementation by authors. 5 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.functional import interpolate 6 | 7 | class GloRe(nn.Module): 8 | def __init__(self, in_channels): 9 | super(GloRe, self).__init__() 10 | self.N = in_channels // 4 11 | self.S = in_channels // 2 12 | 13 | self.theta = nn.Conv2d(in_channels, self.N, 1, 1, 0, bias=False) 14 | self.phi = nn.Conv2d(in_channels, self.S, 1, 1, 0, bias=False) 15 | 16 | self.relu = nn.ReLU() 17 | 18 | self.node_conv = nn.Conv1d(self.N, self.N, 1, 1, 0, bias=False) 19 | self.channel_conv = nn.Conv1d(self.S, self.S, 1, 1, 0, bias=False) 20 | 21 | # このunitに入力された時のチャンネル数と合わせるためのconv layer 22 | self.conv_2 = nn.Conv2d(self.S, in_channels, 1, 1, 0, bias=False) 23 | 24 | def forward(self, x): 25 | batch, C, H, W = x.size() 26 | L = H * W 27 | 28 | B = self.theta(x).view(-1, self.N, L) 29 | 30 | phi = self.phi(x).view(-1, self.S, L) 31 | phi = torch.transpose(phi, 1, 2) 32 | 33 | V = torch.bmm(B, phi) / L #著者コード中にある謎割り算 34 | V = self.relu(self.node_conv(V)) 35 | V = self.relu(self.channel_conv(torch.transpose(V, 1, 2))) 36 | 37 | y = torch.bmm(torch.transpose(B, 1, 2), torch.transpose(V, 1, 2)) 38 | y = y.view(-1, self.S, H, W) 39 | y = self.conv_2(y) 40 | 41 | return x + y 42 | 43 | 44 | class ResBlock(nn.Module): 45 | """ResNet Bottleneck 46 | """ 47 | expansion = 4 48 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, 49 | downsample=None, previous_dilation=1, norm_layer=None): 50 | super(ResBlock, self).__init__() 51 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(out_channels) 53 | self.conv2 = nn.Conv2d( 54 | out_channels, out_channels, kernel_size=3, stride=stride, 55 | padding=dilation, dilation=dilation, bias=False) 56 | self.bn2 = nn.BatchNorm2d(out_channels) 57 | self.conv3 = nn.Conv2d( 58 | out_channels, out_channels, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(out_channels) 60 | self.relu = nn.ReLU(inplace=True) 61 | 62 | if downsample is None: 63 | # 以下の条件を満たさない場合、ConvLayerによるdownsamplingを行う 64 | if stride != 1 or in_channels != out_channels: 65 | downsample = nn.Sequential( 66 | nn.Conv2d(in_channels, out_channels, 67 | kernel_size=1, stride=stride, bias=False), 68 | nn.BatchNorm2d(out_channels)) 69 | 70 | self.downsample = downsample 71 | self.dilation = dilation 72 | self.stride = stride 73 | 74 | def _sum_each(self, x, y): 75 | assert(len(x) == len(y)) 76 | z = [] 77 | for i in range(len(x)): 78 | z.append(x[i]+y[i]) 79 | return z 80 | 81 | def forward(self, x): 82 | residual = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | class ResNet50(nn.Module): 104 | def __init__(self, base_channels=64, multi_grid=False): 105 | 106 | super(ResNet50, self).__init__() 107 | block = ResBlock 108 | self.conv1 = nn.Sequential( 109 | nn.Conv2d(3, base_channels//2, kernel_size=3, stride=2, padding=1, bias=False), 110 | nn.BatchNorm2d(base_channels//2), 111 | nn.ReLU(inplace=True), 112 | nn.Conv2d(base_channels//2, base_channels//2, kernel_size=3, stride=1, padding=1, bias=False), 113 | nn.BatchNorm2d(base_channels//2), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(base_channels//2, base_channels, kernel_size=3, stride=1, padding=1, bias=False), 116 | ) 117 | 118 | self.bn1 = nn.BatchNorm2d(base_channels) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | 122 | self.block1 = make_resblock(block, base_channels*1, base_channels*2, num_blocks=3) 123 | self.block2 = make_resblock(block, base_channels*2, base_channels*4, num_blocks=4, stride=2) 124 | 125 | 126 | self.block3 = make_resblock(block, base_channels*4, base_channels*8, 127 | num_blocks=6, stride=1, dilation=2) 128 | 129 | if multi_grid: 130 | self.block4 = make_resblock(block, base_channels*8, base_channels*16, num_blocks=3, stride=1, 131 | dilation=4, multi_grid=[4, 8, 16]) 132 | 133 | else: 134 | self.block4 = make_resblock(block, base_channels*8, base_channels*16, num_blocks=3, stride=1, 135 | dilation=4) 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.block1(x) 143 | x = self.block2(x) 144 | x = self.block3(x) 145 | x = self.block4(x) 146 | return x 147 | 148 | def make_resblock(block, in_channels, out_channels, num_blocks, stride=1, dilation=1, 149 | downsample=None, multi_grid=None): 150 | layers = [] 151 | if multi_grid is not None: 152 | multi_dilations = multi_grid 153 | else: 154 | multi_dilations = [dilation] * num_blocks 155 | assert len(multi_dilations) == num_blocks, "multi_dilationsの要素数はブロック数と等しくなるように与えてください" 156 | 157 | if multi_grid: 158 | layers.append(block(in_channels, out_channels, stride, dilation=multi_dilations[0], 159 | downsample=downsample, previous_dilation=dilation)) 160 | elif dilation == 1 or dilation == 2: 161 | layers.append(block(in_channels, out_channels, stride, dilation=1, 162 | downsample=downsample, previous_dilation=dilation)) 163 | elif dilation == 4: 164 | layers.append(block(in_channels, out_channels, stride, dilation=2, 165 | downsample=downsample, previous_dilation=dilation)) 166 | else: 167 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 168 | 169 | for i in range(1, num_blocks): 170 | layers.append(block(out_channels, out_channels, dilation=multi_dilations[i], 171 | previous_dilation=dilation)) 172 | 173 | 174 | return nn.Sequential(*layers) 175 | 176 | class FCNHead(nn.Module): 177 | def __init__(self, in_channels, image_size, num_class, use_glore=True): 178 | super(FCNHead, self).__init__() 179 | self.image_size = image_size 180 | 181 | inter_channels = in_channels // 4 182 | self.conv51 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 183 | nn.BatchNorm2d(inter_channels), 184 | nn.ReLU()) 185 | self.use_glore = use_glore 186 | if self.use_glore: 187 | self.gcn = GloRe(inter_channels) 188 | 189 | self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 190 | nn.BatchNorm2d(inter_channels), 191 | nn.ReLU()) 192 | 193 | self.conv53 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 194 | nn.Dropout2d(0.2), 195 | nn.ReLU()) 196 | 197 | self.conv6 = nn.Sequential(nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_class, 3, padding=1, bias=False)) 198 | 199 | def forward(self, x, image_size): 200 | x = self.conv51(x) 201 | if self.use_glore: 202 | x = self.gcn(x) 203 | x = self.conv52(x) 204 | x = interpolate(x, image_size) 205 | x = self.conv53(x) 206 | #x = x[:, :, 1:-1, 1:-1] # conv53のpaddingで拡大してしまった分を除去 207 | 208 | output = self.conv6(x) 209 | 210 | return output 211 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.nn import Module 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | class FocalLoss(Module): 9 | def __init__(self, weight=None, eps=1e-8, gamma=2): 10 | super(FocalLoss, self).__init__() 11 | self.weight = weight 12 | self.gamma = gamma 13 | 14 | def forward(self, input, target): 15 | prob = F.softmax(input, dim=1) 16 | prob = torch.gather(prob, 1, torch.unsqueeze(target, dim=1)) 17 | prob = torch.squeeze(prob, dim=1) 18 | loss = (1 - prob) ** self.gamma * F.cross_entropy(input, target, weight=self.weight, reduction="none") 19 | return torch.mean(loss) 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import yaml 5 | import datetime 6 | import os 7 | 8 | from models import FCNwithGloRe 9 | from trainer import Trainer 10 | from utils import get_dataloader 11 | 12 | def train(params): 13 | model = FCNwithGloRe(params) 14 | trainer = Trainer(params) 15 | image_size = params["common"]["image_size"][0] 16 | train_data_path = params["common"]["train_data_path"] 17 | val_data_path = params["common"]["val_data_path"] 18 | train_batch_size = params["common"]["train_batch_size"] 19 | val_batch_size = params["common"]["val_batch_size"] 20 | num_class = params["common"]["num_class"] 21 | train_dataloader = get_dataloader(train_data_path, train_batch_size, num_class, image_size, is_train=True) 22 | val_dataloder = get_dataloader(val_data_path, val_batch_size, num_class, is_train=False) 23 | 24 | dt_now = datetime.datetime.now() 25 | result_dir = f"./result/{dt_now.year}{dt_now.month:0>2}{dt_now.day:0>2}-{dt_now.hour:0>2}{dt_now.minute:0>2}/" 26 | 27 | os.makedirs(result_dir, exist_ok=True) 28 | 29 | with open(f"{result_dir}/params.yaml", "w") as f: 30 | f.write(yaml.dump(params, default_flow_style=False)) 31 | 32 | trainer.train(model, result_dir, train_dataloader=train_dataloader, val_dataloader=val_dataloder) 33 | 34 | def predict(): 35 | pass 36 | 37 | 38 | if __name__=="__main__": 39 | 40 | args = sys.argv 41 | mode = args[1] 42 | param_file = args[2] 43 | 44 | assert mode in ["train", "predict", "calculate_iou"] 45 | 46 | if mode == "train": 47 | with open(param_file, "r") as f: 48 | params = yaml.load(f) 49 | train(params) 50 | 51 | elif mode == "predict": 52 | model_path = sys.args[3] 53 | 54 | 55 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from distutils.util import strtobool 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from layers import ResNet50, FCNHead 9 | 10 | class FCNwithGloRe(nn.Module): 11 | def __init__(self, params): 12 | super(FCNwithGloRe, self).__init__() 13 | common_params = params["common"] 14 | network_params = params["network"] 15 | 16 | num_class = common_params["num_class"] 17 | image_size = common_params["image_size"] 18 | 19 | use_glore = network_params["use_glore"] 20 | base_channels = network_params["base_channels"] 21 | multi_grid = network_params["multi_grid"] 22 | 23 | self.resnet = ResNet50(base_channels, multi_grid) 24 | base_channels *= 16 25 | self.head = FCNHead(base_channels, image_size, num_class, use_glore) 26 | 27 | def forward(self, x): 28 | image_size = x.size()[2:] 29 | x = self.resnet(x) 30 | out = self.head(x, image_size) 31 | return out 32 | 33 | class UNetwithGloRe(nn.Module): 34 | def __init__(self): 35 | super(UNetwithGloRe, self).__init__() 36 | -------------------------------------------------------------------------------- /params/fcn_params.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | train_data_path: ./data/train/ 3 | val_data_path: ./data/val/ 4 | train_batch_size: 8 5 | val_batch_size: 1 6 | image_size: [256, 256] 7 | num_class: 34 8 | 9 | network: 10 | base_channels: 128 11 | multi_grid: True 12 | use_glore: True 13 | 14 | training: 15 | lr: 0.001 16 | epoch: 80 17 | beta_1: 0.9 18 | checkpoint_root: ./result/ 19 | checkpoint: 10 20 | use_cuda: True 21 | 22 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import time 5 | import datetime 6 | import shutil 7 | import json 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from torch import optim, nn 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | 15 | from losses import FocalLoss 16 | 17 | 18 | class Trainer(object): 19 | def __init__(self, params): 20 | 21 | training_params = params["training"] 22 | 23 | self.lr = training_params["lr"] 24 | self.max_epoch = training_params["epoch"] 25 | self.beta_1 = training_params["beta_1"] 26 | self.checkpoint_root = training_params["checkpoint_root"] 27 | self.checkpoint_epoch = training_params["checkpoint"] 28 | if training_params["use_cuda"]: 29 | self.device = torch.device("cuda") 30 | else: 31 | self.device = torch.device("cpu") 32 | 33 | def train(self, model, result_dir, train_dataloader, val_dataloader=None): 34 | print("create data loader") 35 | model.to(self.device) 36 | optimizer = optim.Adam(model.parameters(), lr=self.lr, betas=(self.beta_1, 0.999), weight_decay=5e-4) 37 | scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6) 38 | #class_weights = train_dataloader.dataset.class_weights.to(self.device) 39 | #criterion = nn.CrossEntropyLoss(weight=class_weights) 40 | criterion = FocalLoss() 41 | start_time = time.time() 42 | train_loss_list = [] 43 | train_acc_list = [] 44 | val_iou_list = [] 45 | val_iou_dict_dict = {} 46 | 47 | print("training starts") 48 | for epoch in range(1, self.max_epoch+1): 49 | train_loss = 0 50 | train_acc = 0 51 | val_loss = 0 52 | val_acc = 0 53 | for i, (images, labels) in enumerate(train_dataloader): 54 | images = images.to(self.device) 55 | labels = labels.to(self.device) 56 | pred = model(images) 57 | loss = criterion(pred, labels) 58 | 59 | optimizer.zero_grad() 60 | loss.backward() 61 | optimizer.step() 62 | 63 | train_loss += loss.item() 64 | ps = torch.exp(pred) 65 | equality = (labels.data == ps.max(dim=1)[1]) 66 | train_acc += equality.type(torch.FloatTensor).mean().item() 67 | train_loss /= len(train_dataloader) 68 | train_acc /= len(train_dataloader) 69 | 70 | model.eval() 71 | confusion_dict = calculate_confusion(model, val_dataloader, self.device) 72 | iou_dict = calculate_IoU(confusion_dict) 73 | val_acc = iou_dict["mean"] 74 | val_iou_dict_dict.update({epoch: iou_dict}) 75 | model.train() 76 | 77 | train_loss_list.append(train_loss) 78 | train_acc_list.append(train_acc) 79 | val_iou_list.append(val_acc) 80 | 81 | with open(f"{result_dir}/class_iou.json", "w") as f: 82 | json.dump(val_iou_dict_dict, f) 83 | 84 | if epoch % self.checkpoint_epoch == 0: 85 | path = f"{result_dir}/{epoch:0>4}.pth" 86 | torch.save(model.state_dict(), path) 87 | 88 | scheduler.step() 89 | 90 | elapsed = time.time() - start_time 91 | print(f"epoch{epoch} done. " 92 | f"train_loss:{train_loss:.4f}, train_acc: {train_acc:.4f}, " 93 | f"val_acc: {val_acc:.4f}" 94 | f" ({elapsed:.4f} sec)") 95 | result_df = pd.DataFrame({"epoch":list(range(1, self.max_epoch+1)), 96 | "train_loss": train_loss_list, 97 | "train_acc": train_acc_list, 98 | "val_acc": val_iou_list}) 99 | result_df.to_csv(f"{result_dir}/result.csv", index=False) 100 | path = f"{result_dir}/{self.max_epoch:0>4}.pth" 101 | torch.save(model.state_dict(), path) 102 | path_optim = f"{result_dir}/{self.max_epoch:0>4}_optim.pth" 103 | torch.save(optimizer.state_dict(), path_optim) 104 | with open(f"{result_dir}/class_iou.json", "w") as f: 105 | json.dump(val_iou_dict_dict, f) 106 | print(f"training ends ({elapsed} sec)") 107 | return model 108 | 109 | def calculate_confusion(model, dataloader, device, num_class=34): 110 | result_dict = {} 111 | 112 | for i in range(num_class): 113 | result_dict.update({i: {"TP":0, 114 | "FPFN": 0}}) 115 | model.eval() 116 | with torch.no_grad(): 117 | for i, (images, labels) in enumerate(dataloader): 118 | images = images.to(device) 119 | labels = labels.cpu().numpy() 120 | 121 | pred = model(images).cpu().detach().numpy() 122 | pred = np.argmax(pred, 1) 123 | for i in range(num_class): 124 | 125 | TP = np.logical_and(labels==i, pred==i).sum() 126 | FPFN = np.logical_xor(labels==i, pred==i).sum() 127 | result_dict[i]["TP"] += TP 128 | result_dict[i]["FPFN"] += FPFN 129 | 130 | del pred 131 | 132 | return result_dict 133 | 134 | def calculate_IoU(confusion_dict): 135 | iou_dict = {} 136 | mean_IoU = 0 137 | for i in confusion_dict.keys(): 138 | TP = confusion_dict[i]["TP"] 139 | FPFN = confusion_dict[i]["FPFN"] 140 | if TP+FPFN != 0: 141 | IoU = TP/(TP+FPFN) 142 | else: 143 | IoU = 0 144 | iou_dict.update({i: IoU}) 145 | mean_IoU += IoU 146 | 147 | mean_IoU /= len(iou_dict.keys()) 148 | iou_dict.update({"mean": mean_IoU}) 149 | return iou_dict -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import re 5 | 6 | import numbers 7 | import random 8 | import glob 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision.transforms import Compose 14 | from torchvision.transforms import functional as F 15 | from skimage.io import imread 16 | from PIL import Image 17 | 18 | 19 | class SegmentationDataSet(Dataset): 20 | gt_dir_name = "GTs" 21 | image_dir_name = "images" 22 | def __init__(self, path_root, num_class, transforms=None, use_weights=True): 23 | super(SegmentationDataSet, self).__init__() 24 | 25 | self.transforms = transforms 26 | self.images, self.labels = self._make_image_list(path_root) 27 | #if use_weights: 28 | #self.class_weights = self._get_class_weights(num_class) 29 | 30 | 31 | def __getitem__(self, idx): 32 | image = Image.open(self.images[idx]) 33 | label = imread(self.labels[idx], as_gray=True) 34 | 35 | if self.transforms is not None: 36 | image, label = self.transforms(image, label) 37 | 38 | return image, label 39 | 40 | def __len__(self): 41 | return len(self.images) 42 | 43 | def _make_image_list(self, path_root): 44 | gt_dir = path_root + self.gt_dir_name + "/" 45 | img_dir = path_root + self.image_dir_name + "/" 46 | 47 | img_path_list = glob.glob(f"{img_dir}/*/*.png") 48 | gt_path_list = [self.__convert_GT_name(path, gt_dir) for path in img_path_list] 49 | 50 | return img_path_list, gt_path_list 51 | 52 | def _get_class_weights(self, num_class): 53 | count_dict = {} 54 | class_weights = torch.zeros(num_class) 55 | total_counts = 0 56 | for i in range(num_class): 57 | count_dict.update({i: 0}) 58 | for label_path in self.labels: 59 | label = imread(label_path) 60 | for i in range(num_class): 61 | count = (label == i).sum() 62 | count_dict[i] += count / 1000 63 | total_counts += count / 1000 64 | for i in range(num_class): 65 | class_weights[i] += total_counts / (count_dict[i] * num_class) 66 | class_weights = 1 + torch.log(class_weights) 67 | print(f"class weights: {class_weights}") 68 | return class_weights 69 | 70 | def __convert_GT_name(self, path, gt_dir): 71 | path_list = path.split("\\") 72 | basename = path_list[-1] 73 | parent_dir = path_list[-2] 74 | gt_filename = "_".join(basename.split("_")[:3]) + "_gtFine_labelIds.png" 75 | return f"{gt_dir}{parent_dir}/{gt_filename}" 76 | 77 | 78 | def get_dataloader(path_root, batch_size, num_class, image_size=None, is_train=True): 79 | if is_train: 80 | train_transforms = Compose( 81 | [RandomCrop(image_size), RandomHorizontalFlip(), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 82 | return DataLoader(SegmentationDataSet(path_root, 83 | num_class, 84 | transforms=train_transforms, 85 | use_weights=True), 86 | batch_size=batch_size) 87 | else: 88 | valid_transforms = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 89 | return DataLoader(SegmentationDataSet(path_root, 90 | num_class, 91 | transforms=valid_transforms, 92 | use_weights=False), 93 | batch_size=batch_size) 94 | 95 | class RandomCrop(object): 96 | """Crop the given PIL Image at a random location. 97 | 98 | Args: 99 | size (sequence or int): Desired output size of the crop. If size is an 100 | int instead of sequence like (h, w), a square crop (size, size) is 101 | made. 102 | padding (int or sequence, optional): Optional padding on each border 103 | of the image. Default is None, i.e no padding. If a sequence of length 104 | 4 is provided, it is used to pad left, top, right, bottom borders 105 | respectively. If a sequence of length 2 is provided, it is used to 106 | pad left/right, top/bottom borders, respectively. 107 | pad_if_needed (boolean): It will pad the image if smaller than the 108 | desired size to avoid raising an exception. Since cropping is done 109 | after padding, the padding seems to be done at a random offset. 110 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 111 | length 3, it is used to fill R, G, B channels respectively. 112 | This value is only used when the padding_mode is constant 113 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 114 | 115 | - constant: pads with a constant value, this value is specified with fill 116 | 117 | - edge: pads with the last value on the edge of the image 118 | 119 | - reflect: pads with reflection of image (without repeating the last value on the edge) 120 | 121 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 122 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 123 | 124 | - symmetric: pads with reflection of image (repeating the last value on the edge) 125 | 126 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 127 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 128 | 129 | """ 130 | 131 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 132 | if isinstance(size, numbers.Number): 133 | self.size = (int(size), int(size)) 134 | else: 135 | self.size = size 136 | self.padding = padding 137 | self.pad_if_needed = pad_if_needed 138 | self.fill = fill 139 | self.padding_mode = padding_mode 140 | 141 | @staticmethod 142 | def get_params(img, output_size): 143 | """Get parameters for ``crop`` for a random crop. 144 | 145 | Args: 146 | img (PIL Image): Image to be cropped. 147 | output_size (tuple): Expected output size of the crop. 148 | 149 | Returns: 150 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 151 | """ 152 | w, h = img.size 153 | th, tw = output_size 154 | if w == tw and h == th: 155 | return 0, 0, h, w 156 | 157 | i = random.randint(0, h - th) 158 | j = random.randint(0, w - tw) 159 | return i, j, th, tw 160 | 161 | def __call__(self, img, label): 162 | """ 163 | Args: 164 | img (PIL Image): Image to be cropped. 165 | 166 | Returns: 167 | PIL Image: Cropped image. 168 | """ 169 | if self.padding is not None: 170 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 171 | 172 | # pad the width if needed 173 | if self.pad_if_needed and img.size[0] < self.size[1]: 174 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 175 | # pad the height if needed 176 | if self.pad_if_needed and img.size[1] < self.size[0]: 177 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 178 | 179 | i, j, h, w = self.get_params(img, self.size) 180 | 181 | return F.crop(img, i, j, h, w), label[i:i+h, j:j+h] 182 | 183 | 184 | class RandomHorizontalFlip(object): 185 | """Horizontally flip the given PIL Image randomly with a given probability. 186 | 187 | Args: 188 | p (float): probability of the image being flipped. Default value is 0.5 189 | """ 190 | 191 | def __init__(self, p=0.5): 192 | self.p = p 193 | 194 | def __call__(self, img, label): 195 | """ 196 | Args: 197 | img (PIL Image): Image to be flipped. 198 | 199 | Returns: 200 | PIL Image: Randomly flipped image. 201 | """ 202 | if random.random() < self.p: 203 | return F.hflip(img), np.fliplr(label) - np.zeros_like(label) 204 | return img, label 205 | 206 | def __repr__(self): 207 | return self.__class__.__name__ + '(p={})'.format(self.p) 208 | 209 | class ToTensor(object): 210 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 211 | 212 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 213 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 214 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 215 | or if the numpy.ndarray has dtype = np.uint8 216 | 217 | In the other cases, tensors are returned without scaling. 218 | """ 219 | 220 | def __call__(self, pic, label): 221 | """ 222 | Args: 223 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 224 | 225 | Returns: 226 | Tensor: Converted image. 227 | """ 228 | return F.to_tensor(pic), torch.LongTensor(label) 229 | 230 | def __repr__(self): 231 | return self.__class__.__name__ + '()' 232 | 233 | class Normalize(object): 234 | """Normalize a tensor image with mean and standard deviation. 235 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 236 | will normalize each channel of the input ``torch.*Tensor`` i.e. 237 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 238 | 239 | .. note:: 240 | This transform acts out of place, i.e., it does not mutates the input tensor. 241 | 242 | Args: 243 | mean (sequence): Sequence of means for each channel. 244 | std (sequence): Sequence of standard deviations for each channel. 245 | inplace(bool,optional): Bool to make this operation in-place. 246 | 247 | """ 248 | 249 | def __init__(self, mean, std, inplace=False): 250 | self.mean = mean 251 | self.std = std 252 | self.inplace = inplace 253 | 254 | def __call__(self, tensor, label): 255 | """ 256 | Args: 257 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 258 | 259 | Returns: 260 | Tensor: Normalized Tensor image. 261 | """ 262 | return F.normalize(tensor, self.mean, self.std), label 263 | 264 | 265 | def __repr__(self): 266 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 267 | 268 | 269 | class Compose(object): 270 | """Composes several transforms together. 271 | 272 | Args: 273 | transforms (list of ``Transform`` objects): list of transforms to compose. 274 | 275 | Example: 276 | 277 | """ 278 | 279 | def __init__(self, transforms): 280 | self.transforms = transforms 281 | 282 | def __call__(self, img, label): 283 | for t in self.transforms: 284 | img, label = t(img, label) 285 | return img, label 286 | 287 | def __repr__(self): 288 | format_string = self.__class__.__name__ + '(' 289 | for t in self.transforms: 290 | format_string += '\n' 291 | format_string += ' {0}'.format(t) 292 | format_string += '\n)' 293 | return format_string 294 | --------------------------------------------------------------------------------