├── .idea ├── .gitignore ├── UNet_Seal_Remove.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── config.py ├── data ├── mytest │ └── output ├── test │ └── 11.png ├── train │ ├── 1.png │ └── 1.xml ├── train_cleaned │ └── 1.png ├── valid │ ├── 2.png │ └── 2.xml └── valid_cleaned │ └── 2.png ├── datasets.py ├── predict.py ├── train.py ├── transforms.py ├── unet ├── __init__.py ├── unet_model.py └── unet_parts.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.idea/UNet_Seal_Remove.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNet Seal Remove 2 | UNet 实现文档印章消除 3 | 4 | # Requirement 5 | pytorch==1.5 6 | 7 | opencv-python 4.2 8 | 9 | numpy 10 | 11 | # Use 12 | 1. data 的目录: 13 | 14 | - test:测试集图片的路径 15 | - mytest:测试结果的输出路径 16 | - train:训练集图片的路径,包含含印章的图片以及标注印章位置的 xml 文件 17 | - train_cleaned:训练集图片人工去除印章后的标签 18 | - valid:验证集图片的路径,包含含印章的图片以及标注印章位置的xml文件 19 | - valid_cleaned:验证集图片人工去除印章后的标签 20 | 21 | 2. config.py 22 | 23 | 设置参数,包括文件路径、模型结构参数和训练的参数等。 24 | 25 | 3. train.py 26 | 27 | 运行 python train.py 训练模型。 28 | 29 | 4. predict.py 30 | 31 | 运行 python predict.py 测试。 32 | 33 | # Note: 34 | 1. 由于作者所使用的图像分辨率极高,在训练和测试时从完整图像中扣出包含印章的区域(ImageSize=512*512),然后进行训练。 35 | 如果图片的分辨率适中或者显存足够大,可以跳过此步骤,无需进行印章标注,直接使用原图进行 UNet 训练。 36 | 37 | 2. 从原图中扣出印章区域也可以使用 yolo 代替。 38 | 39 | 40 | # Reference 41 | [unet-denoising-dirty-documents](https://github.com/1024210879/unet-denoising-dirty-documents). 42 | 43 | [Pytorch-Unet](https://github.com/milesial/Pytorch-UNet) 44 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # =============================== 路径 ========================================= 5 | # 模型保存文件 6 | if not os.path.exists('./weight'): 7 | os.mkdir('./weight') 8 | 9 | weight = './weight/weight.pth' 10 | weight_with_optimizer = './weight/weight_with_optimizer.pth' 11 | best_model = './weight/best_model.pth' 12 | best_model_with_optimizer = './weight/best_model_with_optimizer.pth' 13 | 14 | # =============================== 训练 ========================================= 15 | # 选择训练硬件设备 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | # 裁剪印章区域后图像的大小 18 | image_size = 512 19 | # 训练参数 20 | n_channels = 1 21 | n_classes = 2 22 | LR = 1e-5 23 | EPOCH = 1000 24 | BATCH_SIZE = 4 25 | 26 | # =============================== 测试 ========================================= 27 | test_image = './data/train/10.png' 28 | test_boxes = './data/train/10.xml' 29 | output_path = './data/mytest' 30 | -------------------------------------------------------------------------------- /data/mytest/output: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/test/11.png: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/train/1.png: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/train/1.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/train_cleaned/1.png: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/valid/2.png: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/valid/2.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/valid_cleaned/2.png: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import cv2 3 | import os 4 | import transforms as Transforms 5 | from utils import * 6 | 7 | 8 | class UNetDataset(Dataset): 9 | def __init__(self, dir_train, dir_mask, transform=None): 10 | self.dirTrain = dir_train 11 | self.dirMask = dir_mask 12 | self.transform = transform 13 | self.dataTrain = [os.path.join(self.dirTrain, filename) 14 | for filename in os.listdir(self.dirTrain) 15 | if filename.endswith('.jpg') or filename.endswith('.png')] 16 | self.dataBox = [os.path.join(self.dirTrain, filename) 17 | for filename in os.listdir(self.dirTrain) 18 | if filename.endswith('.xml')] 19 | self.dataMask = [os.path.join(self.dirMask, filename) 20 | for filename in os.listdir(self.dirMask) 21 | if filename.endswith('.jpg') or filename.endswith('.png')] 22 | self.trainDataSize = len(self.dataTrain) 23 | self.maskDataSize = len(self.dataMask) 24 | self.dataBoxSize = len(self.dataBox) 25 | 26 | def __getitem__(self, index): 27 | assert self.trainDataSize == self.maskDataSize 28 | assert self.trainDataSize == self.dataBoxSize 29 | 30 | image = cv2.imread(self.dataTrain[index]) 31 | label = cv2.imread(self.dataMask[index]) 32 | boxfile = self.dataBox[index] 33 | 34 | image, label = cropImage(image, boxfile, label) 35 | 36 | if self.transform: 37 | for method in self.transform: 38 | image, label = method(image, label) 39 | 40 | image = image.transpose((2, 0, 1)) 41 | label = label.transpose((2, 0, 1)) 42 | 43 | return image, label 44 | 45 | def __len__(self): 46 | assert self.trainDataSize == self.maskDataSize 47 | assert self.trainDataSize == self.dataBoxSize 48 | 49 | return self.trainDataSize 50 | 51 | 52 | if __name__ == '__main__': 53 | from torch.utils.data import DataLoader 54 | 55 | transforms = [ 56 | # Transforms.RandomCrop(2300, 2300), 57 | Transforms.RondomFlip(), 58 | Transforms.RandomRotate(15), 59 | Transforms.Log(0.5), 60 | Transforms.Blur(0.2), 61 | Transforms.ToTensor(), 62 | Transforms.ToGray() 63 | ] 64 | dataset = UNetDataset('./data/train', './data/train_cleaned', transform=transforms) 65 | dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) 66 | 67 | for index, (batch_x, batch_y) in enumerate(dataLoader): 68 | print(batch_x.size(), batch_y.size()) # shape:(batch_size, 1, h, w) 1表示图像是灰度图 69 | 70 | dis = batch_y[0][0].numpy() # shape:(h,w) 71 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import os 5 | from unet import UNet 6 | from utils import * 7 | import config 8 | 9 | # load net 10 | print('load net') 11 | net = UNet(n_channels=config.n_channels, n_classes=config.n_classes).to(config.device) 12 | if os.path.exists(config.best_model): 13 | checkpoint = torch.load(config.best_model, map_location='cpu') 14 | net.load_state_dict(checkpoint['net']) 15 | else: 16 | exit(0) 17 | 18 | net.eval() 19 | # load img 20 | print('load img') 21 | image = cv2.imread(config.test_image) 22 | img, _ = cropImage(image, config.test_boxes, image) 23 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 24 | img = img.reshape(img.shape[0], img.shape[1]) 25 | print(img.shape) 26 | 27 | input = torch.from_numpy(img[np.newaxis][np.newaxis]).float() / 255 28 | output = net(input.to(config.device)) 29 | print(output.shape) 30 | output = output[0, 0].detach().data.cpu().numpy() 31 | res = np.concatenate((img / 255, output), axis=1) 32 | 33 | cv2.imwrite(os.path.join(config.output_path, "0.png"), res * 255) 34 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | 训练模型指令: 4 | python train.py 5 | 6 | @author: libo 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | from torch import optim 11 | import os 12 | from unet import UNet 13 | from datasets import UNetDataset 14 | import transforms as Transforms 15 | from torch.utils.data import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | import config 18 | 19 | 20 | def train(): 21 | transforms = [ 22 | Transforms.RondomFlip(), 23 | Transforms.RandomRotate(15), 24 | Transforms.Log(0.5), 25 | Transforms.Blur(0.2), 26 | Transforms.ToGray(), 27 | Transforms.ToTensor() 28 | ] 29 | train_dataset = UNetDataset('./data/train/', './data/train_cleaned/', transform=transforms) 30 | train_dataLoader = DataLoader(dataset=train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=0) 31 | 32 | valid_dataset = UNetDataset('./data/valid/', './data/valid_cleaned/', transform=transforms) 33 | valid_dataLoader = DataLoader(dataset=valid_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=0) 34 | 35 | net = UNet(n_channels=config.n_channels, n_classes=config.n_classes).to(config.device) 36 | writer = SummaryWriter() 37 | optimizer = optim.Adam(net.parameters(), lr=config.LR) 38 | if config.n_classes > 1: 39 | loss_func = nn.CrossEntropyLoss().to(config.device) 40 | else: 41 | loss_func = nn.BCEWithLogitsLoss().to(config.device) 42 | best_loss = float('inf') 43 | 44 | if os.path.exists(config.weight_with_optimizer): 45 | checkpoint = torch.load(config.weight_with_optimizer, map_location='cpu') 46 | net.load_state_dict(checkpoint['net']) 47 | optimizer.load_state_dict(checkpoint['optimizer']) 48 | print('load weight') 49 | 50 | for epoch in range(config.EPOCH): 51 | train_loss = 0 52 | net.train() 53 | for step, (batch_x, batch_y) in enumerate(train_dataLoader): 54 | batch_x = batch_x.to(device=config.device) 55 | batch_y = batch_y.squeeze(1).to(device=config.device) 56 | output = net(batch_x) 57 | loss = loss_func(output, batch_y) 58 | train_loss += loss.item() 59 | if loss < best_loss: 60 | best_loss = loss 61 | torch.save({'net': net.state_dict(), 'optimizer': optimizer.state_dict()}, 62 | config.best_model_with_optimizer) 63 | torch.save({'net': net.state_dict()}, config.best_model) 64 | optimizer.zero_grad() 65 | loss.backward() 66 | optimizer.step() 67 | 68 | net.eval() 69 | eval_loss = 0 70 | for step, (batch_x, batch_y) in enumerate(valid_dataLoader): 71 | batch_x = batch_x.to(device=config.device) 72 | batch_y = batch_y.squeeze(1).to(device=config.device) 73 | output = net(batch_x) 74 | valid_loss = loss_func(output, batch_y) 75 | eval_loss += valid_loss.item() 76 | 77 | writer.add_scalar("train_loss", train_loss, epoch) 78 | writer.add_scalar("eval_loss", eval_loss, epoch) 79 | print("*" * 80) 80 | print('epoch: %d | train loss: %.4f | valid loss: %.4f' % (epoch, train_loss, eval_loss)) 81 | print("*" * 80) 82 | 83 | if (epoch + 1) % 10 == 0: 84 | torch.save({ 85 | 'net': net.state_dict(), 86 | 'optimizer': optimizer.state_dict() 87 | }, config.weight_with_optimizer) 88 | torch.save({ 89 | 'net': net.state_dict() 90 | }, config.weight) 91 | print('saved') 92 | 93 | writer.close() 94 | 95 | 96 | if __name__ == '__main__': 97 | train() 98 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ToGray(object): 7 | def __init__(self): 8 | pass 9 | 10 | def __call__(self, image, label): 11 | if len(image.shape) == 3: 12 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 13 | image = image.reshape(image.shape[0], image.shape[1], 1) 14 | if len(label.shape) == 3: 15 | label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY) 16 | label = label.reshape(label.shape[0], label.shape[1], 1) 17 | return image, label 18 | 19 | 20 | class RondomFlip(object): 21 | def __init__(self): 22 | pass 23 | 24 | def __call__(self, image, label): 25 | degree = random.random() 26 | if degree <= 0.33: 27 | image = cv2.flip(image, 0) 28 | label = cv2.flip(label, 0) 29 | elif degree <= 0.66: 30 | image = cv2.flip(image, 1) 31 | label = cv2.flip(label, 1) 32 | return image, label 33 | 34 | 35 | class RandomRotate(object): 36 | def __init__(self, angle): 37 | self.angle = angle 38 | 39 | def __call__(self, image, label): 40 | angle = random.random() * self.angle 41 | angle = angle if random.random() < 0.5 else -angle 42 | h, w = image.shape[0], image.shape[1] 43 | scale = random.random() * 0.4 + 0.9 44 | matRotate = cv2.getRotationMatrix2D((w * 0.5, h * 0.5), angle, scale) 45 | image = cv2.warpAffine(image, matRotate, (w, h)) 46 | label = cv2.warpAffine(label, matRotate, (w, h), 47 | borderMode=cv2.BORDER_CONSTANT, borderValue=255) 48 | return image, label 49 | 50 | 51 | class RandomCrop(object): 52 | def __init__(self, crop_h, crop_w): 53 | self.crop_h, self.crop_w = crop_h, crop_w 54 | 55 | def __call__(self, image, label): 56 | h, w = image.shape[0], image.shape[1] 57 | crop_x = int(random.random() * (w - self.crop_w)) 58 | crop_y = int(random.random() * (h - self.crop_h)) 59 | image = image[crop_y: crop_y + self.crop_h, crop_x: crop_x + self.crop_w] 60 | label = label[crop_y: crop_y + self.crop_h, crop_x: crop_x + self.crop_w] 61 | return image, label 62 | 63 | 64 | class EqualizeHist(object): 65 | def __init__(self, degree): 66 | self.degree = degree 67 | 68 | def __call__(self, image, label): 69 | if random.random() < self.degree: 70 | image = cv2.equalizeHist(image) 71 | return image, label 72 | 73 | 74 | class Blur(object): 75 | def __init__(self, degree): 76 | self.degree = degree 77 | 78 | def __call__(self, image, label): 79 | if random.random() < self.degree: 80 | image = cv2.blur(image, (3, 3)) 81 | return image, label 82 | 83 | 84 | class Log(object): 85 | def __init__(self, degree): 86 | self.degree = degree 87 | 88 | def __call__(self, image, label): 89 | if random.random() < self.degree: 90 | image = np.log(1 + image.astype(np.float32) / 255) * 255 91 | return image.astype(np.uint8), label 92 | 93 | 94 | class ToTensor(object): 95 | def __init__(self): 96 | pass 97 | 98 | def __call__(self, image, label): 99 | image = image / 255 100 | label = (255 - label) // 50 101 | label[label > 1] = 1 102 | 103 | return image.astype(np.float32), label.astype(np.int64) 104 | 105 | 106 | if __name__ == '__main__': 107 | transforms = [ 108 | RandomCrop(2300, 2300), 109 | RondomFlip(), 110 | RandomRotate(15), 111 | Log(0.5), 112 | Blur(0.2), 113 | ToTensor(), 114 | ToGray() 115 | ] 116 | 117 | image = cv2.imread('./data/train/2.png') 118 | label = cv2.imread('./data/train_cleaned/2.png') 119 | 120 | for i in range(1): 121 | img1, img2 = image, label 122 | for index, transform in enumerate(transforms): 123 | img1, img2 = transform(img1, img2) 124 | img = np.concatenate((img1, img2), axis=1) 125 | cv2.imwrite("./data/mytest/{}.png".format(index), img * 255) 126 | -------------------------------------------------------------------------------- /unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNet 2 | -------------------------------------------------------------------------------- /unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | from .unet_parts import * 3 | 4 | 5 | class UNet(nn.Module): 6 | def __init__(self, n_channels, n_classes, bilinear=True): 7 | super(UNet, self).__init__() 8 | self.n_channels = n_channels 9 | self.n_classes = n_classes 10 | self.bilinear = bilinear 11 | 12 | self.inc = DoubleConv(n_channels, 64) 13 | self.down1 = Down(64, 128) 14 | self.down2 = Down(128, 256) 15 | self.down3 = Down(256, 512) 16 | factor = 2 if bilinear else 1 17 | self.down4 = Down(512, 1024 // factor) 18 | self.up1 = Up(1024, 512 // factor, bilinear) 19 | self.up2 = Up(512, 256 // factor, bilinear) 20 | self.up3 = Up(256, 128 // factor, bilinear) 21 | self.up4 = Up(128, 64, bilinear) 22 | self.outc = OutConv(64, n_classes) 23 | 24 | def forward(self, x): 25 | x1 = self.inc(x) 26 | x2 = self.down1(x1) 27 | x3 = self.down2(x2) 28 | x4 = self.down3(x3) 29 | x5 = self.down4(x4) 30 | x = self.up1(x5, x4) 31 | x = self.up2(x, x3) 32 | x = self.up3(x, x2) 33 | x = self.up4(x, x1) 34 | out = self.outc(x) 35 | return out 36 | -------------------------------------------------------------------------------- /unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """ (convolution => [BN] => ReLU) * 2 """ 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | # 定义了两种上采样的方法:双线性插值和反卷积 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | def forward(self, x1, x2): 57 | # x1接收的是上采样的数据,x2接收的是特征融合的数据。 58 | x1 = self.up(x1) 59 | # 先对小的feature map进行padding,再进行两者concat。 60 | # input is CHW 61 | diffY = x2.size()[2] - x1.size()[2] 62 | diffX = x2.size()[3] - x1.size()[3] 63 | 64 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 65 | diffY // 2, diffY - diffY // 2]) 66 | # if you have padding issues, see 67 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 68 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 69 | x = torch.cat([x2, x1], dim=1) 70 | return self.conv(x) 71 | 72 | 73 | class OutConv(nn.Module): 74 | # 根据分割数量整合输出通道。 75 | def __init__(self, in_channels, out_channels): 76 | super(OutConv, self).__init__() 77 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 78 | 79 | def forward(self, x): 80 | return self.conv(x) 81 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | import numpy.random as npr 5 | import cv2 6 | import config 7 | 8 | image_size = config.image_size 9 | 10 | 11 | def getBoxes(xml_file): 12 | boxes = [] 13 | tree = ET.parse(xml_file) 14 | root = tree.getroot() 15 | for member in root.findall('object'): 16 | box = [int(member[4][0].text), 17 | int(float(member[4][1].text)), 18 | int(member[4][2].text), 19 | int(member[4][3].text) 20 | ] 21 | boxes.append(box) 22 | return boxes 23 | 24 | 25 | def getImage(image, label, boxes): 26 | cropCode = random.choice([-1] + [i for i in range(len(boxes))]) 27 | height, width = image.shape[0], image.shape[1] 28 | if cropCode != -1: 29 | box = boxes[cropCode] 30 | xmin, ymin, xmax, ymax = box[0], box[1], box[2], box[3] 31 | w = xmax - xmin 32 | h = ymax - ymin 33 | if max(w, h) < image_size: 34 | nx = npr.randint(max((xmax - image_size), 0), xmin) 35 | ny = npr.randint(max((ymax - image_size), 0), ymin) 36 | if nx + image_size > width: 37 | nx = width - image_size 38 | if ny + image_size > height: 39 | ny = height - image_size 40 | cropped_im = image[ny: ny + image_size, nx: nx + image_size, :] 41 | cropped_la = label[ny: ny + image_size, nx: nx + image_size, :] 42 | else: 43 | nx = npr.randint(xmax - max(w, h) - 100, xmin) 44 | ny = npr.randint(ymax - max(w, h) - 100, ymin) 45 | if nx + max(w, h) + 100 > width: 46 | nx = width - max(w, h) - 100 47 | if ny + max(w, h) + 100 > height: 48 | ny = height - max(w, h) - 100 49 | cropped_im = image[ny: ny + max(w, h) + 100, nx: nx + max(w, h) + 100, :] 50 | cropped_la = label[ny: ny + max(w, h) + 100, nx: nx + max(w, h) + 100, :] 51 | else: 52 | crop_x = int(random.random() * (width - image_size)) 53 | crop_y = int(random.random() * (height - image_size)) 54 | cropped_im = image[crop_y: crop_y + image_size, crop_x: crop_x + image_size] 55 | cropped_la = label[crop_y: crop_y + image_size, crop_x: crop_x + image_size] 56 | 57 | return cropped_im, cropped_la 58 | 59 | 60 | def cropImage(image, boxfile, label): 61 | boxes = getBoxes(boxfile) 62 | image, label = getImage(image, label, boxes) 63 | image = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_LINEAR) 64 | label = cv2.resize(label, (image_size, image_size), interpolation=cv2.INTER_LINEAR) 65 | return image, label 66 | 67 | 68 | if __name__ == '__main__': 69 | dirTrain = './data/train' 70 | dataBox = [os.path.join(dirTrain, filename) 71 | for filename in os.listdir(dirTrain) 72 | if filename.endswith('.xml')] 73 | dataTrain = [os.path.join(dirTrain, filename) 74 | for filename in os.listdir(dirTrain) 75 | if filename.endswith('.jpg') or filename.endswith('.png')] 76 | 77 | for i in range(len(dataTrain)): 78 | image = cv2.imread(dataTrain[i]) 79 | boxes = getBoxes(dataBox[i]) 80 | crop_image, cropped_label = getImage(image, boxes) 81 | cv2.imwrite('./test/{}.png'.format(i), crop_image) 82 | --------------------------------------------------------------------------------