├── .idea
├── .gitignore
├── UNet_Seal_Remove.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── config.py
├── data
├── mytest
│ └── output
├── test
│ └── 11.png
├── train
│ ├── 1.png
│ └── 1.xml
├── train_cleaned
│ └── 1.png
├── valid
│ ├── 2.png
│ └── 2.xml
└── valid_cleaned
│ └── 2.png
├── datasets.py
├── predict.py
├── train.py
├── transforms.py
├── unet
├── __init__.py
├── unet_model.py
└── unet_parts.py
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.idea/UNet_Seal_Remove.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNet Seal Remove
2 | UNet 实现文档印章消除
3 |
4 | # Requirement
5 | pytorch==1.5
6 |
7 | opencv-python 4.2
8 |
9 | numpy
10 |
11 | # Use
12 | 1. data 的目录:
13 |
14 | - test:测试集图片的路径
15 | - mytest:测试结果的输出路径
16 | - train:训练集图片的路径,包含含印章的图片以及标注印章位置的 xml 文件
17 | - train_cleaned:训练集图片人工去除印章后的标签
18 | - valid:验证集图片的路径,包含含印章的图片以及标注印章位置的xml文件
19 | - valid_cleaned:验证集图片人工去除印章后的标签
20 |
21 | 2. config.py
22 |
23 | 设置参数,包括文件路径、模型结构参数和训练的参数等。
24 |
25 | 3. train.py
26 |
27 | 运行 python train.py 训练模型。
28 |
29 | 4. predict.py
30 |
31 | 运行 python predict.py 测试。
32 |
33 | # Note:
34 | 1. 由于作者所使用的图像分辨率极高,在训练和测试时从完整图像中扣出包含印章的区域(ImageSize=512*512),然后进行训练。
35 | 如果图片的分辨率适中或者显存足够大,可以跳过此步骤,无需进行印章标注,直接使用原图进行 UNet 训练。
36 |
37 | 2. 从原图中扣出印章区域也可以使用 yolo 代替。
38 |
39 |
40 | # Reference
41 | [unet-denoising-dirty-documents](https://github.com/1024210879/unet-denoising-dirty-documents).
42 |
43 | [Pytorch-Unet](https://github.com/milesial/Pytorch-UNet)
44 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | # =============================== 路径 =========================================
5 | # 模型保存文件
6 | if not os.path.exists('./weight'):
7 | os.mkdir('./weight')
8 |
9 | weight = './weight/weight.pth'
10 | weight_with_optimizer = './weight/weight_with_optimizer.pth'
11 | best_model = './weight/best_model.pth'
12 | best_model_with_optimizer = './weight/best_model_with_optimizer.pth'
13 |
14 | # =============================== 训练 =========================================
15 | # 选择训练硬件设备
16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | # 裁剪印章区域后图像的大小
18 | image_size = 512
19 | # 训练参数
20 | n_channels = 1
21 | n_classes = 2
22 | LR = 1e-5
23 | EPOCH = 1000
24 | BATCH_SIZE = 4
25 |
26 | # =============================== 测试 =========================================
27 | test_image = './data/train/10.png'
28 | test_boxes = './data/train/10.xml'
29 | output_path = './data/mytest'
30 |
--------------------------------------------------------------------------------
/data/mytest/output:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/test/11.png:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/train/1.png:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/train/1.xml:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/train_cleaned/1.png:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/valid/2.png:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/valid/2.xml:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/valid_cleaned/2.png:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import cv2
3 | import os
4 | import transforms as Transforms
5 | from utils import *
6 |
7 |
8 | class UNetDataset(Dataset):
9 | def __init__(self, dir_train, dir_mask, transform=None):
10 | self.dirTrain = dir_train
11 | self.dirMask = dir_mask
12 | self.transform = transform
13 | self.dataTrain = [os.path.join(self.dirTrain, filename)
14 | for filename in os.listdir(self.dirTrain)
15 | if filename.endswith('.jpg') or filename.endswith('.png')]
16 | self.dataBox = [os.path.join(self.dirTrain, filename)
17 | for filename in os.listdir(self.dirTrain)
18 | if filename.endswith('.xml')]
19 | self.dataMask = [os.path.join(self.dirMask, filename)
20 | for filename in os.listdir(self.dirMask)
21 | if filename.endswith('.jpg') or filename.endswith('.png')]
22 | self.trainDataSize = len(self.dataTrain)
23 | self.maskDataSize = len(self.dataMask)
24 | self.dataBoxSize = len(self.dataBox)
25 |
26 | def __getitem__(self, index):
27 | assert self.trainDataSize == self.maskDataSize
28 | assert self.trainDataSize == self.dataBoxSize
29 |
30 | image = cv2.imread(self.dataTrain[index])
31 | label = cv2.imread(self.dataMask[index])
32 | boxfile = self.dataBox[index]
33 |
34 | image, label = cropImage(image, boxfile, label)
35 |
36 | if self.transform:
37 | for method in self.transform:
38 | image, label = method(image, label)
39 |
40 | image = image.transpose((2, 0, 1))
41 | label = label.transpose((2, 0, 1))
42 |
43 | return image, label
44 |
45 | def __len__(self):
46 | assert self.trainDataSize == self.maskDataSize
47 | assert self.trainDataSize == self.dataBoxSize
48 |
49 | return self.trainDataSize
50 |
51 |
52 | if __name__ == '__main__':
53 | from torch.utils.data import DataLoader
54 |
55 | transforms = [
56 | # Transforms.RandomCrop(2300, 2300),
57 | Transforms.RondomFlip(),
58 | Transforms.RandomRotate(15),
59 | Transforms.Log(0.5),
60 | Transforms.Blur(0.2),
61 | Transforms.ToTensor(),
62 | Transforms.ToGray()
63 | ]
64 | dataset = UNetDataset('./data/train', './data/train_cleaned', transform=transforms)
65 | dataLoader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)
66 |
67 | for index, (batch_x, batch_y) in enumerate(dataLoader):
68 | print(batch_x.size(), batch_y.size()) # shape:(batch_size, 1, h, w) 1表示图像是灰度图
69 |
70 | dis = batch_y[0][0].numpy() # shape:(h,w)
71 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import os
5 | from unet import UNet
6 | from utils import *
7 | import config
8 |
9 | # load net
10 | print('load net')
11 | net = UNet(n_channels=config.n_channels, n_classes=config.n_classes).to(config.device)
12 | if os.path.exists(config.best_model):
13 | checkpoint = torch.load(config.best_model, map_location='cpu')
14 | net.load_state_dict(checkpoint['net'])
15 | else:
16 | exit(0)
17 |
18 | net.eval()
19 | # load img
20 | print('load img')
21 | image = cv2.imread(config.test_image)
22 | img, _ = cropImage(image, config.test_boxes, image)
23 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
24 | img = img.reshape(img.shape[0], img.shape[1])
25 | print(img.shape)
26 |
27 | input = torch.from_numpy(img[np.newaxis][np.newaxis]).float() / 255
28 | output = net(input.to(config.device))
29 | print(output.shape)
30 | output = output[0, 0].detach().data.cpu().numpy()
31 | res = np.concatenate((img / 255, output), axis=1)
32 |
33 | cv2.imwrite(os.path.join(config.output_path, "0.png"), res * 255)
34 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """
3 | 训练模型指令:
4 | python train.py
5 |
6 | @author: libo
7 | """
8 | import torch
9 | import torch.nn as nn
10 | from torch import optim
11 | import os
12 | from unet import UNet
13 | from datasets import UNetDataset
14 | import transforms as Transforms
15 | from torch.utils.data import DataLoader
16 | from torch.utils.tensorboard import SummaryWriter
17 | import config
18 |
19 |
20 | def train():
21 | transforms = [
22 | Transforms.RondomFlip(),
23 | Transforms.RandomRotate(15),
24 | Transforms.Log(0.5),
25 | Transforms.Blur(0.2),
26 | Transforms.ToGray(),
27 | Transforms.ToTensor()
28 | ]
29 | train_dataset = UNetDataset('./data/train/', './data/train_cleaned/', transform=transforms)
30 | train_dataLoader = DataLoader(dataset=train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=0)
31 |
32 | valid_dataset = UNetDataset('./data/valid/', './data/valid_cleaned/', transform=transforms)
33 | valid_dataLoader = DataLoader(dataset=valid_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=0)
34 |
35 | net = UNet(n_channels=config.n_channels, n_classes=config.n_classes).to(config.device)
36 | writer = SummaryWriter()
37 | optimizer = optim.Adam(net.parameters(), lr=config.LR)
38 | if config.n_classes > 1:
39 | loss_func = nn.CrossEntropyLoss().to(config.device)
40 | else:
41 | loss_func = nn.BCEWithLogitsLoss().to(config.device)
42 | best_loss = float('inf')
43 |
44 | if os.path.exists(config.weight_with_optimizer):
45 | checkpoint = torch.load(config.weight_with_optimizer, map_location='cpu')
46 | net.load_state_dict(checkpoint['net'])
47 | optimizer.load_state_dict(checkpoint['optimizer'])
48 | print('load weight')
49 |
50 | for epoch in range(config.EPOCH):
51 | train_loss = 0
52 | net.train()
53 | for step, (batch_x, batch_y) in enumerate(train_dataLoader):
54 | batch_x = batch_x.to(device=config.device)
55 | batch_y = batch_y.squeeze(1).to(device=config.device)
56 | output = net(batch_x)
57 | loss = loss_func(output, batch_y)
58 | train_loss += loss.item()
59 | if loss < best_loss:
60 | best_loss = loss
61 | torch.save({'net': net.state_dict(), 'optimizer': optimizer.state_dict()},
62 | config.best_model_with_optimizer)
63 | torch.save({'net': net.state_dict()}, config.best_model)
64 | optimizer.zero_grad()
65 | loss.backward()
66 | optimizer.step()
67 |
68 | net.eval()
69 | eval_loss = 0
70 | for step, (batch_x, batch_y) in enumerate(valid_dataLoader):
71 | batch_x = batch_x.to(device=config.device)
72 | batch_y = batch_y.squeeze(1).to(device=config.device)
73 | output = net(batch_x)
74 | valid_loss = loss_func(output, batch_y)
75 | eval_loss += valid_loss.item()
76 |
77 | writer.add_scalar("train_loss", train_loss, epoch)
78 | writer.add_scalar("eval_loss", eval_loss, epoch)
79 | print("*" * 80)
80 | print('epoch: %d | train loss: %.4f | valid loss: %.4f' % (epoch, train_loss, eval_loss))
81 | print("*" * 80)
82 |
83 | if (epoch + 1) % 10 == 0:
84 | torch.save({
85 | 'net': net.state_dict(),
86 | 'optimizer': optimizer.state_dict()
87 | }, config.weight_with_optimizer)
88 | torch.save({
89 | 'net': net.state_dict()
90 | }, config.weight)
91 | print('saved')
92 |
93 | writer.close()
94 |
95 |
96 | if __name__ == '__main__':
97 | train()
98 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 |
5 |
6 | class ToGray(object):
7 | def __init__(self):
8 | pass
9 |
10 | def __call__(self, image, label):
11 | if len(image.shape) == 3:
12 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
13 | image = image.reshape(image.shape[0], image.shape[1], 1)
14 | if len(label.shape) == 3:
15 | label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
16 | label = label.reshape(label.shape[0], label.shape[1], 1)
17 | return image, label
18 |
19 |
20 | class RondomFlip(object):
21 | def __init__(self):
22 | pass
23 |
24 | def __call__(self, image, label):
25 | degree = random.random()
26 | if degree <= 0.33:
27 | image = cv2.flip(image, 0)
28 | label = cv2.flip(label, 0)
29 | elif degree <= 0.66:
30 | image = cv2.flip(image, 1)
31 | label = cv2.flip(label, 1)
32 | return image, label
33 |
34 |
35 | class RandomRotate(object):
36 | def __init__(self, angle):
37 | self.angle = angle
38 |
39 | def __call__(self, image, label):
40 | angle = random.random() * self.angle
41 | angle = angle if random.random() < 0.5 else -angle
42 | h, w = image.shape[0], image.shape[1]
43 | scale = random.random() * 0.4 + 0.9
44 | matRotate = cv2.getRotationMatrix2D((w * 0.5, h * 0.5), angle, scale)
45 | image = cv2.warpAffine(image, matRotate, (w, h))
46 | label = cv2.warpAffine(label, matRotate, (w, h),
47 | borderMode=cv2.BORDER_CONSTANT, borderValue=255)
48 | return image, label
49 |
50 |
51 | class RandomCrop(object):
52 | def __init__(self, crop_h, crop_w):
53 | self.crop_h, self.crop_w = crop_h, crop_w
54 |
55 | def __call__(self, image, label):
56 | h, w = image.shape[0], image.shape[1]
57 | crop_x = int(random.random() * (w - self.crop_w))
58 | crop_y = int(random.random() * (h - self.crop_h))
59 | image = image[crop_y: crop_y + self.crop_h, crop_x: crop_x + self.crop_w]
60 | label = label[crop_y: crop_y + self.crop_h, crop_x: crop_x + self.crop_w]
61 | return image, label
62 |
63 |
64 | class EqualizeHist(object):
65 | def __init__(self, degree):
66 | self.degree = degree
67 |
68 | def __call__(self, image, label):
69 | if random.random() < self.degree:
70 | image = cv2.equalizeHist(image)
71 | return image, label
72 |
73 |
74 | class Blur(object):
75 | def __init__(self, degree):
76 | self.degree = degree
77 |
78 | def __call__(self, image, label):
79 | if random.random() < self.degree:
80 | image = cv2.blur(image, (3, 3))
81 | return image, label
82 |
83 |
84 | class Log(object):
85 | def __init__(self, degree):
86 | self.degree = degree
87 |
88 | def __call__(self, image, label):
89 | if random.random() < self.degree:
90 | image = np.log(1 + image.astype(np.float32) / 255) * 255
91 | return image.astype(np.uint8), label
92 |
93 |
94 | class ToTensor(object):
95 | def __init__(self):
96 | pass
97 |
98 | def __call__(self, image, label):
99 | image = image / 255
100 | label = (255 - label) // 50
101 | label[label > 1] = 1
102 |
103 | return image.astype(np.float32), label.astype(np.int64)
104 |
105 |
106 | if __name__ == '__main__':
107 | transforms = [
108 | RandomCrop(2300, 2300),
109 | RondomFlip(),
110 | RandomRotate(15),
111 | Log(0.5),
112 | Blur(0.2),
113 | ToTensor(),
114 | ToGray()
115 | ]
116 |
117 | image = cv2.imread('./data/train/2.png')
118 | label = cv2.imread('./data/train_cleaned/2.png')
119 |
120 | for i in range(1):
121 | img1, img2 = image, label
122 | for index, transform in enumerate(transforms):
123 | img1, img2 = transform(img1, img2)
124 | img = np.concatenate((img1, img2), axis=1)
125 | cv2.imwrite("./data/mytest/{}.png".format(index), img * 255)
126 |
--------------------------------------------------------------------------------
/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_model import UNet
2 |
--------------------------------------------------------------------------------
/unet/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 | from .unet_parts import *
3 |
4 |
5 | class UNet(nn.Module):
6 | def __init__(self, n_channels, n_classes, bilinear=True):
7 | super(UNet, self).__init__()
8 | self.n_channels = n_channels
9 | self.n_classes = n_classes
10 | self.bilinear = bilinear
11 |
12 | self.inc = DoubleConv(n_channels, 64)
13 | self.down1 = Down(64, 128)
14 | self.down2 = Down(128, 256)
15 | self.down3 = Down(256, 512)
16 | factor = 2 if bilinear else 1
17 | self.down4 = Down(512, 1024 // factor)
18 | self.up1 = Up(1024, 512 // factor, bilinear)
19 | self.up2 = Up(512, 256 // factor, bilinear)
20 | self.up3 = Up(256, 128 // factor, bilinear)
21 | self.up4 = Up(128, 64, bilinear)
22 | self.outc = OutConv(64, n_classes)
23 |
24 | def forward(self, x):
25 | x1 = self.inc(x)
26 | x2 = self.down1(x1)
27 | x3 = self.down2(x2)
28 | x4 = self.down3(x3)
29 | x5 = self.down4(x4)
30 | x = self.up1(x5, x4)
31 | x = self.up2(x, x3)
32 | x = self.up3(x, x2)
33 | x = self.up4(x, x1)
34 | out = self.outc(x)
35 | return out
36 |
--------------------------------------------------------------------------------
/unet/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """ (convolution => [BN] => ReLU) * 2 """
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 | # 定义了两种上采样的方法:双线性插值和反卷积
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 | def forward(self, x1, x2):
57 | # x1接收的是上采样的数据,x2接收的是特征融合的数据。
58 | x1 = self.up(x1)
59 | # 先对小的feature map进行padding,再进行两者concat。
60 | # input is CHW
61 | diffY = x2.size()[2] - x1.size()[2]
62 | diffX = x2.size()[3] - x1.size()[3]
63 |
64 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
65 | diffY // 2, diffY - diffY // 2])
66 | # if you have padding issues, see
67 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
68 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
69 | x = torch.cat([x2, x1], dim=1)
70 | return self.conv(x)
71 |
72 |
73 | class OutConv(nn.Module):
74 | # 根据分割数量整合输出通道。
75 | def __init__(self, in_channels, out_channels):
76 | super(OutConv, self).__init__()
77 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
78 |
79 | def forward(self, x):
80 | return self.conv(x)
81 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import xml.etree.ElementTree as ET
4 | import numpy.random as npr
5 | import cv2
6 | import config
7 |
8 | image_size = config.image_size
9 |
10 |
11 | def getBoxes(xml_file):
12 | boxes = []
13 | tree = ET.parse(xml_file)
14 | root = tree.getroot()
15 | for member in root.findall('object'):
16 | box = [int(member[4][0].text),
17 | int(float(member[4][1].text)),
18 | int(member[4][2].text),
19 | int(member[4][3].text)
20 | ]
21 | boxes.append(box)
22 | return boxes
23 |
24 |
25 | def getImage(image, label, boxes):
26 | cropCode = random.choice([-1] + [i for i in range(len(boxes))])
27 | height, width = image.shape[0], image.shape[1]
28 | if cropCode != -1:
29 | box = boxes[cropCode]
30 | xmin, ymin, xmax, ymax = box[0], box[1], box[2], box[3]
31 | w = xmax - xmin
32 | h = ymax - ymin
33 | if max(w, h) < image_size:
34 | nx = npr.randint(max((xmax - image_size), 0), xmin)
35 | ny = npr.randint(max((ymax - image_size), 0), ymin)
36 | if nx + image_size > width:
37 | nx = width - image_size
38 | if ny + image_size > height:
39 | ny = height - image_size
40 | cropped_im = image[ny: ny + image_size, nx: nx + image_size, :]
41 | cropped_la = label[ny: ny + image_size, nx: nx + image_size, :]
42 | else:
43 | nx = npr.randint(xmax - max(w, h) - 100, xmin)
44 | ny = npr.randint(ymax - max(w, h) - 100, ymin)
45 | if nx + max(w, h) + 100 > width:
46 | nx = width - max(w, h) - 100
47 | if ny + max(w, h) + 100 > height:
48 | ny = height - max(w, h) - 100
49 | cropped_im = image[ny: ny + max(w, h) + 100, nx: nx + max(w, h) + 100, :]
50 | cropped_la = label[ny: ny + max(w, h) + 100, nx: nx + max(w, h) + 100, :]
51 | else:
52 | crop_x = int(random.random() * (width - image_size))
53 | crop_y = int(random.random() * (height - image_size))
54 | cropped_im = image[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
55 | cropped_la = label[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
56 |
57 | return cropped_im, cropped_la
58 |
59 |
60 | def cropImage(image, boxfile, label):
61 | boxes = getBoxes(boxfile)
62 | image, label = getImage(image, label, boxes)
63 | image = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
64 | label = cv2.resize(label, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
65 | return image, label
66 |
67 |
68 | if __name__ == '__main__':
69 | dirTrain = './data/train'
70 | dataBox = [os.path.join(dirTrain, filename)
71 | for filename in os.listdir(dirTrain)
72 | if filename.endswith('.xml')]
73 | dataTrain = [os.path.join(dirTrain, filename)
74 | for filename in os.listdir(dirTrain)
75 | if filename.endswith('.jpg') or filename.endswith('.png')]
76 |
77 | for i in range(len(dataTrain)):
78 | image = cv2.imread(dataTrain[i])
79 | boxes = getBoxes(dataBox[i])
80 | crop_image, cropped_label = getImage(image, boxes)
81 | cv2.imwrite('./test/{}.png'.format(i), crop_image)
82 |
--------------------------------------------------------------------------------