├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── ctw1500.py ├── ic15.py ├── msra.py ├── synthtext.py ├── testdataset.py └── totaltext.py ├── eval ├── ctw │ ├── eval.py │ └── file_util.py ├── ic15 │ ├── gt.zip │ ├── rrc_evaluation_funcs.py │ ├── rrc_evaluation_funcs_v1.py │ ├── rrc_evaluation_funcs_v2.py │ ├── script.py │ └── script_self_adapt.py ├── msra │ ├── eval.py │ └── file_util.py └── totaltext │ └── Deteval.py ├── inference.py ├── loss ├── __init__.py ├── dice_loss.py ├── emb_loss_v1.py ├── iou.py ├── loss.py └── ohem.py ├── misc ├── ctw_statistics.png ├── synthtext_statistics.png └── tt_statistics.png ├── models ├── __init__.py ├── backbone.py ├── ffm.py ├── fpem.py ├── head.py └── pan.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── average_meter.py ├── helper.py └── pa ├── __init__.py ├── pa.cpp ├── pa.py ├── pa.pyx ├── readme.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | logs 4 | .ipynb_checkpoints 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chun-Hao Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pixel Aggregation Network 2 | 3 | This is an unofficial PyTorch re-implementation of paper "Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network" published in ICCV 2019, with PyTorch >= v1.4.0. 4 | 5 | ## Task 6 | 7 | - [x] Backbone model 8 | - [x] FPEM model 9 | - [x] FFM model 10 | - [x] Integrated model 11 | - [x] Loss Function 12 | - [x] Data preprocessing 13 | - [x] Data postprocessing 14 | - [x] Training pipeline 15 | - [x] Inference pipeline 16 | - [x] Evaluation pipeline 17 | 18 | ## Command 19 | 20 | ### Training 21 | 22 | `` 23 | python train.py --batch 32 --epoch 5000 --dataset_type ctw --gpu True 24 | `` 25 | 26 | ### Inference 27 | 28 | `` 29 | python inference.py --input ./data/CTW1500/test/text_image --model ./outputs/model_epoch_0.pth --bbox_type poly 30 | `` 31 | 32 | ## Results 33 | 34 | ### CTW1500 35 | ![Statstics for CTW training](https://github.com/liuch37/pan-pytorch/blob/master/misc/ctw_statistics.png) 36 | 37 | Model | Precision | Recall | F score | FPS (CPU) + pa.py | FPS (1 GPU) + pa.py | FPS (1 GPU) + pa.pyx | 38 | ------- | --------- | ------ | ------- | ------------------- | ------------------- | -------------------- | 39 | PAN-640 | 0.8509 | 0.7927 | 0.8208 | 0.3493 | 4.6347 | 21.167 | 40 | 41 | ### TotalText 42 | ![Statstics for TT training](https://github.com/liuch37/pan-pytorch/blob/master/misc/tt_statistics.png) 43 | 44 | Model | Precision | Recall | F score | FPS (CPU) + pa.py | FPS (1 GPU) + pa.py | FPS (1 GPU) + pa.pyx | 45 | ------- | --------- | ------ | ------- | ------------------- | ------------------- | -------------------- | 46 | PAN-640 | 0.9011 | 0.8040 | 0.8498 | 0.2883 | 7.6481 | 20.390 | 47 | 48 | ### SynthText 49 | ![Statstics for SynthText training](https://github.com/liuch37/pan-pytorch/blob/master/misc/synthtext_statistics.png) 50 | 51 | ## Supported Dataset 52 | 53 | - [x] CTW1500: https://github.com/Yuliang-Liu/Curve-Text-Detector 54 | - [x] Total-Text: https://github.com/cs-chan/Total-Text-Dataset 55 | - [x] SynthText: https://www.robots.ox.ac.uk/~vgg/data/scenetext/ 56 | - [x] MSRA-TD500: http://www.iapr-tc11.org/mediawiki/index.php/MSRA_Text_Detection_500_Database_(MSRA-TD500) 57 | - [x] ICDAR-2015: https://rrc.cvc.uab.es/ 58 | 59 | ## Source 60 | 61 | [1] Original paper: https://arxiv.org/abs/1908.05900 62 | 63 | [2] Official PyTorch code: https://github.com/whai362/pan_pp.pytorch 64 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/ctw1500.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build data loader for CTW1500 dataset. 3 | ''' 4 | 5 | import numpy as np 6 | from torch.utils import data 7 | import cv2 8 | import random 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import torch 12 | import pyclipper 13 | import os 14 | import pdb 15 | import matplotlib.pyplot as plt 16 | 17 | ctw_root_dir = './data/CTW1500/' 18 | ctw_train_data_dir = ctw_root_dir + 'train/text_image/' 19 | ctw_train_gt_dir = ctw_root_dir + 'train/text_label_curve/' 20 | ctw_test_data_dir = ctw_root_dir + 'test/text_image/' 21 | ctw_test_gt_dir = ctw_root_dir + 'test/text_label_circum/' 22 | 23 | def PolyArea(x,y): 24 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 25 | 26 | def get_img(img_path): 27 | try: 28 | img = cv2.imread(img_path) 29 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 30 | except Exception as e: 31 | print(img_path) 32 | raise 33 | return img 34 | 35 | def get_ann(img, gt_path): 36 | h, w = img.shape[0:2] 37 | #lines = mmcv.list_from_file(gt_path) # replaced by python readlines 38 | with open(gt_path, "r") as file: 39 | lines = file.readlines() 40 | bboxes = [] 41 | words = [] 42 | for line in lines: 43 | line = line.replace('\xef\xbb\xbf', '') 44 | gt = line.split(',') 45 | 46 | x1 = np.int(gt[0]) 47 | y1 = np.int(gt[1]) 48 | 49 | bbox = [np.int(gt[i]) for i in range(4, 32)] 50 | bbox = np.asarray(bbox) + ([x1 * 1.0, y1 * 1.0] * 14) 51 | bbox = np.asarray(bbox) / ([w * 1.0, h * 1.0] * 14) 52 | 53 | bboxes.append(bbox) 54 | words.append('???') 55 | return bboxes, words 56 | 57 | 58 | def random_horizontal_flip(imgs): 59 | if random.random() < 0.5: 60 | for i in range(len(imgs)): 61 | imgs[i] = np.flip(imgs[i], axis=1).copy() 62 | return imgs 63 | 64 | 65 | def random_rotate(imgs): 66 | max_angle = 10 67 | angle = random.random() * 2 * max_angle - max_angle 68 | for i in range(len(imgs)): 69 | img = imgs[i] 70 | w, h = img.shape[:2] 71 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 72 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 73 | imgs[i] = img_rotation 74 | return imgs 75 | 76 | 77 | def scale_aligned(img, scale): 78 | h, w = img.shape[0:2] 79 | h = int(h * scale + 0.5) 80 | w = int(w * scale + 0.5) 81 | if h % 32 != 0: 82 | h = h + (32 - h % 32) 83 | if w % 32 != 0: 84 | w = w + (32 - w % 32) 85 | img = cv2.resize(img, dsize=(w, h)) 86 | return img 87 | 88 | 89 | def random_scale(img, short_size=640): 90 | h, w = img.shape[0:2] 91 | 92 | random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]) 93 | scale = (np.random.choice(random_scale) * short_size) / min(h, w) 94 | 95 | img = scale_aligned(img, scale) 96 | return img 97 | 98 | 99 | def scale_aligned_short(img, short_size=640): 100 | h, w = img.shape[0:2] 101 | scale = short_size * 1.0 / min(h, w) 102 | h = int(h * scale + 0.5) 103 | w = int(w * scale + 0.5) 104 | if h % 32 != 0: 105 | h = h + (32 - h % 32) 106 | if w % 32 != 0: 107 | w = w + (32 - w % 32) 108 | img = cv2.resize(img, dsize=(w, h)) 109 | return img 110 | 111 | 112 | def random_crop_padding(imgs, target_size): 113 | h, w = imgs[0].shape[0:2] 114 | t_w, t_h = target_size 115 | p_w, p_h = target_size 116 | if w == t_w and h == t_h: 117 | return imgs 118 | 119 | t_h = t_h if t_h < h else h 120 | t_w = t_w if t_w < w else w 121 | 122 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 123 | # make sure to crop the text region 124 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 125 | tl[tl < 0] = 0 126 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 127 | br[br < 0] = 0 128 | br[0] = min(br[0], h - t_h) 129 | br[1] = min(br[1], w - t_w) 130 | 131 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 132 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 133 | else: 134 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 135 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 136 | 137 | n_imgs = [] 138 | for idx in range(len(imgs)): 139 | if len(imgs[idx].shape) == 3: 140 | s3_length = int(imgs[idx].shape[-1]) 141 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 142 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 143 | value=tuple(0 for i in range(s3_length))) 144 | else: 145 | img = imgs[idx][i:i + t_h, j:j + t_w] 146 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 147 | n_imgs.append(img_p) 148 | return n_imgs 149 | 150 | def dist(a, b): 151 | return np.linalg.norm((a - b), ord=2, axis=0) 152 | 153 | def perimeter(bbox): 154 | peri = 0.0 155 | for i in range(bbox.shape[0]): 156 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 157 | return peri 158 | 159 | def shrink(bboxes, rate, max_shr=20): 160 | rate = rate * rate 161 | shrinked_bboxes = [] 162 | for bbox in bboxes: 163 | # Replace ply.Polygon with simple area calculation function 164 | #area = plg.Polygon(bbox).area() 165 | x = bbox[:,0] 166 | y = bbox[:,1] 167 | area = PolyArea(x,y) 168 | peri = perimeter(bbox) 169 | 170 | try: 171 | pco = pyclipper.PyclipperOffset() 172 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 173 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 174 | 175 | shrinked_bbox = pco.Execute(-offset) 176 | if len(shrinked_bbox) == 0: 177 | shrinked_bboxes.append(bbox) 178 | continue 179 | 180 | shrinked_bbox = np.array(shrinked_bbox[0]) 181 | if shrinked_bbox.shape[0] <= 2: 182 | shrinked_bboxes.append(bbox) 183 | continue 184 | 185 | shrinked_bboxes.append(shrinked_bbox) 186 | except Exception as e: 187 | print(type(shrinked_bbox), shrinked_bbox) 188 | print('area:', area, 'peri:', peri) 189 | shrinked_bboxes.append(bbox) 190 | 191 | return shrinked_bboxes 192 | 193 | class PAN_CTW(data.Dataset): 194 | def __init__(self, 195 | split='train', 196 | is_transform=False, 197 | img_size=None, 198 | short_size=640, 199 | kernel_scale=0.7, 200 | report_speed=False): 201 | self.split = split 202 | self.is_transform = is_transform 203 | 204 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 205 | self.kernel_scale = kernel_scale 206 | self.short_size = short_size 207 | 208 | if split == 'train': 209 | data_dirs = [ctw_train_data_dir] 210 | gt_dirs = [ctw_train_gt_dir] 211 | elif split == 'test': 212 | data_dirs = [ctw_test_data_dir] 213 | gt_dirs = [ctw_test_gt_dir] 214 | else: 215 | print('Error: split must be train or test!') 216 | raise 217 | 218 | self.img_paths = [] 219 | self.gt_paths = [] 220 | 221 | for data_dir, gt_dir in zip(data_dirs, gt_dirs): 222 | #img_names = [img_name for img_name in mmcv.utils.scandir(data_dir, '.jpg')] # to be handled by removing mmcv 223 | #img_names.extend([img_name for img_name in mmcv.utils.scandir(data_dir, '.png')]) # to be handled by removing mmcv 224 | img_names = os.listdir(data_dir) 225 | 226 | img_paths = [] 227 | gt_paths = [] 228 | for idx, img_name in enumerate(img_names): 229 | img_path = data_dir + img_name 230 | img_paths.append(img_path) 231 | 232 | gt_name = img_name.split('.')[0] + '.txt' 233 | gt_path = gt_dir + gt_name 234 | gt_paths.append(gt_path) 235 | 236 | self.img_paths.extend(img_paths) 237 | self.gt_paths.extend(gt_paths) 238 | 239 | ''' 240 | if report_speed: 241 | target_size = 3000 242 | data_size = len(self.img_paths) 243 | extend_scale = (target_size + data_size - 1) // data_size 244 | self.img_paths = (self.img_paths * extend_scale)[:target_size] 245 | self.gt_paths = (self.gt_paths * extend_scale)[:target_size] 246 | ''' 247 | 248 | self.max_word_num = 200 249 | 250 | def __len__(self): 251 | return len(self.img_paths) 252 | 253 | def prepare_train_data(self, index): 254 | img_path = self.img_paths[index] 255 | gt_path = self.gt_paths[index] 256 | 257 | img = get_img(img_path) 258 | bboxes, words = get_ann(img, gt_path) 259 | 260 | if len(bboxes) > self.max_word_num: 261 | bboxes = bboxes[:self.max_word_num] 262 | 263 | if self.is_transform: 264 | img = random_scale(img, self.short_size) 265 | 266 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 267 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 268 | if len(bboxes) > 0: 269 | for i in range(len(bboxes)): 270 | bboxes[i] = np.reshape(bboxes[i] * ([img.shape[1], img.shape[0]] * (bboxes[i].shape[0] // 2)), 271 | (bboxes[i].shape[0] // 2, 2)).astype('int32') 272 | for i in range(len(bboxes)): 273 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 274 | if words[i] == '###': 275 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 276 | 277 | gt_kernels = [] 278 | for rate in [self.kernel_scale]: 279 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 280 | kernel_bboxes = shrink(bboxes, rate) 281 | for i in range(len(bboxes)): 282 | cv2.drawContours(gt_kernel, [kernel_bboxes[i]], -1, 1, -1) 283 | gt_kernels.append(gt_kernel) 284 | 285 | if self.is_transform: 286 | imgs = [img, gt_instance, training_mask] 287 | imgs.extend(gt_kernels) 288 | 289 | imgs = random_horizontal_flip(imgs) 290 | imgs = random_rotate(imgs) 291 | imgs = random_crop_padding(imgs, self.img_size) 292 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 293 | 294 | gt_text = gt_instance.copy() 295 | gt_text[gt_text > 0] = 1 296 | gt_kernels = np.array(gt_kernels) 297 | 298 | max_instance = np.max(gt_instance) 299 | gt_bboxes = np.zeros((self.max_word_num + 1, 4), dtype=np.int32) 300 | for i in range(1, max_instance + 1): 301 | ind = gt_instance == i 302 | if np.sum(ind) == 0: 303 | continue 304 | points = np.array(np.where(ind)).transpose((1, 0)) 305 | tl = np.min(points, axis=0) 306 | br = np.max(points, axis=0) + 1 307 | gt_bboxes[i] = (tl[0], tl[1], br[0], br[1]) 308 | 309 | if self.is_transform: 310 | img = Image.fromarray(img) 311 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 312 | 313 | img = transforms.ToTensor()(img) 314 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 315 | 316 | gt_text = torch.from_numpy(gt_text).long() 317 | gt_kernels = torch.from_numpy(gt_kernels).long() 318 | training_mask = torch.from_numpy(training_mask).long() 319 | gt_instance = torch.from_numpy(gt_instance).long() 320 | gt_bboxes = torch.from_numpy(gt_bboxes).long() 321 | 322 | data = dict( 323 | imgs=img, 324 | gt_texts=gt_text, 325 | gt_kernels=gt_kernels, 326 | training_masks=training_mask, 327 | gt_instances=gt_instance, 328 | gt_bboxes=gt_bboxes, 329 | ) 330 | 331 | return data 332 | 333 | def prepare_test_data(self, index): 334 | img_path = self.img_paths[index] 335 | 336 | img = get_img(img_path) 337 | img_meta = dict( 338 | org_img_size=np.array(img.shape[:2]) 339 | ) 340 | 341 | img = scale_aligned_short(img, self.short_size) 342 | img_meta.update(dict( 343 | img_size=np.array(img.shape[:2]) 344 | )) 345 | 346 | img = transforms.ToTensor()(img) 347 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 348 | 349 | data = dict( 350 | imgs=img, 351 | img_metas=img_meta 352 | ) 353 | 354 | return data 355 | 356 | def __getitem__(self, index): 357 | if self.split == 'train': 358 | return self.prepare_train_data(index) 359 | elif self.split == 'test': 360 | return self.prepare_test_data(index) 361 | 362 | # unit testing 363 | if __name__ == '__main__': 364 | 365 | train_dataset = PAN_CTW(split='train', 366 | is_transform=False, 367 | img_size=None, 368 | short_size=640, 369 | kernel_scale=0.7, 370 | report_speed=False) 371 | 372 | for i, data in enumerate(train_dataset): 373 | # convert to numpy and plot 374 | print("Process image index:", i) 375 | imgs = data['imgs'] 376 | gt_texts = data['gt_texts'] 377 | gt_kernels = data['gt_kernels'] 378 | training_masks = data['training_masks'] 379 | gt_instances = data['gt_instances'] 380 | gt_bboxes = data['gt_bboxes'] 381 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 382 | gt_texts = gt_texts.detach().cpu().numpy() 383 | gt_kernels = gt_kernels.detach().cpu().numpy()[0] 384 | training_masks = training_masks.detach().cpu().numpy() 385 | gt_instances = gt_instances.detach().cpu().numpy() 386 | ''' 387 | plt.figure(1) 388 | plt.imshow(imgs) 389 | plt.figure(2) 390 | plt.imshow(gt_texts) 391 | plt.title('gt_texts') 392 | plt.figure(3) 393 | plt.imshow(gt_kernels) 394 | plt.title('gt_kernels') 395 | plt.figure(4) 396 | plt.imshow(training_masks) 397 | plt.title('training_masks') 398 | plt.figure(5) 399 | plt.imshow(gt_instances) 400 | plt.title('gt_instances') 401 | plt.show() 402 | pdb.set_trace() 403 | ''' 404 | 405 | test_dataset = PAN_CTW(split='test', 406 | is_transform=False, 407 | img_size=None, 408 | short_size=640, 409 | kernel_scale=0.7, 410 | report_speed=False) 411 | 412 | for i, data in enumerate(test_dataset): 413 | # convert to numpy and plot 414 | print("Process image index:", i) 415 | imgs = data['imgs'] 416 | img_metas = data['img_metas'] 417 | print(img_metas) 418 | ''' 419 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 420 | plt.imshow(imgs) 421 | plt.show() 422 | pdb.set_trace() 423 | ''' -------------------------------------------------------------------------------- /dataset/ic15.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build data loader for ICDAR-2015 dataset. 3 | ''' 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils import data 8 | import cv2 9 | import random 10 | import torchvision.transforms as transforms 11 | import torch 12 | import pyclipper 13 | import math 14 | import string 15 | import os 16 | import matplotlib.pyplot as plt 17 | import pdb 18 | 19 | ic15_root_dir = './data/ICDAR2015/Challenge4/' 20 | ic15_train_data_dir = ic15_root_dir + 'ch4_training_images/' 21 | ic15_train_gt_dir = ic15_root_dir + 'ch4_training_localization_transcription_gt/' 22 | ic15_test_data_dir = ic15_root_dir + 'ch4_test_images/' 23 | ic15_test_gt_dir = ic15_root_dir + 'Challenge4_Test_Task1_GT/' 24 | 25 | def PolyArea(x,y): 26 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 27 | 28 | def get_img(img_path): 29 | try: 30 | img = cv2.imread(img_path) 31 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 32 | except Exception as e: 33 | print('Cannot read image: %s.' % img_path) 34 | raise 35 | return img 36 | 37 | 38 | def get_ann(img, gt_path): 39 | h, w = img.shape[0:2] 40 | #lines = mmcv.list_from_file(gt_path) # replaced 41 | with open(gt_path, "r") as file: 42 | lines = file.readlines() 43 | bboxes = [] 44 | words = [] 45 | for line in lines: 46 | line = line.encode('utf-8').decode('utf-8-sig') 47 | line = line.replace('\xef\xbb\xbf', '') 48 | gt = line.split(',') 49 | word = gt[8].replace('\r', '').replace('\n', '') 50 | if word[0] == '#': 51 | words.append('###') 52 | else: 53 | words.append(word) 54 | 55 | bbox = [int(gt[i]) for i in range(8)] 56 | bbox = np.array(bbox) / ([w * 1.0, h * 1.0] * 4) 57 | bboxes.append(bbox) 58 | return np.array(bboxes), words 59 | 60 | 61 | def random_horizontal_flip(imgs): 62 | if random.random() < 0.5: 63 | for i in range(len(imgs)): 64 | imgs[i] = np.flip(imgs[i], axis=1).copy() 65 | return imgs 66 | 67 | 68 | def random_rotate(imgs): 69 | max_angle = 10 70 | angle = random.random() * 2 * max_angle - max_angle 71 | for i in range(len(imgs)): 72 | img = imgs[i] 73 | w, h = img.shape[:2] 74 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 75 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 76 | imgs[i] = img_rotation 77 | return imgs 78 | 79 | 80 | def scale_aligned(img, h_scale, w_scale): 81 | h, w = img.shape[0:2] 82 | h = int(h * h_scale + 0.5) 83 | w = int(w * w_scale + 0.5) 84 | if h % 32 != 0: 85 | h = h + (32 - h % 32) 86 | if w % 32 != 0: 87 | w = w + (32 - w % 32) 88 | img = cv2.resize(img, dsize=(w, h)) 89 | return img 90 | 91 | 92 | def scale_aligned_short(img, short_size=736): 93 | h, w = img.shape[0:2] 94 | scale = short_size * 1.0 / min(h, w) 95 | h = int(h * scale + 0.5) 96 | w = int(w * scale + 0.5) 97 | if h % 32 != 0: 98 | h = h + (32 - h % 32) 99 | if w % 32 != 0: 100 | w = w + (32 - w % 32) 101 | img = cv2.resize(img, dsize=(w, h)) 102 | return img 103 | 104 | 105 | def random_scale(img, short_size=736): 106 | h, w = img.shape[0:2] 107 | 108 | scale = np.random.choice(np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])) 109 | scale = (scale * short_size) / min(h, w) 110 | 111 | aspect = np.random.choice(np.array([0.9, 0.95, 1.0, 1.05, 1.1])) 112 | h_scale = scale * math.sqrt(aspect) 113 | w_scale = scale / math.sqrt(aspect) 114 | 115 | img = scale_aligned(img, h_scale, w_scale) 116 | return img 117 | 118 | 119 | def random_crop_padding(imgs, target_size): 120 | h, w = imgs[0].shape[0:2] 121 | t_w, t_h = target_size 122 | p_w, p_h = target_size 123 | if w == t_w and h == t_h: 124 | return imgs 125 | 126 | t_h = t_h if t_h < h else h 127 | t_w = t_w if t_w < w else w 128 | 129 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 130 | # make sure to crop the text region 131 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 132 | tl[tl < 0] = 0 133 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 134 | br[br < 0] = 0 135 | br[0] = min(br[0], h - t_h) 136 | br[1] = min(br[1], w - t_w) 137 | 138 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 139 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 140 | else: 141 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 142 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 143 | 144 | n_imgs = [] 145 | for idx in range(len(imgs)): 146 | if len(imgs[idx].shape) == 3: 147 | s3_length = int(imgs[idx].shape[-1]) 148 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 149 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 150 | value=tuple(0 for i in range(s3_length))) 151 | else: 152 | img = imgs[idx][i:i + t_h, j:j + t_w] 153 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 154 | n_imgs.append(img_p) 155 | return n_imgs 156 | 157 | 158 | def update_word_mask(instance, instance_before_crop, word_mask): 159 | labels = np.unique(instance) 160 | 161 | for label in labels: 162 | if label == 0: 163 | continue 164 | ind = instance == label 165 | if np.sum(ind) == 0: 166 | word_mask[label] = 0 167 | continue 168 | ind_before_crop = instance_before_crop == label 169 | # print(np.sum(ind), np.sum(ind_before_crop)) 170 | if float(np.sum(ind)) / np.sum(ind_before_crop) > 0.9: 171 | continue 172 | word_mask[label] = 0 173 | 174 | return word_mask 175 | 176 | 177 | def dist(a, b): 178 | return np.linalg.norm((a - b), ord=2, axis=0) 179 | 180 | 181 | def perimeter(bbox): 182 | peri = 0.0 183 | for i in range(bbox.shape[0]): 184 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 185 | return peri 186 | 187 | 188 | def shrink(bboxes, rate, max_shr=20): 189 | rate = rate * rate 190 | shrinked_bboxes = [] 191 | for bbox in bboxes: 192 | x = bbox[:,0] 193 | y = bbox[:,1] 194 | area = PolyArea(x,y) 195 | #area = plg.Polygon(bbox).area() 196 | peri = perimeter(bbox) 197 | 198 | try: 199 | pco = pyclipper.PyclipperOffset() 200 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 201 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 202 | 203 | shrinked_bbox = pco.Execute(-offset) 204 | if len(shrinked_bbox) == 0: 205 | shrinked_bboxes.append(bbox) 206 | continue 207 | 208 | shrinked_bbox = np.array(shrinked_bbox)[0] 209 | if shrinked_bbox.shape[0] <= 2: 210 | shrinked_bboxes.append(bbox) 211 | continue 212 | 213 | shrinked_bboxes.append(shrinked_bbox) 214 | except Exception as e: 215 | print('area:', area, 'peri:', peri) 216 | shrinked_bboxes.append(bbox) 217 | 218 | return shrinked_bboxes 219 | 220 | 221 | def get_vocabulary(voc_type, EOS='EOS', PADDING='PAD', UNKNOWN='UNK'): 222 | if voc_type == 'LOWERCASE': 223 | voc = list(string.digits + string.ascii_lowercase) 224 | elif voc_type == 'ALLCASES': 225 | voc = list(string.digits + string.ascii_letters) 226 | elif voc_type == 'ALLCASES_SYMBOLS': 227 | voc = list(string.printable[:-6]) 228 | else: 229 | raise KeyError('voc_type must be one of "LOWERCASE", "ALLCASES", "ALLCASES_SYMBOLS"') 230 | 231 | # update the voc with specifical chars 232 | voc.append(EOS) 233 | voc.append(PADDING) 234 | voc.append(UNKNOWN) 235 | 236 | char2id = dict(zip(voc, range(len(voc)))) 237 | id2char = dict(zip(range(len(voc)), voc)) 238 | 239 | return voc, char2id, id2char 240 | 241 | 242 | class PAN_IC15(data.Dataset): 243 | def __init__(self, 244 | split='train', 245 | is_transform=False, 246 | img_size=None, 247 | short_size=736, 248 | kernel_scale=0.5, 249 | with_rec=False): 250 | self.split = split 251 | self.is_transform = is_transform 252 | 253 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 254 | self.kernel_scale = kernel_scale 255 | self.short_size = short_size 256 | self.with_rec = with_rec 257 | 258 | if split == 'train': 259 | data_dirs = [ic15_train_data_dir] 260 | gt_dirs = [ic15_train_gt_dir] 261 | elif split == 'test': 262 | data_dirs = [ic15_test_data_dir] 263 | gt_dirs = [ic15_test_gt_dir] 264 | else: 265 | print('Error: split must be train or test!') 266 | raise 267 | 268 | self.img_paths = [] 269 | self.gt_paths = [] 270 | 271 | for data_dir, gt_dir in zip(data_dirs, gt_dirs): 272 | #img_names = [img_name for img_name in mmcv.utils.scandir(data_dir, '.jpg')] 273 | #img_names.extend([img_name for img_name in mmcv.utils.scandir(data_dir, '.png')]) 274 | img_names = [img_name for img_name in os.listdir(data_dir) if img_name.endswith('.jpg') or img_name.endswith('.JPG')] 275 | 276 | img_paths = [] 277 | gt_paths = [] 278 | for idx, img_name in enumerate(img_names): 279 | img_path = data_dir + img_name 280 | img_paths.append(img_path) 281 | 282 | gt_name = 'gt_' + img_name.split('.')[0] + '.txt' 283 | gt_path = gt_dir + gt_name 284 | gt_paths.append(gt_path) 285 | 286 | self.img_paths.extend(img_paths) 287 | self.gt_paths.extend(gt_paths) 288 | 289 | ''' 290 | if report_speed: 291 | target_size = 3000 292 | extend_scale = (target_size + len(self.img_paths) - 1) // len(self.img_paths) 293 | self.img_paths = (self.img_paths * extend_scale)[:target_size] 294 | self.gt_paths = (self.gt_paths * extend_scale)[:target_size] 295 | ''' 296 | 297 | self.voc, self.char2id, self.id2char = get_vocabulary('LOWERCASE') 298 | self.max_word_num = 200 299 | self.max_word_len = 32 300 | 301 | def __len__(self): 302 | return len(self.img_paths) 303 | 304 | def prepare_train_data(self, index): 305 | img_path = self.img_paths[index] 306 | gt_path = self.gt_paths[index] 307 | 308 | img = get_img(img_path) 309 | bboxes, words = get_ann(img, gt_path) 310 | 311 | if bboxes.shape[0] > self.max_word_num: 312 | bboxes = bboxes[:self.max_word_num] 313 | words = words[:self.max_word_num] 314 | 315 | gt_words = np.full((self.max_word_num + 1, self.max_word_len), self.char2id['PAD'], dtype=np.int32) 316 | word_mask = np.zeros((self.max_word_num + 1,), dtype=np.int32) 317 | for i, word in enumerate(words): 318 | if word == '###': 319 | continue 320 | word = word.lower() 321 | gt_word = np.full((self.max_word_len,), self.char2id['PAD'], dtype=np.int) 322 | for j, char in enumerate(word): 323 | if j > self.max_word_len - 1: 324 | break 325 | if char in self.char2id: 326 | gt_word[j] = self.char2id[char] 327 | else: 328 | gt_word[j] = self.char2id['UNK'] 329 | if len(word) > self.max_word_len - 1: 330 | gt_word[-1] = self.char2id['EOS'] 331 | else: 332 | gt_word[len(word)] = self.char2id['EOS'] 333 | gt_words[i + 1] = gt_word 334 | word_mask[i + 1] = 1 335 | 336 | if self.is_transform: 337 | img = random_scale(img, self.short_size) 338 | 339 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 340 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 341 | if bboxes.shape[0] > 0: 342 | bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), 343 | (bboxes.shape[0], -1, 2)).astype('int32') 344 | for i in range(bboxes.shape[0]): 345 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 346 | if words[i] == '###': 347 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 348 | 349 | gt_kernels = [] 350 | for rate in [self.kernel_scale]: 351 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 352 | kernel_bboxes = shrink(bboxes, rate) 353 | for i in range(bboxes.shape[0]): 354 | cv2.drawContours(gt_kernel, [kernel_bboxes[i]], -1, 1, -1) 355 | gt_kernels.append(gt_kernel) 356 | 357 | if self.is_transform: 358 | imgs = [img, gt_instance, training_mask] 359 | imgs.extend(gt_kernels) 360 | 361 | imgs = random_horizontal_flip(imgs) 362 | imgs = random_rotate(imgs) 363 | gt_instance_before_crop = imgs[1].copy() 364 | imgs = random_crop_padding(imgs, self.img_size) 365 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 366 | word_mask = update_word_mask(gt_instance, gt_instance_before_crop, word_mask) 367 | 368 | gt_text = gt_instance.copy() 369 | gt_text[gt_text > 0] = 1 370 | gt_kernels = np.array(gt_kernels) 371 | 372 | max_instance = np.max(gt_instance) 373 | gt_bboxes = np.zeros((self.max_word_num + 1, 4), dtype=np.int32) 374 | for i in range(1, max_instance + 1): 375 | ind = gt_instance == i 376 | if np.sum(ind) == 0: 377 | continue 378 | points = np.array(np.where(ind)).transpose((1, 0)) 379 | tl = np.min(points, axis=0) 380 | br = np.max(points, axis=0) + 1 381 | gt_bboxes[i] = (tl[0], tl[1], br[0], br[1]) 382 | 383 | if self.is_transform: 384 | img = Image.fromarray(img) 385 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 386 | 387 | img = transforms.ToTensor()(img) 388 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 389 | gt_text = torch.from_numpy(gt_text).long() 390 | gt_kernels = torch.from_numpy(gt_kernels).long() 391 | training_mask = torch.from_numpy(training_mask).long() 392 | gt_instance = torch.from_numpy(gt_instance).long() 393 | gt_bboxes = torch.from_numpy(gt_bboxes).long() 394 | gt_words = torch.from_numpy(gt_words).long() 395 | word_mask = torch.from_numpy(word_mask).long() 396 | 397 | data = dict( 398 | imgs=img, 399 | gt_texts=gt_text, 400 | gt_kernels=gt_kernels, 401 | training_masks=training_mask, 402 | gt_instances=gt_instance, 403 | gt_bboxes=gt_bboxes, 404 | ) 405 | if self.with_rec: 406 | data.update(dict( 407 | gt_words=gt_words, 408 | word_masks=word_mask 409 | )) 410 | 411 | return data 412 | 413 | def prepare_test_data(self, index): 414 | img_path = self.img_paths[index] 415 | 416 | img = get_img(img_path) 417 | img_meta = dict( 418 | org_img_size=np.array(img.shape[:2]) 419 | ) 420 | 421 | img = scale_aligned_short(img, self.short_size) 422 | img_meta.update(dict( 423 | img_size=np.array(img.shape[:2]) 424 | )) 425 | 426 | img = transforms.ToTensor()(img) 427 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 428 | 429 | data = dict( 430 | imgs=img, 431 | img_metas=img_meta 432 | ) 433 | 434 | return data 435 | 436 | def __getitem__(self, index): 437 | if self.split == 'train': 438 | return self.prepare_train_data(index) 439 | elif self.split == 'test': 440 | return self.prepare_test_data(index) 441 | 442 | # unit testing 443 | if __name__ == '__main__': 444 | 445 | train_dataset = PAN_IC15(split='train', 446 | is_transform=False, 447 | img_size=None, 448 | short_size=736, 449 | kernel_scale=0.5, 450 | with_rec=True) 451 | 452 | for i, data in enumerate(train_dataset): 453 | # convert to numpy and plot 454 | print("Process image index:", i) 455 | imgs = data['imgs'] 456 | gt_texts = data['gt_texts'] 457 | gt_kernels = data['gt_kernels'] 458 | training_masks = data['training_masks'] 459 | gt_instances = data['gt_instances'] 460 | gt_bboxes = data['gt_bboxes'] 461 | gt_words = data['gt_words'] 462 | word_masks = data['word_masks'] 463 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 464 | gt_texts = gt_texts.detach().cpu().numpy() 465 | gt_kernels = gt_kernels.detach().cpu().numpy()[0] 466 | training_masks = training_masks.detach().cpu().numpy() 467 | gt_instances = gt_instances.detach().cpu().numpy() 468 | gt_words = gt_words.detach().cpu().numpy() 469 | word_masks = word_masks.detach().cpu().numpy() 470 | print(gt_words) 471 | print(word_masks) 472 | ''' 473 | plt.figure(1) 474 | plt.imshow(imgs) 475 | plt.figure(2) 476 | plt.imshow(gt_texts) 477 | plt.title('gt_texts') 478 | plt.figure(3) 479 | plt.imshow(gt_kernels) 480 | plt.title('gt_kernels') 481 | plt.figure(4) 482 | plt.imshow(training_masks) 483 | plt.title('training_masks') 484 | plt.figure(5) 485 | plt.imshow(gt_instances) 486 | plt.title('gt_instances') 487 | plt.show() 488 | pdb.set_trace() 489 | ''' 490 | test_dataset = PAN_IC15(split='test', 491 | is_transform=False, 492 | img_size=None, 493 | short_size=736, 494 | kernel_scale=0.5, 495 | with_rec=False) 496 | 497 | for i, data in enumerate(test_dataset): 498 | # convert to numpy and plot 499 | print("Process image index:", i) 500 | imgs = data['imgs'] 501 | img_metas = data['img_metas'] 502 | print(img_metas) 503 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 504 | plt.imshow(imgs) 505 | plt.show() 506 | pdb.set_trace() -------------------------------------------------------------------------------- /dataset/msra.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build data loader for MSRA-TD500 dataset. 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | from torch.utils import data 9 | import cv2 10 | import random 11 | import torchvision.transforms as transforms 12 | import torch 13 | import pyclipper 14 | import math 15 | import pdb 16 | import matplotlib.pyplot as plt 17 | 18 | msra_root_dir = './data/MSRA-TD500/' 19 | msra_train_data_dir = msra_root_dir + 'train/' 20 | msra_train_gt_dir = msra_root_dir + 'train/' 21 | msra_test_data_dir = msra_root_dir + 'test/' 22 | msra_test_gt_dir = msra_root_dir + 'test/' 23 | 24 | def PolyArea(x,y): 25 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 26 | 27 | def get_img(img_path): 28 | try: 29 | img = cv2.imread(img_path) 30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 31 | except Exception as e: 32 | print(img_path) 33 | raise 34 | return img 35 | 36 | 37 | def get_ann(img, gt_path): 38 | h, w = img.shape[0:2] 39 | #lines = mmcv.list_from_file(gt_path) # replaced 40 | with open(gt_path, "r") as file: 41 | lines = file.readlines() 42 | bboxes = [] 43 | words = [] 44 | for line in lines: 45 | line = line.encode('utf-8').decode('utf-8-sig') 46 | line = line.replace('\xef\xbb\xbf', '') 47 | 48 | gt = line.split(' ') 49 | 50 | w_ = np.float(gt[4]) 51 | h_ = np.float(gt[5]) 52 | x1 = np.float(gt[2]) + w_ / 2.0 53 | y1 = np.float(gt[3]) + h_ / 2.0 54 | theta = np.float(gt[6]) / math.pi * 180 55 | 56 | bbox = cv2.boxPoints(((x1, y1), (w_, h_), theta)) 57 | bbox = bbox.reshape(-1) / ([w * 1.0, h * 1.0] * 4) 58 | 59 | bboxes.append(bbox) 60 | words.append('???') 61 | return np.array(bboxes), words 62 | 63 | 64 | def random_horizontal_flip(imgs): 65 | if random.random() < 0.5: 66 | for i in range(len(imgs)): 67 | imgs[i] = np.flip(imgs[i], axis=1).copy() 68 | return imgs 69 | 70 | 71 | def random_rotate(imgs): 72 | max_angle = 10 73 | angle = random.random() * 2 * max_angle - max_angle 74 | for i in range(len(imgs)): 75 | img = imgs[i] 76 | w, h = img.shape[:2] 77 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 78 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 79 | imgs[i] = img_rotation 80 | return imgs 81 | 82 | 83 | def scale_aligned(img, h_scale, w_scale): 84 | h, w = img.shape[0:2] 85 | h = int(h * h_scale + 0.5) 86 | w = int(w * w_scale + 0.5) 87 | if h % 32 != 0: 88 | h = h + (32 - h % 32) 89 | if w % 32 != 0: 90 | w = w + (32 - w % 32) 91 | img = cv2.resize(img, dsize=(w, h)) 92 | return img 93 | 94 | 95 | def scale_aligned_short(img, short_size=736): 96 | h, w = img.shape[0:2] 97 | scale = short_size * 1.0 / min(h, w) 98 | h = int(h * scale + 0.5) 99 | w = int(w * scale + 0.5) 100 | if h % 32 != 0: 101 | h = h + (32 - h % 32) 102 | if w % 32 != 0: 103 | w = w + (32 - w % 32) 104 | img = cv2.resize(img, dsize=(w, h)) 105 | return img 106 | 107 | 108 | def random_scale(img, short_size=736): 109 | h, w = img.shape[0:2] 110 | 111 | scale = np.random.choice(np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])) 112 | scale = (scale * short_size) / min(h, w) 113 | 114 | aspect = np.random.choice(np.array([0.9, 0.95, 1.0, 1.05, 1.1])) 115 | h_scale = scale * math.sqrt(aspect) 116 | w_scale = scale / math.sqrt(aspect) 117 | 118 | img = scale_aligned(img, h_scale, w_scale) 119 | return img 120 | 121 | 122 | def random_crop_padding(imgs, target_size): 123 | h, w = imgs[0].shape[0:2] 124 | t_w, t_h = target_size 125 | p_w, p_h = target_size 126 | if w == t_w and h == t_h: 127 | return imgs 128 | 129 | t_h = t_h if t_h < h else h 130 | t_w = t_w if t_w < w else w 131 | 132 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 133 | # make sure to crop the text region 134 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 135 | tl[tl < 0] = 0 136 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 137 | br[br < 0] = 0 138 | br[0] = min(br[0], h - t_h) 139 | br[1] = min(br[1], w - t_w) 140 | 141 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 142 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 143 | else: 144 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 145 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 146 | 147 | n_imgs = [] 148 | for idx in range(len(imgs)): 149 | if len(imgs[idx].shape) == 3: 150 | s3_length = int(imgs[idx].shape[-1]) 151 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 152 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 153 | value=tuple(0 for i in range(s3_length))) 154 | else: 155 | img = imgs[idx][i:i + t_h, j:j + t_w] 156 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 157 | n_imgs.append(img_p) 158 | return n_imgs 159 | 160 | 161 | def dist(a, b): 162 | return np.linalg.norm((a - b), ord=2, axis=0) 163 | 164 | 165 | def perimeter(bbox): 166 | peri = 0.0 167 | for i in range(bbox.shape[0]): 168 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 169 | return peri 170 | 171 | 172 | def shrink(bboxes, rate, max_shr=20): 173 | rate = rate * rate 174 | shrinked_bboxes = [] 175 | for bbox in bboxes: 176 | x = bbox[:,0] 177 | y = bbox[:,1] 178 | area = PolyArea(x,y) 179 | #area = plg.Polygon(bbox).area() # replaced 180 | peri = perimeter(bbox) 181 | 182 | try: 183 | pco = pyclipper.PyclipperOffset() 184 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 185 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 186 | 187 | shrinked_bbox = pco.Execute(-offset) 188 | if len(shrinked_bbox) == 0: 189 | shrinked_bboxes.append(bbox) 190 | continue 191 | 192 | shrinked_bbox = np.array(shrinked_bbox)[0] 193 | if shrinked_bbox.shape[0] <= 2: 194 | shrinked_bboxes.append(bbox) 195 | continue 196 | 197 | shrinked_bboxes.append(shrinked_bbox) 198 | except Exception as e: 199 | print('area:', area, 'peri:', peri) 200 | shrinked_bboxes.append(bbox) 201 | 202 | return shrinked_bboxes 203 | 204 | 205 | class PAN_MSRA(data.Dataset): 206 | def __init__(self, 207 | split='train', 208 | is_transform=False, 209 | img_size=None, 210 | short_size=736, 211 | kernel_scale=0.5, 212 | report_speed=False): 213 | self.split = split 214 | self.is_transform = is_transform 215 | 216 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 217 | self.kernel_scale = kernel_scale 218 | self.short_size = short_size 219 | 220 | if split == 'train': 221 | data_dirs = [msra_train_data_dir] 222 | gt_dirs = [msra_train_gt_dir] 223 | elif split == 'test': 224 | data_dirs = [msra_test_data_dir] 225 | gt_dirs = [msra_test_gt_dir] 226 | else: 227 | print('Error: split must be train or test!') 228 | raise 229 | 230 | self.img_paths = [] 231 | self.gt_paths = [] 232 | 233 | for data_dir, gt_dir in zip(data_dirs, gt_dirs): 234 | #img_names = [img_name for img_name in mmcv.utils.scandir(data_dir) if img_name.endswith('.JPG')] 235 | #img_names.extend([img_name for img_name in mmcv.utils.scandir(data_dir) if img_name.endswith('.jpg')]) 236 | img_names = [img_name for img_name in os.listdir(data_dir) if img_name.endswith('.jpg') or img_name.endswith('.JPG')] 237 | 238 | img_paths = [] 239 | gt_paths = [] 240 | for idx, img_name in enumerate(img_names): 241 | img_path = data_dir + img_name 242 | img_paths.append(img_path) 243 | 244 | gt_name = img_name.split('.')[0] + '.gt' 245 | gt_path = gt_dir + gt_name 246 | gt_paths.append(gt_path) 247 | 248 | self.img_paths.extend(img_paths) 249 | self.gt_paths.extend(gt_paths) 250 | 251 | ''' 252 | if report_speed: 253 | target_size = 3000 254 | data_size = len(self.img_paths) 255 | extend_scale = (target_size + data_size - 1) // data_size 256 | self.img_paths = (self.img_paths * extend_scale)[:target_size] 257 | self.gt_paths = (self.gt_paths * extend_scale)[:target_size] 258 | ''' 259 | 260 | self.max_word_num = 200 261 | 262 | def __len__(self): 263 | return len(self.img_paths) 264 | 265 | def prepare_train_data(self, index): 266 | img_path = self.img_paths[index] 267 | gt_path = self.gt_paths[index] 268 | 269 | img = get_img(img_path) 270 | bboxes, words = get_ann(img, gt_path) 271 | 272 | if bboxes.shape[0] > self.max_word_num: 273 | bboxes = bboxes[:self.max_word_num] 274 | 275 | if self.is_transform: 276 | img = random_scale(img, self.short_size) 277 | 278 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 279 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 280 | if bboxes.shape[0] > 0: 281 | bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), 282 | (bboxes.shape[0], -1, 2)).astype('int32') 283 | for i in range(bboxes.shape[0]): 284 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 285 | if words[i] == '###': 286 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 287 | 288 | gt_kernels = [] 289 | for rate in [self.kernel_scale]: 290 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 291 | kernel_bboxes = shrink(bboxes, rate) 292 | for i in range(bboxes.shape[0]): 293 | cv2.drawContours(gt_kernel, [kernel_bboxes[i]], -1, 1, -1) 294 | gt_kernels.append(gt_kernel) 295 | 296 | if self.is_transform: 297 | imgs = [img, gt_instance, training_mask] 298 | imgs.extend(gt_kernels) 299 | 300 | imgs = random_horizontal_flip(imgs) 301 | imgs = random_rotate(imgs) 302 | imgs = random_crop_padding(imgs, self.img_size) 303 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 304 | 305 | gt_text = gt_instance.copy() 306 | gt_text[gt_text > 0] = 1 307 | gt_kernels = np.array(gt_kernels) 308 | 309 | max_instance = np.max(gt_instance) 310 | gt_bboxes = np.zeros((self.max_word_num + 1, 4), dtype=np.int32) 311 | for i in range(1, max_instance + 1): 312 | ind = gt_instance == i 313 | if np.sum(ind) == 0: 314 | continue 315 | points = np.array(np.where(ind)).transpose((1, 0)) 316 | tl = np.min(points, axis=0) 317 | br = np.max(points, axis=0) + 1 318 | gt_bboxes[i] = (tl[0], tl[1], br[0], br[1]) 319 | 320 | if self.is_transform: 321 | img = Image.fromarray(img) 322 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 323 | 324 | img = transforms.ToTensor()(img) 325 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 326 | 327 | gt_text = torch.from_numpy(gt_text).long() 328 | gt_kernels = torch.from_numpy(gt_kernels).long() 329 | training_mask = torch.from_numpy(training_mask).long() 330 | gt_instance = torch.from_numpy(gt_instance).long() 331 | gt_bboxes = torch.from_numpy(gt_bboxes).long() 332 | 333 | data = dict( 334 | imgs=img, 335 | gt_texts=gt_text, 336 | gt_kernels=gt_kernels, 337 | training_masks=training_mask, 338 | gt_instances=gt_instance, 339 | gt_bboxes=gt_bboxes, 340 | ) 341 | 342 | return data 343 | 344 | def prepare_test_data(self, index): 345 | img_path = self.img_paths[index] 346 | 347 | img = get_img(img_path) 348 | img_meta = dict( 349 | org_img_size=np.array(img.shape[:2]) 350 | ) 351 | 352 | img = scale_aligned_short(img, self.short_size) 353 | img_meta.update(dict( 354 | img_size=np.array(img.shape[:2]) 355 | )) 356 | 357 | img = transforms.ToTensor()(img) 358 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 359 | 360 | data = dict( 361 | imgs=img, 362 | img_metas=img_meta 363 | ) 364 | 365 | return data 366 | 367 | def __getitem__(self, index): 368 | if self.split == 'train': 369 | return self.prepare_train_data(index) 370 | elif self.split == 'test': 371 | return self.prepare_test_data(index) 372 | 373 | # unit testing 374 | if __name__ == '__main__': 375 | 376 | train_dataset = PAN_MSRA(split='train', 377 | is_transform=False, 378 | img_size=None, 379 | short_size=736, 380 | kernel_scale=0.5, 381 | report_speed=False) 382 | 383 | for i, data in enumerate(train_dataset): 384 | # convert to numpy and plot 385 | print("Process image index:", i) 386 | imgs = data['imgs'] 387 | gt_texts = data['gt_texts'] 388 | gt_kernels = data['gt_kernels'] 389 | training_masks = data['training_masks'] 390 | gt_instances = data['gt_instances'] 391 | gt_bboxes = data['gt_bboxes'] 392 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 393 | gt_texts = gt_texts.detach().cpu().numpy() 394 | gt_kernels = gt_kernels.detach().cpu().numpy()[0] 395 | training_masks = training_masks.detach().cpu().numpy() 396 | gt_instances = gt_instances.detach().cpu().numpy() 397 | plt.figure(1) 398 | plt.imshow(imgs) 399 | plt.figure(2) 400 | plt.imshow(gt_texts) 401 | plt.title('gt_texts') 402 | plt.figure(3) 403 | plt.imshow(gt_kernels) 404 | plt.title('gt_kernels') 405 | plt.figure(4) 406 | plt.imshow(training_masks) 407 | plt.title('training_masks') 408 | plt.figure(5) 409 | plt.imshow(gt_instances) 410 | plt.title('gt_instances') 411 | plt.show() 412 | pdb.set_trace() 413 | 414 | test_dataset = PAN_MSRA(split='test', 415 | is_transform=False, 416 | img_size=None, 417 | short_size=736, 418 | kernel_scale=0.5, 419 | report_speed=False) 420 | 421 | for i, data in enumerate(test_dataset): 422 | # convert to numpy and plot 423 | print("Process image index:", i) 424 | imgs = data['imgs'] 425 | img_metas = data['img_metas'] 426 | print(img_metas) 427 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 428 | ''' 429 | plt.imshow(imgs) 430 | plt.show() 431 | pdb.set_trace() 432 | ''' -------------------------------------------------------------------------------- /dataset/synthtext.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build data loader for synthtext dataset. 3 | ''' 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils import data 8 | import cv2 9 | import random 10 | import torchvision.transforms as transforms 11 | import torch 12 | import pyclipper 13 | import math 14 | import string 15 | import scipy.io as scio 16 | import matplotlib.pyplot as plt 17 | import pdb 18 | 19 | synth_root_dir = './data/SynthText/' 20 | synth_train_data_dir = synth_root_dir 21 | synth_train_gt_path = synth_root_dir + 'gt.mat' 22 | 23 | 24 | def PolyArea(x,y): 25 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 26 | 27 | def get_img(img_path): 28 | try: 29 | img = cv2.imread(img_path) 30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 31 | except Exception as e: 32 | print(img_path) 33 | raise 34 | return img 35 | 36 | 37 | def get_ann(img, gts, texts, index): 38 | bboxes = np.array(gts[index]) 39 | bboxes = np.reshape(bboxes, (bboxes.shape[0], bboxes.shape[1], -1)) 40 | bboxes = bboxes.transpose(2, 1, 0) 41 | bboxes = np.reshape(bboxes, (bboxes.shape[0], -1)) / ([img.shape[1], img.shape[0]] * 4) 42 | 43 | words = [] 44 | for text in texts[index]: 45 | text = text.replace('\n', ' ').replace('\r', ' ') 46 | words.extend([w for w in text.split(' ') if len(w) > 0]) 47 | 48 | return bboxes, words 49 | 50 | 51 | def random_horizontal_flip(imgs): 52 | if random.random() < 0.5: 53 | for i in range(len(imgs)): 54 | imgs[i] = np.flip(imgs[i], axis=1).copy() 55 | return imgs 56 | 57 | 58 | def random_rotate(imgs): 59 | max_angle = 10 60 | angle = random.random() * 2 * max_angle - max_angle 61 | for i in range(len(imgs)): 62 | img = imgs[i] 63 | w, h = img.shape[:2] 64 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 65 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 66 | imgs[i] = img_rotation 67 | return imgs 68 | 69 | 70 | def scale_aligned(img, h_scale, w_scale): 71 | h, w = img.shape[0:2] 72 | h = int(h * h_scale + 0.5) 73 | w = int(w * w_scale + 0.5) 74 | if h % 32 != 0: 75 | h = h + (32 - h % 32) 76 | if w % 32 != 0: 77 | w = w + (32 - w % 32) 78 | img = cv2.resize(img, dsize=(w, h)) 79 | return img 80 | 81 | 82 | def random_scale(img, short_size=736): 83 | h, w = img.shape[0:2] 84 | 85 | scale = np.random.choice(np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])) 86 | scale = (scale * short_size) / min(h, w) 87 | 88 | aspect = np.random.choice(np.array([0.9, 0.95, 1.0, 1.05, 1.1])) 89 | h_scale = scale * math.sqrt(aspect) 90 | w_scale = scale / math.sqrt(aspect) 91 | 92 | img = scale_aligned(img, h_scale, w_scale) 93 | return img 94 | 95 | 96 | def random_crop_padding(imgs, target_size): 97 | h, w = imgs[0].shape[0:2] 98 | t_w, t_h = target_size 99 | p_w, p_h = target_size 100 | if w == t_w and h == t_h: 101 | return imgs 102 | 103 | t_h = t_h if t_h < h else h 104 | t_w = t_w if t_w < w else w 105 | 106 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 107 | # make sure to crop the text region 108 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 109 | tl[tl < 0] = 0 110 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 111 | br[br < 0] = 0 112 | br[0] = min(br[0], h - t_h) 113 | br[1] = min(br[1], w - t_w) 114 | 115 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 116 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 117 | else: 118 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 119 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 120 | 121 | n_imgs = [] 122 | for idx in range(len(imgs)): 123 | if len(imgs[idx].shape) == 3: 124 | s3_length = int(imgs[idx].shape[-1]) 125 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 126 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 127 | value=tuple(0 for i in range(s3_length))) 128 | else: 129 | img = imgs[idx][i:i + t_h, j:j + t_w] 130 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 131 | n_imgs.append(img_p) 132 | return n_imgs 133 | 134 | 135 | def update_word_mask(instance, instance_before_crop, word_mask): 136 | labels = np.unique(instance) 137 | 138 | for label in labels: 139 | if label == 0: 140 | continue 141 | ind = instance == label 142 | if np.sum(ind) == 0: 143 | word_mask[label] = 0 144 | continue 145 | ind_before_crop = instance_before_crop == label 146 | # print(np.sum(ind), np.sum(ind_before_crop)) 147 | if float(np.sum(ind)) / np.sum(ind_before_crop) > 0.9: 148 | continue 149 | word_mask[label] = 0 150 | 151 | return word_mask 152 | 153 | 154 | def dist(a, b): 155 | return np.linalg.norm((a - b), ord=2, axis=0) 156 | 157 | 158 | def perimeter(bbox): 159 | peri = 0.0 160 | for i in range(bbox.shape[0]): 161 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 162 | return peri 163 | 164 | 165 | def shrink(bboxes, rate, max_shr=20): 166 | rate = rate * rate 167 | shrinked_bboxes = [] 168 | for bbox in bboxes:# Replace ply.Polygon with simple area calculation function 169 | #area = plg.Polygon(bbox).area() 170 | x = bbox[:,0] 171 | y = bbox[:,1] 172 | area = PolyArea(x,y) 173 | peri = perimeter(bbox) 174 | 175 | try: 176 | pco = pyclipper.PyclipperOffset() 177 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 178 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 179 | 180 | shrinked_bbox = pco.Execute(-offset) 181 | if len(shrinked_bbox) == 0: 182 | shrinked_bboxes.append(bbox) 183 | continue 184 | 185 | shrinked_bbox = np.array(shrinked_bbox[0]) 186 | if shrinked_bbox.shape[0] <= 2: 187 | shrinked_bboxes.append(bbox) 188 | continue 189 | 190 | shrinked_bboxes.append(shrinked_bbox) 191 | except Exception as e: 192 | print('area:', area, 'peri:', peri) 193 | shrinked_bboxes.append(bbox) 194 | 195 | return shrinked_bboxes 196 | 197 | 198 | def get_vocabulary(voc_type, EOS='EOS', PADDING='PAD', UNKNOWN='UNK'): 199 | if voc_type == 'LOWERCASE': 200 | voc = list(string.digits + string.ascii_lowercase) 201 | elif voc_type == 'ALLCASES': 202 | voc = list(string.digits + string.ascii_letters) 203 | elif voc_type == 'ALLCASES_SYMBOLS': 204 | voc = list(string.printable[:-5]) 205 | else: 206 | raise KeyError('voc_type must be one of "LOWERCASE", "ALLCASES", "ALLCASES_SYMBOLS"') 207 | 208 | # update the voc with specifical chars 209 | voc.append(EOS) 210 | voc.append(PADDING) 211 | voc.append(UNKNOWN) 212 | 213 | char2id = dict(zip(voc, range(len(voc)))) 214 | id2char = dict(zip(range(len(voc)), voc)) 215 | 216 | return voc, char2id, id2char 217 | 218 | 219 | class PAN_Synth(data.Dataset): 220 | def __init__(self, 221 | is_transform=False, 222 | img_size=None, 223 | short_size=736, 224 | kernel_scale=0.5, 225 | with_rec=False): 226 | self.is_transform = is_transform 227 | 228 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 229 | self.kernel_scale = kernel_scale 230 | self.short_size = short_size 231 | self.with_rec = with_rec 232 | 233 | data = scio.loadmat(synth_train_gt_path) 234 | 235 | self.img_paths = data['imnames'][0] 236 | self.gts = data['wordBB'][0] 237 | self.texts = data['txt'][0] 238 | 239 | self.voc, self.char2id, self.id2char = get_vocabulary('LOWERCASE') 240 | self.max_word_num = 200 241 | self.max_word_len = 32 242 | 243 | def __len__(self): 244 | return len(self.img_paths) 245 | 246 | def __getitem__(self, index): 247 | img_path = synth_train_data_dir + self.img_paths[index][0] 248 | img = get_img(img_path) 249 | bboxes, words = get_ann(img, self.gts, self.texts, index) 250 | 251 | if bboxes.shape[0] > self.max_word_num: 252 | bboxes = bboxes[:self.max_word_num] 253 | words = words[:self.max_word_num] 254 | 255 | gt_words = np.full((self.max_word_num, self.max_word_len), self.char2id['PAD'], dtype=np.int32) 256 | word_mask = np.zeros((self.max_word_num,), dtype=np.int32) 257 | for i, word in enumerate(words): 258 | if word == '###': 259 | continue 260 | word = word.lower() 261 | gt_word = np.full((self.max_word_len,), self.char2id['PAD'], dtype=np.int) 262 | for j, char in enumerate(word): 263 | if j > self.max_word_len - 1: 264 | break 265 | if char in self.char2id: 266 | gt_word[j] = self.char2id[char] 267 | else: 268 | gt_word[j] = self.char2id['UNK'] 269 | if len(word) > self.max_word_len - 1: 270 | gt_word[-1] = self.char2id['EOS'] 271 | else: 272 | gt_word[len(word)] = self.char2id['EOS'] 273 | gt_words[i + 1] = gt_word 274 | word_mask[i + 1] = 1 275 | 276 | if self.is_transform: 277 | img = random_scale(img, self.short_size) 278 | 279 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 280 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 281 | if bboxes.shape[0] > 0: 282 | bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), 283 | (bboxes.shape[0], -1, 2)).astype('int32') 284 | for i in range(bboxes.shape[0]): 285 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 286 | if words[i] == '###': 287 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 288 | 289 | gt_kernels = [] 290 | for rate in [self.kernel_scale]: 291 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 292 | kernel_bboxes = shrink(bboxes, rate) 293 | for i in range(bboxes.shape[0]): 294 | cv2.drawContours(gt_kernel, [kernel_bboxes[i]], -1, 1, -1) 295 | gt_kernels.append(gt_kernel) 296 | 297 | if self.is_transform: 298 | imgs = [img, gt_instance, training_mask] 299 | imgs.extend(gt_kernels) 300 | 301 | #if not self.with_rec: 302 | imgs = random_horizontal_flip(imgs) 303 | imgs = random_rotate(imgs) 304 | gt_instance_before_crop = imgs[1].copy() 305 | imgs = random_crop_padding(imgs, self.img_size) 306 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 307 | word_mask = update_word_mask(gt_instance, gt_instance_before_crop, word_mask) 308 | 309 | gt_text = gt_instance.copy() 310 | gt_text[gt_text > 0] = 1 311 | gt_kernels = np.array(gt_kernels) 312 | 313 | max_instance = np.max(gt_instance) 314 | gt_bboxes = np.zeros((self.max_word_num, 4), dtype=np.int32) 315 | for i in range(1, max_instance + 1): 316 | ind = gt_instance == i 317 | if np.sum(ind) == 0: 318 | continue 319 | points = np.array(np.where(ind)).transpose((1, 0)) 320 | tl = np.min(points, axis=0) 321 | br = np.max(points, axis=0) + 1 322 | gt_bboxes[i] = (tl[0], tl[1], br[0], br[1]) 323 | 324 | if self.is_transform: 325 | img = Image.fromarray(img) 326 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 327 | 328 | img = transforms.ToTensor()(img) 329 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 330 | 331 | gt_text = torch.from_numpy(gt_text).long() 332 | gt_kernels = torch.from_numpy(gt_kernels).long() 333 | training_mask = torch.from_numpy(training_mask).long() 334 | gt_instance = torch.from_numpy(gt_instance).long() 335 | gt_bboxes = torch.from_numpy(gt_bboxes).long() 336 | gt_words = torch.from_numpy(gt_words).long() 337 | word_mask = torch.from_numpy(word_mask).long() 338 | 339 | data = dict( 340 | imgs=img, 341 | gt_texts=gt_text, 342 | gt_kernels=gt_kernels, 343 | training_masks=training_mask, 344 | gt_instances=gt_instance, 345 | gt_bboxes=gt_bboxes, 346 | ) 347 | if self.with_rec: 348 | data.update(dict( 349 | gt_words=gt_words, 350 | word_masks=word_mask 351 | )) 352 | 353 | return data 354 | 355 | # unit testing 356 | if __name__ == '__main__': 357 | 358 | train_dataset = PAN_Synth(is_transform=False, 359 | img_size=None, 360 | short_size=640, 361 | kernel_scale=0.5, 362 | with_rec=True) 363 | 364 | for i, data in enumerate(train_dataset): 365 | # convert to numpy and plot 366 | print("Process image index:", i) 367 | imgs = data['imgs'] 368 | gt_texts = data['gt_texts'] 369 | gt_kernels = data['gt_kernels'] 370 | training_masks = data['training_masks'] 371 | gt_instances = data['gt_instances'] 372 | gt_bboxes = data['gt_bboxes'] 373 | gt_words = data['gt_words'] 374 | word_masks = data['word_masks'] 375 | imgs = imgs.permute(1,2,0).detach().cpu().numpy() 376 | gt_texts = gt_texts.detach().cpu().numpy() 377 | gt_kernels = gt_kernels.detach().cpu().numpy()[0] 378 | training_masks = training_masks.detach().cpu().numpy() 379 | gt_instances = gt_instances.detach().cpu().numpy() 380 | gt_words = gt_words.detach().cpu().numpy() 381 | word_masks = word_masks.detach().cpu().numpy() 382 | print(gt_words) 383 | print(word_masks) 384 | plt.figure(1) 385 | plt.imshow(imgs) 386 | plt.figure(2) 387 | plt.imshow(gt_texts) 388 | plt.title('gt_texts') 389 | plt.figure(3) 390 | plt.imshow(gt_kernels) 391 | plt.title('gt_kernels') 392 | plt.figure(4) 393 | plt.imshow(training_masks) 394 | plt.title('training_masks') 395 | plt.figure(5) 396 | plt.imshow(gt_instances) 397 | plt.title('gt_instances') 398 | plt.show() 399 | pdb.set_trace() -------------------------------------------------------------------------------- /dataset/testdataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build data loader for arbitrary test dataset. 3 | ''' 4 | 5 | import numpy as np 6 | from torch.utils import data 7 | import torchvision.transforms as transforms 8 | import cv2 9 | import torch 10 | import os 11 | 12 | def get_img(img_path): 13 | try: 14 | img = cv2.imread(img_path) 15 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 16 | except Exception as e: 17 | print(img_path) 18 | raise 19 | return img 20 | 21 | def scale_aligned_short(img, short_size=640): 22 | h, w = img.shape[0:2] 23 | scale = short_size * 1.0 / min(h, w) 24 | h = int(h * scale + 0.5) 25 | w = int(w * scale + 0.5) 26 | if h % 32 != 0: 27 | h = h + (32 - h % 32) 28 | if w % 32 != 0: 29 | w = w + (32 - w % 32) 30 | img = cv2.resize(img, dsize=(w, h)) 31 | return img 32 | 33 | class PAN_test(data.Dataset): 34 | def __init__(self, 35 | data_dirs=None, 36 | short_size=640): 37 | 38 | self.short_size = short_size 39 | 40 | self.img_paths = [] 41 | 42 | for data_dir in [data_dirs]: 43 | img_names = os.listdir(data_dir) 44 | 45 | img_paths = [] 46 | for idx, img_name in enumerate(img_names): 47 | img_path = os.path.join(data_dir, img_name) 48 | img_paths.append(img_path) 49 | 50 | self.img_paths.extend(img_paths) 51 | 52 | def __len__(self): 53 | return len(self.img_paths) 54 | 55 | def prepare_test_data(self, index): 56 | img_path = self.img_paths[index] 57 | 58 | img = get_img(img_path) 59 | img_meta = dict( 60 | org_img_size=np.array(img.shape[:2]) 61 | ) 62 | 63 | img = scale_aligned_short(img, self.short_size) 64 | img_meta.update(dict( 65 | img_size=np.array(img.shape[:2]) 66 | )) 67 | 68 | img = transforms.ToTensor()(img) 69 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 70 | 71 | data = dict( 72 | imgs=img, 73 | img_metas=img_meta 74 | ) 75 | 76 | return data 77 | 78 | def __getitem__(self, index): 79 | return self.prepare_test_data(index) -------------------------------------------------------------------------------- /eval/ctw/eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is the evaluation code modified from the originally released code. 3 | 1) Remove dependency on Polygon library 4 | 2) Fix input prediction format to x0,y0,x1,y1,... 5 | ''' 6 | 7 | import file_util 8 | import numpy as np 9 | import cv2 10 | 11 | project_root = '../../' 12 | 13 | pred_root = project_root + 'results/submit_ctw' 14 | gt_root = project_root + 'data/CTW1500/test/text_label_circum/' 15 | img_root = project_root + 'data/CTW1500/test/text_image/' 16 | 17 | def get_pred(path): 18 | lines = file_util.read_file(path).split('\n') 19 | bboxes = [] 20 | for line in lines: 21 | if line == '': 22 | continue 23 | bbox = line.split(',') 24 | if len(bbox) % 2 == 1: 25 | print(path) 26 | bbox = [int(x) for x in bbox] 27 | bboxes.append(bbox) 28 | return bboxes 29 | 30 | 31 | def get_gt(path): 32 | lines = file_util.read_file(path).split('\n') 33 | bboxes = [] 34 | for line in lines: 35 | if line == '': 36 | continue 37 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 38 | # gt = util.str.split(line, ',') 39 | gt = line.split(',') 40 | 41 | x1 = np.int(gt[0]) 42 | y1 = np.int(gt[1]) 43 | 44 | bbox = [np.int(gt[i]) for i in range(4, 32)] 45 | bbox = np.asarray(bbox) + ([x1, y1] * 14) 46 | 47 | bboxes.append(bbox) 48 | return bboxes 49 | 50 | 51 | def get_union(pD, pG, H, W): 52 | # replace original polygon library by opencv function 53 | blank = np.zeros((H, W)) 54 | image1 = cv2.fillPoly(blank.copy(), [pD], 1) 55 | image2 = cv2.fillPoly(blank.copy(), [pG], 1) 56 | areaA = np.sum(image1) 57 | areaB = np.sum(image2) 58 | 59 | return areaA + areaB - get_intersection(pD, pG, H, W) 60 | 61 | def get_intersection(pD, pG, H, W): 62 | # replace original polygon library by opencv function 63 | blank = np.zeros((H, W)) 64 | image1 = cv2.fillPoly(blank.copy(), [pD], 1) 65 | image2 = cv2.fillPoly(blank.copy(), [pG], 1) 66 | intersection = np.logical_and(image1, image2) 67 | 68 | return np.sum(intersection) 69 | 70 | if __name__ == '__main__': 71 | th = 0.5 72 | pred_list = file_util.read_dir(pred_root) 73 | 74 | tp, fp, npos = 0, 0, 0 75 | 76 | for pred_path in pred_list: 77 | print("evaluting predict path:", pred_path) 78 | preds = get_pred(pred_path) 79 | gt_path = gt_root + pred_path.split('/')[-1] 80 | img = cv2.imread(img_root + pred_path.split('/')[-1][:-4] + '.jpg') 81 | H, W, _ = img.shape 82 | gts = get_gt(gt_path) 83 | npos += len(gts) 84 | 85 | cover = set() 86 | for pred_id, pred in enumerate(preds): 87 | pred = np.array(pred) 88 | pred = pred.reshape(pred.shape[0] // 2, 2) 89 | 90 | #pred_p = plg.Polygon(pred) 91 | 92 | flag = False 93 | for gt_id, gt in enumerate(gts): 94 | gt = np.array(gt) 95 | gt = gt.reshape(gt.shape[0] // 2, 2) 96 | #gt_p = plg.Polygon(gt) 97 | 98 | union = get_union(pred, gt, H, W) 99 | inter = get_intersection(pred, gt, H, W) 100 | 101 | if inter * 1.0 / union >= th: 102 | if gt_id not in cover: 103 | flag = True 104 | cover.add(gt_id) 105 | if flag: 106 | tp += 1.0 107 | else: 108 | fp += 1.0 109 | 110 | # print tp, fp, npos 111 | precision = tp / (tp + fp) 112 | recall = tp / npos 113 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 114 | 115 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 116 | -------------------------------------------------------------------------------- /eval/ctw/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /eval/ic15/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/eval/ic15/gt.zip -------------------------------------------------------------------------------- /eval/ic15/rrc_evaluation_funcs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | # if withTranscription and withConfidence: 179 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | # if m == None : 181 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | # elif withConfidence: 183 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | # if m == None : 185 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | # elif withTranscription: 187 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | # if m == None : 189 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | # else: 191 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | # if m == None : 193 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | # points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | # print line 197 | nums = line.split(',')[:8] 198 | points = [(float)(nums[i]) for i in range(8)] 199 | 200 | # validate_clockwise_points(points) 201 | 202 | if (imWidth>0 and imHeight>0): 203 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 204 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 205 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 206 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 207 | 208 | 209 | # if withConfidence: 210 | # try: 211 | # confidence = float(m.group(numPoints+1)) 212 | # except ValueError: 213 | # raise Exception("Confidence value must be a float") 214 | 215 | if withTranscription: 216 | # posTranscription = numPoints + (2 if withConfidence else 1) 217 | # transcription = m.group(posTranscription) 218 | transcription = line.split(',')[-1] 219 | m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 220 | if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 221 | transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 222 | 223 | return points,confidence,transcription 224 | 225 | 226 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 227 | if(x<0 or x>imWidth): 228 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 229 | if(y<0 or y>imHeight): 230 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 231 | 232 | def validate_clockwise_points(points): 233 | """ 234 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 235 | """ 236 | 237 | if len(points) != 8: 238 | raise Exception("Points list not valid." + str(len(points))) 239 | 240 | point = [ 241 | [int(points[0]) , int(points[1])], 242 | [int(points[2]) , int(points[3])], 243 | [int(points[4]) , int(points[5])], 244 | [int(points[6]) , int(points[7])] 245 | ] 246 | edge = [ 247 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 248 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 249 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 250 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 251 | ] 252 | 253 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 254 | if summatory>0: 255 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 256 | 257 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 258 | """ 259 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 260 | xmin,ymin,xmax,ymax,[confidence],[transcription] 261 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 262 | """ 263 | pointsList = [] 264 | transcriptionsList = [] 265 | confidencesList = [] 266 | 267 | lines = content.split( "\r\n" if CRLF else "\n" ) 268 | for line in lines: 269 | line = line.replace("\r","").replace("\n","") 270 | if(line != "") : 271 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 272 | pointsList.append(points) 273 | transcriptionsList.append(transcription) 274 | confidencesList.append(confidence) 275 | 276 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 277 | import numpy as np 278 | sorted_ind = np.argsort(-np.array(confidencesList)) 279 | confidencesList = [confidencesList[i] for i in sorted_ind] 280 | pointsList = [pointsList[i] for i in sorted_ind] 281 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 282 | 283 | return pointsList,confidencesList,transcriptionsList 284 | 285 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 286 | """ 287 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 288 | Params: 289 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 290 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 291 | validate_data_fn: points to a method that validates the corrct format of the submission 292 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 293 | """ 294 | 295 | if (p == None): 296 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 297 | if(len(sys.argv)<3): 298 | print_help() 299 | 300 | evalParams = default_evaluation_params_fn() 301 | if 'p' in p.keys(): 302 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 303 | 304 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 305 | try: 306 | validate_data_fn(p['g'], p['s'], evalParams) 307 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 308 | resDict.update(evalData) 309 | 310 | except Exception, e: 311 | resDict['Message']= str(e) 312 | resDict['calculated']=False 313 | 314 | if 'o' in p: 315 | if not os.path.exists(p['o']): 316 | os.makedirs(p['o']) 317 | 318 | resultsOutputname = p['o'] + '/results.zip' 319 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 320 | 321 | del resDict['per_sample'] 322 | if 'output_items' in resDict.keys(): 323 | del resDict['output_items'] 324 | 325 | outZip.writestr('method.json',json.dumps(resDict)) 326 | 327 | if not resDict['calculated']: 328 | if show_result: 329 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 330 | if 'o' in p: 331 | outZip.close() 332 | return resDict 333 | 334 | if 'o' in p: 335 | if per_sample == True: 336 | for k,v in evalData['per_sample'].iteritems(): 337 | outZip.writestr( k + '.json',json.dumps(v)) 338 | 339 | if 'output_items' in evalData.keys(): 340 | for k, v in evalData['output_items'].iteritems(): 341 | outZip.writestr( k,v) 342 | 343 | outZip.close() 344 | 345 | if show_result: 346 | sys.stdout.write("Calculated!") 347 | sys.stdout.write(json.dumps(resDict['method'])) 348 | 349 | return resDict 350 | 351 | 352 | def main_validation(default_evaluation_params_fn,validate_data_fn): 353 | """ 354 | This process validates a method 355 | Params: 356 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 357 | validate_data_fn: points to a method that validates the corrct format of the submission 358 | """ 359 | try: 360 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 361 | evalParams = default_evaluation_params_fn() 362 | if 'p' in p.keys(): 363 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 364 | 365 | validate_data_fn(p['g'], p['s'], evalParams) 366 | print 'SUCCESS' 367 | sys.exit(0) 368 | except Exception as e: 369 | print str(e) 370 | sys.exit(101) -------------------------------------------------------------------------------- /eval/ic15/rrc_evaluation_funcs_v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | if withTranscription and withConfidence: 179 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | if m == None : 181 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | elif withConfidence: 183 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | if m == None : 185 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | elif withTranscription: 187 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | if m == None : 189 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | else: 191 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | if m == None : 193 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | 197 | validate_clockwise_points(points) 198 | 199 | if (imWidth>0 and imHeight>0): 200 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 201 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 202 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 203 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 204 | 205 | 206 | if withConfidence: 207 | try: 208 | confidence = float(m.group(numPoints+1)) 209 | except ValueError: 210 | raise Exception("Confidence value must be a float") 211 | 212 | if withTranscription: 213 | posTranscription = numPoints + (2 if withConfidence else 1) 214 | transcription = m.group(posTranscription) 215 | m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 216 | if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 217 | transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 218 | 219 | return points,confidence,transcription 220 | 221 | 222 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 223 | if(x<0 or x>imWidth): 224 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 225 | if(y<0 or y>imHeight): 226 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 227 | 228 | def validate_clockwise_points(points): 229 | """ 230 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 231 | """ 232 | 233 | if len(points) != 8: 234 | raise Exception("Points list not valid." + str(len(points))) 235 | 236 | point = [ 237 | [int(points[0]) , int(points[1])], 238 | [int(points[2]) , int(points[3])], 239 | [int(points[4]) , int(points[5])], 240 | [int(points[6]) , int(points[7])] 241 | ] 242 | edge = [ 243 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 244 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 245 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 246 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 247 | ] 248 | 249 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 250 | if summatory>0: 251 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 252 | 253 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 254 | """ 255 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 256 | xmin,ymin,xmax,ymax,[confidence],[transcription] 257 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 258 | """ 259 | pointsList = [] 260 | transcriptionsList = [] 261 | confidencesList = [] 262 | 263 | lines = content.split( "\r\n" if CRLF else "\n" ) 264 | for line in lines: 265 | line = line.replace("\r","").replace("\n","") 266 | if(line != "") : 267 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 268 | pointsList.append(points) 269 | transcriptionsList.append(transcription) 270 | confidencesList.append(confidence) 271 | 272 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 273 | import numpy as np 274 | sorted_ind = np.argsort(-np.array(confidencesList)) 275 | confidencesList = [confidencesList[i] for i in sorted_ind] 276 | pointsList = [pointsList[i] for i in sorted_ind] 277 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 278 | 279 | return pointsList,confidencesList,transcriptionsList 280 | 281 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 282 | """ 283 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 284 | Params: 285 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 286 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 287 | validate_data_fn: points to a method that validates the corrct format of the submission 288 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 289 | """ 290 | 291 | if (p == None): 292 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 293 | if(len(sys.argv)<3): 294 | print_help() 295 | 296 | evalParams = default_evaluation_params_fn() 297 | if 'p' in p.keys(): 298 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 299 | 300 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 301 | try: 302 | validate_data_fn(p['g'], p['s'], evalParams) 303 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 304 | resDict.update(evalData) 305 | 306 | except Exception, e: 307 | resDict['Message']= str(e) 308 | resDict['calculated']=False 309 | 310 | if 'o' in p: 311 | if not os.path.exists(p['o']): 312 | os.makedirs(p['o']) 313 | 314 | resultsOutputname = p['o'] + '/results.zip' 315 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 316 | 317 | del resDict['per_sample'] 318 | if 'output_items' in resDict.keys(): 319 | del resDict['output_items'] 320 | 321 | outZip.writestr('method.json',json.dumps(resDict)) 322 | 323 | if not resDict['calculated']: 324 | if show_result: 325 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 326 | if 'o' in p: 327 | outZip.close() 328 | return resDict 329 | 330 | if 'o' in p: 331 | if per_sample == True: 332 | for k,v in evalData['per_sample'].iteritems(): 333 | outZip.writestr( k + '.json',json.dumps(v)) 334 | 335 | if 'output_items' in evalData.keys(): 336 | for k, v in evalData['output_items'].iteritems(): 337 | outZip.writestr( k,v) 338 | 339 | outZip.close() 340 | 341 | if show_result: 342 | sys.stdout.write("Calculated!") 343 | sys.stdout.write(json.dumps(resDict['method'])) 344 | 345 | return resDict 346 | 347 | 348 | def main_validation(default_evaluation_params_fn,validate_data_fn): 349 | """ 350 | This process validates a method 351 | Params: 352 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 353 | validate_data_fn: points to a method that validates the corrct format of the submission 354 | """ 355 | try: 356 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 357 | evalParams = default_evaluation_params_fn() 358 | if 'p' in p.keys(): 359 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 360 | 361 | validate_data_fn(p['g'], p['s'], evalParams) 362 | print 'SUCCESS' 363 | sys.exit(0) 364 | except Exception as e: 365 | print str(e) 366 | sys.exit(101) -------------------------------------------------------------------------------- /eval/ic15/rrc_evaluation_funcs_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | # if withTranscription and withConfidence: 179 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | # if m == None : 181 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | # elif withConfidence: 183 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | # if m == None : 185 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | # elif withTranscription: 187 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | # if m == None : 189 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | # else: 191 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | # if m == None : 193 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | # points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | # print line 197 | nums = line.split(',')[:8] 198 | points = [(float)(nums[i]) for i in range(8)] 199 | 200 | # validate_clockwise_points(points) 201 | 202 | if (imWidth>0 and imHeight>0): 203 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 204 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 205 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 206 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 207 | 208 | 209 | # if withConfidence: 210 | # try: 211 | # confidence = float(m.group(numPoints+1)) 212 | # except ValueError: 213 | # raise Exception("Confidence value must be a float") 214 | 215 | # if withTranscription: 216 | # posTranscription = numPoints + (2 if withConfidence else 1) 217 | # transcription = m.group(posTranscription) 218 | # m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 219 | # if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 220 | # transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 221 | 222 | return points,confidence,transcription 223 | 224 | 225 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 226 | if(x<0 or x>imWidth): 227 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 228 | if(y<0 or y>imHeight): 229 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 230 | 231 | def validate_clockwise_points(points): 232 | """ 233 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 234 | """ 235 | 236 | if len(points) != 8: 237 | raise Exception("Points list not valid." + str(len(points))) 238 | 239 | point = [ 240 | [int(points[0]) , int(points[1])], 241 | [int(points[2]) , int(points[3])], 242 | [int(points[4]) , int(points[5])], 243 | [int(points[6]) , int(points[7])] 244 | ] 245 | edge = [ 246 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 247 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 248 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 249 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 250 | ] 251 | 252 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 253 | if summatory>0: 254 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 255 | 256 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 257 | """ 258 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 259 | xmin,ymin,xmax,ymax,[confidence],[transcription] 260 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 261 | """ 262 | pointsList = [] 263 | transcriptionsList = [] 264 | confidencesList = [] 265 | 266 | lines = content.split( "\r\n" if CRLF else "\n" ) 267 | for line in lines: 268 | line = line.replace("\r","").replace("\n","") 269 | if(line != "") : 270 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 271 | pointsList.append(points) 272 | transcriptionsList.append(transcription) 273 | confidencesList.append(confidence) 274 | 275 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 276 | import numpy as np 277 | sorted_ind = np.argsort(-np.array(confidencesList)) 278 | confidencesList = [confidencesList[i] for i in sorted_ind] 279 | pointsList = [pointsList[i] for i in sorted_ind] 280 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 281 | 282 | return pointsList,confidencesList,transcriptionsList 283 | 284 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 285 | """ 286 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 287 | Params: 288 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 289 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 290 | validate_data_fn: points to a method that validates the corrct format of the submission 291 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 292 | """ 293 | 294 | if (p == None): 295 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 296 | if(len(sys.argv)<3): 297 | print_help() 298 | 299 | evalParams = default_evaluation_params_fn() 300 | if 'p' in p.keys(): 301 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 302 | 303 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 304 | try: 305 | validate_data_fn(p['g'], p['s'], evalParams) 306 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 307 | resDict.update(evalData) 308 | 309 | except Exception, e: 310 | resDict['Message']= str(e) 311 | resDict['calculated']=False 312 | 313 | if 'o' in p: 314 | if not os.path.exists(p['o']): 315 | os.makedirs(p['o']) 316 | 317 | resultsOutputname = p['o'] + '/results.zip' 318 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 319 | 320 | del resDict['per_sample'] 321 | if 'output_items' in resDict.keys(): 322 | del resDict['output_items'] 323 | 324 | outZip.writestr('method.json',json.dumps(resDict)) 325 | 326 | if not resDict['calculated']: 327 | if show_result: 328 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 329 | if 'o' in p: 330 | outZip.close() 331 | return resDict 332 | 333 | if 'o' in p: 334 | if per_sample == True: 335 | for k,v in evalData['per_sample'].iteritems(): 336 | outZip.writestr( k + '.json',json.dumps(v)) 337 | 338 | if 'output_items' in evalData.keys(): 339 | for k, v in evalData['output_items'].iteritems(): 340 | outZip.writestr( k,v) 341 | 342 | outZip.close() 343 | 344 | if show_result: 345 | sys.stdout.write("Calculated!") 346 | sys.stdout.write(json.dumps(resDict['method'])) 347 | 348 | return resDict 349 | 350 | 351 | def main_validation(default_evaluation_params_fn,validate_data_fn): 352 | """ 353 | This process validates a method 354 | Params: 355 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 356 | validate_data_fn: points to a method that validates the corrct format of the submission 357 | """ 358 | try: 359 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 360 | evalParams = default_evaluation_params_fn() 361 | if 'p' in p.keys(): 362 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 363 | 364 | validate_data_fn(p['g'], p['s'], evalParams) 365 | print 'SUCCESS' 366 | sys.exit(0) 367 | except Exception as e: 368 | print str(e) 369 | sys.exit(101) -------------------------------------------------------------------------------- /eval/ic15/script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from collections import namedtuple 4 | import rrc_evaluation_funcs 5 | import importlib 6 | 7 | def evaluation_imports(): 8 | """ 9 | evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. 10 | """ 11 | return { 12 | 'Polygon':'plg', 13 | 'numpy':'np' 14 | } 15 | 16 | def default_evaluation_params(): 17 | """ 18 | default_evaluation_params: Default parameters to use for the validation and evaluation. 19 | """ 20 | return { 21 | 'IOU_CONSTRAINT' :0.5, 22 | 'AREA_PRECISION_CONSTRAINT' :0.5, 23 | 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt', 24 | 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt', 25 | 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 26 | 'CRLF':False, # Lines are delimited by Windows CRLF format 27 | 'CONFIDENCES':False, #Detections must include confidence value. AP will be calculated 28 | 'PER_SAMPLE_RESULTS':True #Generate per sample results and produce data for visualization 29 | } 30 | 31 | def validate_data(gtFilePath, submFilePath,evaluationParams): 32 | """ 33 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 34 | Validates also that there are no missing files in the folder. 35 | If some error detected, the method raises the error 36 | """ 37 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) 38 | 39 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) 40 | 41 | #Validate format of GroundTruth 42 | for k in gt: 43 | rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True) 44 | 45 | #Validate format of results 46 | for k in subm: 47 | if (k in gt) == False : 48 | raise Exception("The sample %s not present in GT" %k) 49 | 50 | rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES']) 51 | 52 | 53 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 54 | """ 55 | Method evaluate_method: evaluate method and returns the results 56 | Results. Dictionary with the following values: 57 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 58 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 59 | """ 60 | 61 | for module,alias in evaluation_imports().iteritems(): 62 | globals()[alias] = importlib.import_module(module) 63 | 64 | def polygon_from_points(points): 65 | """ 66 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 67 | """ 68 | # resBoxes=np.empty([1,8],dtype='int32') 69 | # resBoxes[0,0]=int(points[0]) 70 | # resBoxes[0,4]=int(points[1]) 71 | # resBoxes[0,1]=int(points[2]) 72 | # resBoxes[0,5]=int(points[3]) 73 | # resBoxes[0,2]=int(points[4]) 74 | # resBoxes[0,6]=int(points[5]) 75 | # resBoxes[0,3]=int(points[6]) 76 | # resBoxes[0,7]=int(points[7]) 77 | # pointMat = resBoxes[0].reshape([2,4]).T 78 | # return plg.Polygon( pointMat) 79 | 80 | p = np.array(points) 81 | p = p.reshape(p.shape[0]//2, 2) 82 | p = plg.Polygon(p) 83 | return p 84 | 85 | def rectangle_to_polygon(rect): 86 | resBoxes=np.empty([1,8],dtype='int32') 87 | resBoxes[0,0]=int(rect.xmin) 88 | resBoxes[0,4]=int(rect.ymax) 89 | resBoxes[0,1]=int(rect.xmin) 90 | resBoxes[0,5]=int(rect.ymin) 91 | resBoxes[0,2]=int(rect.xmax) 92 | resBoxes[0,6]=int(rect.ymin) 93 | resBoxes[0,3]=int(rect.xmax) 94 | resBoxes[0,7]=int(rect.ymax) 95 | 96 | pointMat = resBoxes[0].reshape([2,4]).T 97 | 98 | return plg.Polygon( pointMat) 99 | 100 | def rectangle_to_points(rect): 101 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)] 102 | return points 103 | 104 | def get_union(pD,pG): 105 | areaA = pD.area(); 106 | areaB = pG.area(); 107 | return areaA + areaB - get_intersection(pD, pG); 108 | 109 | def get_intersection_over_union(pD,pG): 110 | try: 111 | return get_intersection(pD, pG) / get_union(pD, pG); 112 | except: 113 | return 0 114 | 115 | def get_intersection(pD,pG): 116 | pInt = pD & pG 117 | if len(pInt) == 0: 118 | return 0 119 | return pInt.area() 120 | 121 | def compute_ap(confList, matchList,numGtCare): 122 | correct = 0 123 | AP = 0 124 | if len(confList)>0: 125 | confList = np.array(confList) 126 | matchList = np.array(matchList) 127 | sorted_ind = np.argsort(-confList) 128 | confList = confList[sorted_ind] 129 | matchList = matchList[sorted_ind] 130 | for n in range(len(confList)): 131 | match = matchList[n] 132 | if match: 133 | correct += 1 134 | AP += float(correct)/(n + 1) 135 | 136 | if numGtCare>0: 137 | AP /= numGtCare 138 | 139 | return AP 140 | 141 | perSampleMetrics = {} 142 | 143 | matchedSum = 0 144 | 145 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 146 | 147 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) 148 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) 149 | 150 | numGlobalCareGt = 0; 151 | numGlobalCareDet = 0; 152 | 153 | arrGlobalConfidences = []; 154 | arrGlobalMatches = []; 155 | 156 | for resFile in gt: 157 | 158 | gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) 159 | recall = 0 160 | precision = 0 161 | hmean = 0 162 | 163 | detMatched = 0 164 | 165 | iouMat = np.empty([1,1]) 166 | 167 | gtPols = [] 168 | detPols = [] 169 | 170 | gtPolPoints = [] 171 | detPolPoints = [] 172 | 173 | #Array of Ground Truth Polygons' keys marked as don't Care 174 | gtDontCarePolsNum = [] 175 | #Array of Detected Polygons' matched with a don't Care GT 176 | detDontCarePolsNum = [] 177 | 178 | pairs = [] 179 | detMatchedNums = [] 180 | 181 | arrSampleConfidences = []; 182 | arrSampleMatch = []; 183 | sampleAP = 0; 184 | 185 | evaluationLog = "" 186 | 187 | pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False) 188 | for n in range(len(pointsList)): 189 | points = pointsList[n] 190 | transcription = transcriptionsList[n] 191 | dontCare = transcription == "###" 192 | if evaluationParams['LTRB']: 193 | gtRect = Rectangle(*points) 194 | gtPol = rectangle_to_polygon(gtRect) 195 | else: 196 | gtPol = polygon_from_points(points) 197 | gtPols.append(gtPol) 198 | gtPolPoints.append(points) 199 | if dontCare: 200 | gtDontCarePolsNum.append( len(gtPols)-1 ) 201 | 202 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n") 203 | 204 | if resFile in subm: 205 | 206 | detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) 207 | def get_pred(file): 208 | lines = file.split('\n') 209 | pointsList = [] 210 | for line in lines: 211 | if line == '': 212 | continue 213 | bbox = line.split(',') 214 | if len(bbox) % 2 == 1: 215 | print(path) 216 | bbox = [int(x) for x in bbox] 217 | pointsList.append(bbox) 218 | return pointsList 219 | 220 | # pointsList,confidencesList,_ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES']) 221 | # print(pointsList) 222 | # print(confidencesList) 223 | 224 | pointsList = get_pred(detFile) 225 | confidencesList = [0.0] * len(pointsList) 226 | 227 | for n in range(len(pointsList)): 228 | points = pointsList[n] 229 | 230 | if evaluationParams['LTRB']: 231 | detRect = Rectangle(*points) 232 | detPol = rectangle_to_polygon(detRect) 233 | else: 234 | detPol = polygon_from_points(points) 235 | detPols.append(detPol) 236 | detPolPoints.append(points) 237 | if len(gtDontCarePolsNum)>0 : 238 | for dontCarePol in gtDontCarePolsNum: 239 | dontCarePol = gtPols[dontCarePol] 240 | intersected_area = get_intersection(dontCarePol,detPol) 241 | pdDimensions = detPol.area() 242 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 243 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ): 244 | detDontCarePolsNum.append( len(detPols)-1 ) 245 | break 246 | 247 | evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n") 248 | 249 | if len(gtPols)>0 and len(detPols)>0: 250 | #Calculate IoU and precision matrixs 251 | outputShape=[len(gtPols),len(detPols)] 252 | iouMat = np.empty(outputShape) 253 | gtRectMat = np.zeros(len(gtPols),np.int8) 254 | detRectMat = np.zeros(len(detPols),np.int8) 255 | for gtNum in range(len(gtPols)): 256 | for detNum in range(len(detPols)): 257 | pG = gtPols[gtNum] 258 | pD = detPols[detNum] 259 | iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG) 260 | 261 | for gtNum in range(len(gtPols)): 262 | for detNum in range(len(detPols)): 263 | if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum : 264 | if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']: 265 | gtRectMat[gtNum] = 1 266 | detRectMat[detNum] = 1 267 | detMatched += 1 268 | pairs.append({'gt':gtNum,'det':detNum}) 269 | detMatchedNums.append(detNum) 270 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" 271 | 272 | if evaluationParams['CONFIDENCES']: 273 | for detNum in range(len(detPols)): 274 | if detNum not in detDontCarePolsNum : 275 | #we exclude the don't care detections 276 | match = detNum in detMatchedNums 277 | 278 | arrSampleConfidences.append(confidencesList[detNum]) 279 | arrSampleMatch.append(match) 280 | 281 | arrGlobalConfidences.append(confidencesList[detNum]); 282 | arrGlobalMatches.append(match); 283 | 284 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 285 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 286 | if numGtCare == 0: 287 | recall = float(1) 288 | precision = float(0) if numDetCare >0 else float(1) 289 | sampleAP = precision 290 | else: 291 | recall = float(detMatched) / numGtCare 292 | precision = 0 if numDetCare==0 else float(detMatched) / numDetCare 293 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: 294 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare ) 295 | 296 | hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall) 297 | 298 | matchedSum += detMatched 299 | numGlobalCareGt += numGtCare 300 | numGlobalCareDet += numDetCare 301 | 302 | if evaluationParams['PER_SAMPLE_RESULTS']: 303 | perSampleMetrics[resFile] = { 304 | 'precision':precision, 305 | 'recall':recall, 306 | 'hmean':hmean, 307 | 'pairs':pairs, 308 | 'AP':sampleAP, 309 | 'iouMat':[] if len(detPols)>100 else iouMat.tolist(), 310 | 'gtPolPoints':gtPolPoints, 311 | 'detPolPoints':detPolPoints, 312 | 'gtDontCare':gtDontCarePolsNum, 313 | 'detDontCare':detDontCarePolsNum, 314 | 'evaluationParams': evaluationParams, 315 | 'evaluationLog': evaluationLog 316 | } 317 | 318 | # Compute MAP and MAR 319 | AP = 0 320 | if evaluationParams['CONFIDENCES']: 321 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) 322 | 323 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt 324 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet 325 | methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision) 326 | 327 | methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP } 328 | 329 | resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics} 330 | 331 | return resDict 332 | 333 | 334 | 335 | if __name__=='__main__': 336 | 337 | rrc_evaluation_funcs.main_evaluation(None,default_evaluation_params,validate_data,evaluate_method) 338 | print('') 339 | -------------------------------------------------------------------------------- /eval/ic15/script_self_adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Hyperparams') 6 | # parser.add_argument('--gt', nargs='?', type=str, default=None) 7 | parser.add_argument('--pred', nargs='?', type=str, default=None) 8 | args = parser.parse_args() 9 | 10 | output_root = '../outputs/tmp_results/' 11 | pred = mmcv.load(args.pred) 12 | 13 | def write_result_as_txt(image_name, bboxes, path, words=None): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | file_path = path + 'res_%s.txt'%(image_name) 18 | lines = [] 19 | for i, bbox in enumerate(bboxes): 20 | values = [int(v) for v in bbox] 21 | if words is None: 22 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n"%tuple(values) 23 | lines.append(line) 24 | elif words[i] is not None: 25 | line = "%d,%d,%d,%d,%d,%d,%d,%d"%tuple(values) + ",%s\n"%words[i] 26 | lines.append(line) 27 | with open(file_path, 'w') as f: 28 | for line in lines: 29 | f.write(line) 30 | 31 | def eval(thr): 32 | for key in pred: 33 | pred_ = pred[key] 34 | line_num = len(pred_['scores']) 35 | bboxes = [] 36 | # words = [] 37 | for i in range(line_num): 38 | if pred_['scores'][i] < thr: 39 | continue 40 | bboxes.append(pred_['bboxes'][i]) 41 | # words.append(pred_['words'][i]) 42 | 43 | write_result_as_txt(key, bboxes, output_root) 44 | 45 | cmd = 'cd %s;zip -j %s %s/*' % ('../outputs/', 'tmp_results.zip', 'tmp_results') 46 | res_cmd = os.popen(cmd) 47 | res_cmd.read() 48 | 49 | cmd = 'cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/tmp_results.zip && cd ..' 50 | res_cmd = os.popen(cmd) 51 | res_cmd = res_cmd.read() 52 | h_mean = float(res_cmd.split(',')[-2].split(':')[-1]) 53 | return res_cmd, h_mean 54 | 55 | max_h_mean = 0 56 | best_thr = 0 57 | best_res = '' 58 | for i in range(85, 100): 59 | thr = float(i) / 100 60 | # print('Testing thr: %f'%thr) 61 | res, h_mean = eval(thr) 62 | # print(thr, h_mean) 63 | if h_mean > max_h_mean: 64 | max_h_mean = h_mean 65 | best_thr = thr 66 | best_res = res 67 | 68 | print('thr: %f | %s'%(best_thr, best_res)) 69 | -------------------------------------------------------------------------------- /eval/msra/eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is the evaluation code modified from the originally released code. 3 | 1) Remove dependency on Polygon library 4 | 2) Fix input prediction format to x0,y0,x1,y1,... 5 | ''' 6 | 7 | import file_util 8 | import numpy as np 9 | import math 10 | import cv2 11 | 12 | project_root = '../../' 13 | 14 | pred_root = project_root + 'results/submit_msra/' 15 | gt_root = project_root + 'data/MSRA-TD500/test/' 16 | 17 | 18 | def get_pred(path): 19 | lines = file_util.read_file(path).split('\n') 20 | bboxes = [] 21 | for line in lines: 22 | if line == '': 23 | continue 24 | bbox = line.split(',') 25 | if len(bbox) % 2 == 1: 26 | print(path) 27 | bbox = [int(x) for x in bbox] 28 | bboxes.append(bbox) 29 | return bboxes 30 | 31 | 32 | def get_gt(path): 33 | lines = file_util.read_file(path).split('\n') 34 | bboxes = [] 35 | tags = [] 36 | for line in lines: 37 | if line == '': 38 | continue 39 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 40 | # gt = util.str.split(line, ' ') 41 | gt = line.split(' ') 42 | 43 | w_ = np.float(gt[4]) 44 | h_ = np.float(gt[5]) 45 | x1 = np.float(gt[2]) + w_ / 2.0 46 | y1 = np.float(gt[3]) + h_ / 2.0 47 | theta = np.float(gt[6]) / math.pi * 180 48 | 49 | bbox = cv2.boxPoints(((x1, y1), (w_, h_), theta)) 50 | bbox = bbox.reshape(-1) 51 | 52 | bboxes.append(bbox) 53 | tags.append(np.int(gt[1])) 54 | return np.array(bboxes), tags 55 | 56 | def get_union(pD, pG, H, W): 57 | # replace original polygon library by opencv function 58 | blank = np.zeros((H, W)) 59 | image1 = cv2.fillPoly(blank.copy(), [pD], 1) 60 | image2 = cv2.fillPoly(blank.copy(), [pG], 1) 61 | areaA = np.sum(image1) 62 | areaB = np.sum(image2) 63 | 64 | return areaA + areaB - get_intersection(pD, pG, H, W) 65 | 66 | def get_intersection(pD, pG, H, W): 67 | # replace original polygon library by opencv function 68 | blank = np.zeros((H, W)) 69 | image1 = cv2.fillPoly(blank.copy(), [pD], 1) 70 | image2 = cv2.fillPoly(blank.copy(), [pG], 1) 71 | intersection = np.logical_and(image1, image2) 72 | 73 | return np.sum(intersection) 74 | 75 | if __name__ == '__main__': 76 | th = 0.5 77 | pred_list = file_util.read_dir(pred_root) 78 | 79 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 80 | for pred_path in pred_list: 81 | count = count + 1 82 | preds = get_pred(pred_path) 83 | gt_path = gt_root + pred_path.split('/')[-1].split('.')[0] + '.gt' 84 | img = cv2.imread(gt_root + pred_path.split('/')[-1][:-4] + '.jpg') 85 | H, W, _ = img.shape 86 | gts, tags = get_gt(gt_path) 87 | 88 | ta = ta + len(preds) 89 | for gt, tag in zip(gts, tags): 90 | gt = np.array(gt) 91 | gt = gt.reshape(gt.shape[0] // 2, 2) 92 | #gt_p = plg.Polygon(gt) 93 | difficult = tag 94 | flag = 0 95 | for pred in preds: 96 | pred = np.array(pred) 97 | pred = pred.reshape(pred.shape[0] // 2, 2) 98 | #pred_p = plg.Polygon(pred) 99 | 100 | union = get_union(pred, gt, H, W) 101 | inter = get_intersection(pred, gt, H, W) 102 | iou = float(inter) / union 103 | if iou >= th: 104 | flag = 1 105 | tp = tp + 1 106 | break 107 | 108 | if flag == 0 and difficult == 0: 109 | fp = fp + 1 110 | 111 | recall = float(tp) / (tp + fp) 112 | precision = float(tp) / ta 113 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 114 | 115 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 116 | -------------------------------------------------------------------------------- /eval/msra/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /eval/totaltext/Deteval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is the evaluation code modified from the originally released code. 3 | 1) Remove dependency on Polygon library 4 | 2) Fix input prediction format to x0,y0,x1,y1,... 5 | ''' 6 | 7 | from os import listdir 8 | from scipy import io 9 | import numpy as np 10 | import cv2 11 | import pdb 12 | 13 | """ 14 | Input format: x0,y0, ..... xn,yn. Each detection is separated by the end of line token ('\n')' 15 | """ 16 | project_root = '../../' 17 | 18 | input_dir = project_root + 'results/submit_tt/' 19 | gt_dir = project_root + 'data/totaltext/Groundtruth/Polygon/Test/' 20 | fid_path = project_root + 'outputs/totaltext/res_tt.txt' 21 | img_root = project_root + 'data/totaltext/Images/Test/' 22 | 23 | allInputs = listdir(input_dir) 24 | 25 | ''' 26 | def get_union(pD, pG): 27 | areaA = pD.area() 28 | areaB = pG.area() 29 | return areaA + areaB - get_intersection(pD, pG) 30 | 31 | 32 | def get_intersection(pD, pG): 33 | pInt = pD & pG 34 | if len(pInt) == 0: 35 | return 0 36 | return pInt.area() 37 | ''' 38 | 39 | def PolyArea(x,y): 40 | return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) 41 | 42 | def get_union(pD, pG, H, W): 43 | # replace original polygon library by opencv function 44 | blank = np.zeros((H, W)) 45 | image1 = cv2.fillPoly(blank.copy(), [pD], 1) 46 | image2 = cv2.fillPoly(blank.copy(), [pG], 1) 47 | areaA = np.sum(image1) 48 | areaB = np.sum(image2) 49 | 50 | return areaA + areaB - get_intersection(pD, pG, H, W) 51 | 52 | def get_intersection(pD, pG, H, W): 53 | # replace original polygon library by opencv function 54 | blank = np.zeros((H, W)) 55 | image1 = cv2.fillPoly(blank.copy(), [pD], 1) 56 | image2 = cv2.fillPoly(blank.copy(), [pG], 1) 57 | intersection = np.logical_and(image1, image2) 58 | 59 | return np.sum(intersection) 60 | 61 | def input_reading_mod(input_dir, input): 62 | """This helper reads input from txt files""" 63 | with open('%s/%s' % (input_dir, input), 'r') as input_fid: 64 | pred = input_fid.readlines() 65 | det = [x.strip('\n') for x in pred] 66 | return det 67 | 68 | 69 | def gt_reading_mod(gt_dir, gt_id): 70 | """This helper reads groundtruths from mat files""" 71 | gt_id = gt_id.split('.')[0] 72 | gt = io.loadmat('%s/poly_gt_%s.mat' % (gt_dir, gt_id)) 73 | gt = gt['polygt'] 74 | return gt 75 | 76 | 77 | def detection_filtering(detections, groundtruths, threshold, H, W): 78 | for gt_id, gt in enumerate(groundtruths): 79 | if (gt[5] == '#') and (gt[1].shape[1] > 1): 80 | gt_x = map(int, np.squeeze(gt[1])) 81 | gt_y = map(int, np.squeeze(gt[3])) 82 | 83 | gt_p = np.concatenate((np.array(list(gt_x)), np.array(list(gt_y)))) 84 | gt_p = gt_p.reshape(2, -1).transpose() 85 | #gt_p = plg.Polygon(gt_p) # replaced 86 | 87 | for det_id, detection in enumerate(detections): 88 | detection = detection.split(',') 89 | # detection = map(int, detection[0:-1]) 90 | detection = map(int, detection) 91 | detection = np.array(list(detection)) 92 | det_x = detection[0::2] 93 | det_y = detection[1::2] 94 | 95 | det_p = np.concatenate((np.array(det_x), np.array(det_y))) 96 | det_p = det_p.reshape(2, -1).transpose() 97 | #det_p = plg.Polygon(det_p) replaced 98 | 99 | try: 100 | # det_gt_iou = iod(det_x, det_y, gt_x, gt_y) 101 | #det_gt_iou = get_intersection(det_p, gt_p, H, W) / det_p.area() # replaced 102 | det_gt_iou = get_intersection(det_p, gt_p, H, W) / PolyArea(det_p[:,0], det_p[:,1]) 103 | except: 104 | print(det_x, det_y, gt_x, gt_y) 105 | if det_gt_iou > threshold: 106 | detections[det_id] = [] 107 | 108 | detections[:] = [item for item in detections if item != []] 109 | return detections 110 | 111 | def sigma_calculation(det_p, gt_p, H, W): 112 | """ 113 | sigma = inter_area / gt_area 114 | """ 115 | # return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), 2) 116 | #return get_intersection(det_p, gt_p) / gt_p.area() # replaced 117 | return get_intersection(det_p, gt_p, H, W) / PolyArea(gt_p[:,0], gt_p[:,1]) 118 | 119 | def tau_calculation(det_p, gt_p, H, W): 120 | """ 121 | tau = inter_area / det_area 122 | """ 123 | #return get_intersection(det_p, gt_p) / det_p.area() # replaced 124 | return get_intersection(det_p, gt_p, H, W) / PolyArea(det_p[:,0], det_p[:,1]) 125 | 126 | 127 | ##############################Initialization################################### 128 | global_tp = 0 129 | global_fp = 0 130 | global_fn = 0 131 | global_sigma = [] 132 | global_tau = [] 133 | tr = 0.7 134 | tp = 0.6 135 | fsc_k = 0.8 136 | k = 2 137 | ############################################################################### 138 | 139 | 140 | for input_id in allInputs: 141 | if (input_id != '.DS_Store'): 142 | img = cv2.imread(img_root + input_id[:-4] + '.jpg') 143 | H, W, _ = img.shape 144 | print('input_id', input_id) 145 | detections = input_reading_mod(input_dir, input_id) 146 | # from IPython import embed; 147 | groundtruths = gt_reading_mod(gt_dir, input_id) 148 | detections = detection_filtering(detections, groundtruths, 0.5, H, W) # filters detections overlapping with DC area 149 | dc_id = np.where(groundtruths[:, 5] == '#') 150 | groundtruths = np.delete(groundtruths, (dc_id), (0)) 151 | 152 | local_sigma_table = np.zeros((groundtruths.shape[0], len(detections))) 153 | local_tau_table = np.zeros((groundtruths.shape[0], len(detections))) 154 | 155 | for gt_id, gt in enumerate(groundtruths): 156 | if len(detections) > 0: 157 | for det_id, detection in enumerate(detections): 158 | detection = detection.split(',') 159 | # print (len(detection)) 160 | 161 | # detection = map(int, detection[:-1]) 162 | detection = map(int, detection) 163 | detection = np.array(list(detection)) 164 | # print (len(detection)) 165 | 166 | # from IPython import embed;embed() 167 | # detection = list(detection) 168 | gt_x = map(int, np.squeeze(gt[1])) 169 | gt_y = map(int, np.squeeze(gt[3])) 170 | gt_p = np.concatenate((np.array(list(gt_x)), np.array(list(gt_y)))) 171 | gt_p = gt_p.reshape(2, -1).transpose() 172 | #gt_p = plg.Polygon(gt_p) # replaced 173 | 174 | det_y = detection[1::2] 175 | det_x = detection[0::2] 176 | 177 | det_p = np.concatenate((np.array(det_x), np.array(det_y))) 178 | # print (det_p.shape) 179 | det_p = det_p.reshape(2, -1).transpose() 180 | #det_p = plg.Polygon(det_p) # replaced 181 | 182 | # gt_x = list(map(int, np.squeeze(gt[1]))) 183 | # gt_y = list(map(int, np.squeeze(gt[3]))) 184 | # try: 185 | # local_sigma_table[gt_id, det_id] = sigma_calculation(det_x, det_y, gt_x, gt_y) 186 | # local_tau_table[gt_id, det_id] = tau_calculation(det_x, det_y, gt_x, gt_y) 187 | # except: 188 | # embed() 189 | local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, gt_p, H, W) 190 | local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p, H, W) 191 | # if input_id == 'img1199.txt': 192 | # embed() 193 | global_sigma.append(local_sigma_table) 194 | global_tau.append(local_tau_table) 195 | 196 | global_accumulative_recall = 0 197 | global_accumulative_precision = 0 198 | total_num_gt = 0 199 | total_num_det = 0 200 | 201 | 202 | def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, 203 | local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, 204 | gt_flag, det_flag): 205 | for gt_id in range(num_gt): 206 | qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr) 207 | num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0] 208 | qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp) 209 | num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0] 210 | 211 | if (num_qualified_sigma_candidates == 1) and (num_qualified_tau_candidates == 1): 212 | global_accumulative_recall = global_accumulative_recall + 1.0 213 | global_accumulative_precision = global_accumulative_precision + 1.0 214 | local_accumulative_recall = local_accumulative_recall + 1.0 215 | local_accumulative_precision = local_accumulative_precision + 1.0 216 | 217 | gt_flag[0, gt_id] = 1 218 | matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) 219 | det_flag[0, matched_det_id] = 1 220 | return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag 221 | 222 | 223 | def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, 224 | local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, 225 | gt_flag, det_flag): 226 | for gt_id in range(num_gt): 227 | # skip the following if the groundtruth was matched 228 | if gt_flag[0, gt_id] > 0: 229 | continue 230 | 231 | non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) 232 | num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] 233 | 234 | if num_non_zero_in_sigma >= k: 235 | ####search for all detections that overlaps with this groundtruth 236 | qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0)) 237 | num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0] 238 | 239 | if num_qualified_tau_candidates == 1: 240 | if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and ( 241 | local_sigma_table[gt_id, qualified_tau_candidates] >= tr)): 242 | # became an one-to-one case 243 | global_accumulative_recall = global_accumulative_recall + 1.0 244 | global_accumulative_precision = global_accumulative_precision + 1.0 245 | local_accumulative_recall = local_accumulative_recall + 1.0 246 | local_accumulative_precision = local_accumulative_precision + 1.0 247 | 248 | gt_flag[0, gt_id] = 1 249 | det_flag[0, qualified_tau_candidates] = 1 250 | elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr): 251 | gt_flag[0, gt_id] = 1 252 | det_flag[0, qualified_tau_candidates] = 1 253 | 254 | global_accumulative_recall = global_accumulative_recall + fsc_k 255 | global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k 256 | 257 | local_accumulative_recall = local_accumulative_recall + fsc_k 258 | local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k 259 | 260 | return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag 261 | 262 | 263 | def many_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, 264 | local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, 265 | gt_flag, det_flag): 266 | for det_id in range(num_det): 267 | # skip the following if the detection was matched 268 | if det_flag[0, det_id] > 0: 269 | continue 270 | 271 | non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) 272 | num_non_zero_in_tau = non_zero_in_tau[0].shape[0] 273 | 274 | if num_non_zero_in_tau >= k: 275 | ####search for all detections that overlaps with this groundtruth 276 | qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) 277 | num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0] 278 | 279 | if num_qualified_sigma_candidates == 1: 280 | if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and ( 281 | local_sigma_table[qualified_sigma_candidates, det_id] >= tr)): 282 | # became an one-to-one case 283 | global_accumulative_recall = global_accumulative_recall + 1.0 284 | global_accumulative_precision = global_accumulative_precision + 1.0 285 | local_accumulative_recall = local_accumulative_recall + 1.0 286 | local_accumulative_precision = local_accumulative_precision + 1.0 287 | 288 | gt_flag[0, qualified_sigma_candidates] = 1 289 | det_flag[0, det_id] = 1 290 | elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp): 291 | det_flag[0, det_id] = 1 292 | gt_flag[0, qualified_sigma_candidates] = 1 293 | 294 | global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k 295 | global_accumulative_precision = global_accumulative_precision + fsc_k 296 | 297 | local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k 298 | local_accumulative_precision = local_accumulative_precision + fsc_k 299 | return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag 300 | 301 | 302 | for idx in range(len(global_sigma)): 303 | print(allInputs[idx]) 304 | local_sigma_table = global_sigma[idx] 305 | local_tau_table = global_tau[idx] 306 | 307 | num_gt = local_sigma_table.shape[0] 308 | num_det = local_sigma_table.shape[1] 309 | 310 | total_num_gt = total_num_gt + num_gt 311 | total_num_det = total_num_det + num_det 312 | 313 | local_accumulative_recall = 0 314 | local_accumulative_precision = 0 315 | gt_flag = np.zeros((1, num_gt)) 316 | det_flag = np.zeros((1, num_det)) 317 | 318 | #######first check for one-to-one case########## 319 | local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ 320 | gt_flag, det_flag = one_to_one(local_sigma_table, local_tau_table, 321 | local_accumulative_recall, local_accumulative_precision, 322 | global_accumulative_recall, global_accumulative_precision, 323 | gt_flag, det_flag) 324 | 325 | #######then check for one-to-many case########## 326 | local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ 327 | gt_flag, det_flag = one_to_many(local_sigma_table, local_tau_table, 328 | local_accumulative_recall, local_accumulative_precision, 329 | global_accumulative_recall, global_accumulative_precision, 330 | gt_flag, det_flag) 331 | 332 | #######then check for many-to-many case########## 333 | local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ 334 | gt_flag, det_flag = many_to_many(local_sigma_table, local_tau_table, 335 | local_accumulative_recall, local_accumulative_precision, 336 | global_accumulative_recall, global_accumulative_precision, 337 | gt_flag, det_flag) 338 | # for det_id in xrange(num_det): 339 | # # skip the following if the detection was matched 340 | # if det_flag[0, det_id] > 0: 341 | # continue 342 | # 343 | # non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) 344 | # num_non_zero_in_tau = non_zero_in_tau[0].shape[0] 345 | # 346 | # if num_non_zero_in_tau >= k: 347 | # ####search for all detections that overlaps with this groundtruth 348 | # qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) 349 | # num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0] 350 | # 351 | # if num_qualified_sigma_candidates == 1: 352 | # if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (local_sigma_table[qualified_sigma_candidates, det_id] >= tr)): 353 | # #became an one-to-one case 354 | # global_accumulative_recall = global_accumulative_recall + 1.0 355 | # global_accumulative_precision = global_accumulative_precision + 1.0 356 | # local_accumulative_recall = local_accumulative_recall + 1.0 357 | # local_accumulative_precision = local_accumulative_precision + 1.0 358 | # 359 | # gt_flag[0, qualified_sigma_candidates] = 1 360 | # det_flag[0, det_id] = 1 361 | # elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp): 362 | # det_flag[0, det_id] = 1 363 | # gt_flag[0, qualified_sigma_candidates] = 1 364 | # 365 | # global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k 366 | # global_accumulative_precision = global_accumulative_precision + fsc_k 367 | # 368 | # local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k 369 | # local_accumulative_precision = local_accumulative_precision + fsc_k 370 | 371 | fid = open(fid_path, 'a+') 372 | try: 373 | local_precision = local_accumulative_precision / num_det 374 | except ZeroDivisionError: 375 | local_precision = 0 376 | 377 | try: 378 | local_recall = local_accumulative_recall / num_gt 379 | except ZeroDivisionError: 380 | local_recall = 0 381 | 382 | temp = ('%s______/Precision:_%s_______/Recall:_%s\n' % (allInputs[idx], str(local_precision), str(local_recall))) 383 | fid.write(temp) 384 | fid.close() 385 | try: 386 | recall = global_accumulative_recall / total_num_gt 387 | except ZeroDivisionError: 388 | recall = 0 389 | 390 | try: 391 | precision = global_accumulative_precision / total_num_det 392 | except ZeroDivisionError: 393 | precision = 0 394 | 395 | try: 396 | f_score = 2 * precision * recall / (precision + recall) 397 | except ZeroDivisionError: 398 | f_score = 0 399 | 400 | fid = open(fid_path, 'a') 401 | hmean = 2 * precision * recall / (precision + recall) 402 | temp = ('Precision:_%s_______/Recall:_%s/Hmean:_%s\n' % (str(precision), str(recall), str(hmean))) 403 | print(temp) 404 | fid.write(temp) 405 | fid.close() 406 | 407 | print('pb') -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | THis is the main training code. 3 | ''' 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set GPU id at the very begining 6 | import argparse 7 | import random 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn.parallel 12 | import torch.optim as optim 13 | import torch.utils.data 14 | import torch.nn.functional as F 15 | from torch.multiprocessing import freeze_support 16 | import json 17 | import sys 18 | import time 19 | import pdb 20 | # internal package 21 | from dataset import testdataset 22 | from models.pan import PAN 23 | from utils.helper import get_results, write_result, draw_result, upsample 24 | 25 | # main function: 26 | if __name__ == '__main__': 27 | freeze_support() 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | '--worker', type=int, default=4, help='number of data loading workers') 31 | parser.add_argument('--input', type=str, default='', required=True, help='input folder name') 32 | parser.add_argument('--output', type=str, default='results', help='output folder name') 33 | parser.add_argument('--model', type=str, required=True, help='model path') 34 | parser.add_argument('--gpu', type=bool, default=False, help="GPU being used or not") 35 | parser.add_argument('--bbox_type', type=str, default='poly', help="bounding box type - poly | rect") 36 | 37 | opt = parser.parse_args() 38 | print(opt) 39 | 40 | # turn on GPU for models: 41 | if opt.gpu == False: 42 | device = torch.device("cpu") 43 | print("CPU being used!") 44 | else: 45 | if torch.cuda.is_available() == True and opt.gpu == True: 46 | device = torch.device("cuda") 47 | print("GPU being used!") 48 | else: 49 | device = torch.device("cpu") 50 | print("CPU being used!") 51 | 52 | # set training parameters 53 | batch_size = 1 54 | neck_channel = (64, 128, 256, 512) 55 | pa_in_channels = 512 56 | hidden_dim = 128 57 | num_classes = 6 58 | 59 | data_dirs = opt.input 60 | worker = opt.worker 61 | output_path = opt.output 62 | trained_model_path = opt.model 63 | bbox_type = opt.bbox_type 64 | min_area = 16 65 | min_score = 0.88 66 | 67 | # create dataset 68 | print("Create dataset......") 69 | test_dataset = testdataset.PAN_test(data_dirs, 640) 70 | 71 | # make dataloader 72 | test_dataloader = torch.utils.data.DataLoader( 73 | test_dataset, 74 | batch_size=1, 75 | shuffle=False, 76 | num_workers=int(worker)) 77 | 78 | print("Length of test dataset is:", len(test_dataset)) 79 | 80 | # make model prediction output folder 81 | try: 82 | os.makedirs(output_path) 83 | except OSError: 84 | pass 85 | 86 | # create model 87 | print("Create model......") 88 | model = PAN(pretrained=False, neck_channel=neck_channel, pa_in_channels=pa_in_channels, hidden_dim=hidden_dim, num_classes=num_classes) 89 | 90 | if trained_model_path != '': 91 | if torch.cuda.is_available() == True and opt.gpu == True: 92 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 93 | model = torch.nn.DataParallel(model).to(device) 94 | else: 95 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 96 | else: 97 | print("Error: Empty model path!") 98 | exit(1) 99 | 100 | # model inference 101 | print("Prediction on testset......") 102 | timer = [] 103 | model.eval() 104 | for idx, data in enumerate(test_dataloader): 105 | print('Testing %d/%d' % (idx, len(test_dataloader))) 106 | outputs = dict() 107 | # prepare input 108 | data['imgs'] = data['imgs'].to(device) 109 | # forward 110 | start = time.time() 111 | with torch.no_grad(): 112 | det_out = model(data['imgs']) 113 | det_out = upsample(det_out, data['imgs'].size(), 4) 114 | det_res = get_results(det_out, data['img_metas'], min_area, min_score, bbox_type) 115 | outputs.update(det_res) 116 | end = time.time() 117 | timer.append(end - start) 118 | 119 | # save result 120 | image_name, _ = os.path.splitext(os.path.basename(test_dataloader.dataset.img_paths[idx])) 121 | write_result(image_name, outputs, os.path.join(output_path, 'submit_ctw')) 122 | 123 | # draw and save images 124 | draw_result(test_dataloader.dataset.img_paths[idx], outputs, output_path) 125 | 126 | print("Average FPS:", 1/(sum(timer)/len(timer))) -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/loss/__init__.py -------------------------------------------------------------------------------- /loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to implement dice loss. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | 7 | class DiceLoss(nn.Module): 8 | def __init__(self, loss_weight=1.0): 9 | super(DiceLoss, self).__init__() 10 | self.loss_weight = loss_weight 11 | 12 | def forward(self, input, target, mask, reduce=True): 13 | batch_size = input.size(0) 14 | input = torch.sigmoid(input) 15 | 16 | input = input.contiguous().view(batch_size, -1) 17 | target = target.contiguous().view(batch_size, -1).float() 18 | mask = mask.contiguous().view(batch_size, -1).float() 19 | 20 | input = input * mask 21 | target = target * mask 22 | 23 | a = torch.sum(input * target, dim=1) 24 | b = torch.sum(input * input, dim=1) + 0.001 25 | c = torch.sum(target * target, dim=1) + 0.001 26 | d = (2 * a) / (b + c) 27 | loss = 1 - d 28 | 29 | loss = self.loss_weight * loss 30 | 31 | if reduce: 32 | loss = torch.mean(loss) 33 | 34 | return loss -------------------------------------------------------------------------------- /loss/emb_loss_v1.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to implement embedding loss. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from torch.autograd import Function, Variable 9 | 10 | class EmbLoss_v1(nn.Module): 11 | def __init__(self, feature_dim=4, loss_weight=1.0): 12 | super(EmbLoss_v1, self).__init__() 13 | self.feature_dim = feature_dim 14 | self.loss_weight = loss_weight 15 | self.delta_v = 0.5 # delta_agg 16 | self.delta_d = 1.5 # delta_dis 17 | self.weights = (1.0, 1.0) 18 | 19 | def forward_single(self, emb, instance, kernel, training_mask, bboxes): 20 | training_mask = (training_mask > 0.5).long() 21 | kernel = (kernel > 0.5).long() 22 | instance = instance * training_mask 23 | instance_kernel = (instance * kernel).view(-1) 24 | instance = instance.view(-1) 25 | emb = emb.view(self.feature_dim, -1) 26 | 27 | unique_labels, unique_ids = torch.unique(instance_kernel, sorted=True, return_inverse=True) 28 | num_instance = unique_labels.size(0) 29 | if num_instance <= 1: 30 | return 0 31 | 32 | emb_mean = emb.new_zeros((self.feature_dim, num_instance), dtype=torch.float32) 33 | for i, lb in enumerate(unique_labels): 34 | if lb == 0: 35 | continue 36 | ind_k = instance_kernel == lb 37 | emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1) 38 | 39 | l_agg = emb.new_zeros(num_instance, dtype=torch.float32) # bug 40 | for i, lb in enumerate(unique_labels): 41 | if lb == 0: 42 | continue 43 | ind = instance == lb 44 | emb_ = emb[:, ind] 45 | dist = (emb_ - emb_mean[:, i:i + 1]).norm(p=2, dim=0) 46 | dist = F.relu(dist - self.delta_v) ** 2 47 | l_agg[i] = torch.mean(torch.log(dist + 1.0)) 48 | l_agg = torch.mean(l_agg[1:]) 49 | 50 | if num_instance > 2: 51 | emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1) 52 | emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view(-1, self.feature_dim) 53 | # print(seg_band) 54 | 55 | mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view(-1, 1).repeat(1, self.feature_dim) 56 | mask = mask.view(num_instance, num_instance, -1) 57 | mask[0, :, :] = 0 58 | mask[:, 0, :] = 0 59 | mask = mask.view(num_instance * num_instance, -1) 60 | # print(mask) 61 | 62 | dist = emb_interleave - emb_band 63 | dist = dist[mask > 0].view(-1, self.feature_dim).norm(p=2, dim=1) 64 | dist = F.relu(2 * self.delta_d - dist) ** 2 65 | l_dis = torch.mean(torch.log(dist + 1.0)) 66 | else: 67 | l_dis = 0 68 | 69 | l_agg = self.weights[0] * l_agg 70 | l_dis = self.weights[1] * l_dis 71 | l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001 72 | loss = l_agg + l_dis + l_reg 73 | return loss 74 | 75 | def forward(self, emb, instance, kernel, training_mask, bboxes, reduce=True): 76 | # TO CHECK: bboxes needs to be removed? 77 | loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32) 78 | 79 | for i in range(loss_batch.size(0)): 80 | loss_batch[i] = self.forward_single(emb[i], instance[i], kernel[i], training_mask[i], bboxes[i]) 81 | 82 | loss_batch = self.loss_weight * loss_batch 83 | 84 | if reduce: 85 | loss_batch = torch.mean(loss_batch) 86 | 87 | return loss_batch -------------------------------------------------------------------------------- /loss/iou.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is for IOU implementation. 3 | ''' 4 | import torch 5 | 6 | EPS = 1e-6 7 | 8 | def iou_single(a, b, mask, n_class): 9 | valid = mask == 1 10 | a = a[valid] 11 | b = b[valid] 12 | miou = [] 13 | for i in range(n_class): 14 | inter = ((a == i) & (b == i)).float() 15 | union = ((a == i) | (b == i)).float() 16 | 17 | miou.append(torch.sum(inter) / (torch.sum(union) + EPS)) 18 | miou = sum(miou) / len(miou) 19 | return miou 20 | 21 | def iou(a, b, mask, n_class=2, reduce=True): 22 | batch_size = a.size(0) 23 | 24 | a = a.view(batch_size, -1) 25 | b = b.view(batch_size, -1) 26 | mask = mask.view(batch_size, -1) 27 | 28 | iou = a.new_zeros((batch_size,), dtype=torch.float32) 29 | for i in range(batch_size): 30 | iou[i] = iou_single(a[i], b[i], mask[i], n_class) 31 | 32 | if reduce: 33 | iou = torch.mean(iou) 34 | return iou -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This function is to compute total loss for training 3 | ''' 4 | import torch 5 | from .ohem import ohem_batch 6 | from .iou import iou 7 | from .dice_loss import DiceLoss 8 | from .emb_loss_v1 import EmbLoss_v1 9 | 10 | def text_loss(input, target, mask, reduce, loss_weight): 11 | loss = DiceLoss(loss_weight) 12 | return loss(input, target, mask, reduce) 13 | 14 | def kernel_loss(input, target, mask, reduce, loss_weight): 15 | loss = DiceLoss(loss_weight) 16 | return loss(input, target, mask, reduce) 17 | 18 | def emb_loss(emb, instance, kernel, training_mask, bboxes, reduce, loss_weight): 19 | loss = EmbLoss_v1(feature_dim=4, loss_weight=loss_weight) 20 | return loss(emb, instance, kernel, training_mask, bboxes, reduce) 21 | 22 | def loss(out, gt_texts, gt_kernels, training_masks, gt_instances, gt_bboxes, loss_text_weight, loss_kernel_weight, loss_emb_weight): 23 | # output 24 | texts = out[:, 0, :, :] 25 | kernels = out[:, 1:2, :, :] 26 | embs = out[:, 2:, :, :] 27 | 28 | # text loss 29 | selected_masks = ohem_batch(texts, gt_texts, training_masks) 30 | loss_text = text_loss(texts, gt_texts, selected_masks, False, loss_text_weight) 31 | iou_text = iou((texts > 0).long(), gt_texts, training_masks, reduce=False) 32 | losses = dict( 33 | loss_text=loss_text, 34 | iou_text=iou_text 35 | ) 36 | 37 | # kernel loss 38 | loss_kernels = [] 39 | selected_masks = gt_texts * training_masks 40 | for i in range(kernels.size(1)): 41 | kernel_i = kernels[:, i, :, :] 42 | gt_kernel_i = gt_kernels[:, i, :, :] 43 | loss_kernel_i = kernel_loss(kernel_i, gt_kernel_i, selected_masks, False, loss_kernel_weight) 44 | loss_kernels.append(loss_kernel_i) 45 | loss_kernels = torch.mean(torch.stack(loss_kernels, dim=1), dim=1) 46 | iou_kernel = iou( 47 | (kernels[:, -1, :, :] > 0).long(), gt_kernels[:, -1, :, :], training_masks * gt_texts, reduce=False) 48 | losses.update(dict( 49 | loss_kernels=loss_kernels, 50 | iou_kernel=iou_kernel 51 | )) 52 | 53 | # embedding loss 54 | loss_emb = emb_loss(embs, gt_instances, gt_kernels[:, -1, :, :], training_masks, gt_bboxes, False, loss_emb_weight) 55 | losses.update(dict( 56 | loss_emb=loss_emb 57 | )) 58 | 59 | return losses -------------------------------------------------------------------------------- /loss/ohem.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is for online hard example mining algorithm. 3 | ''' 4 | import torch 5 | 6 | def ohem_single(score, gt_text, training_mask): 7 | pos_num = int(torch.sum(gt_text > 0.5)) - int(torch.sum((gt_text > 0.5) & (training_mask <= 0.5))) 8 | 9 | if pos_num == 0: 10 | # selected_mask = gt_text.copy() * 0 # may be not good 11 | selected_mask = training_mask 12 | selected_mask = selected_mask.view(1, selected_mask.shape[0], selected_mask.shape[1]).float() 13 | return selected_mask 14 | 15 | neg_num = int(torch.sum(gt_text <= 0.5)) 16 | neg_num = int(min(pos_num * 3, neg_num)) 17 | 18 | if neg_num == 0: 19 | selected_mask = training_mask 20 | selected_mask = selected_mask.view(1, selected_mask.shape[0], selected_mask.shape[1]).float() 21 | return selected_mask 22 | 23 | neg_score = score[gt_text <= 0.5] 24 | neg_score_sorted, _ = torch.sort(-neg_score) 25 | threshold = -neg_score_sorted[neg_num - 1] 26 | 27 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5) 28 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).float() 29 | return selected_mask 30 | 31 | def ohem_batch(scores, gt_texts, training_masks): 32 | selected_masks = [] 33 | for i in range(scores.shape[0]): 34 | selected_masks.append(ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])) 35 | 36 | selected_masks = torch.cat(selected_masks, 0).float() 37 | return selected_masks -------------------------------------------------------------------------------- /misc/ctw_statistics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/misc/ctw_statistics.png -------------------------------------------------------------------------------- /misc/synthtext_statistics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/misc/synthtext_statistics.png -------------------------------------------------------------------------------- /misc/tt_statistics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/misc/tt_statistics.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/models/__init__.py -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build backbone model by pretrained ResNet from ImageNet. 3 | ''' 4 | import os 5 | import sys 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | try: 11 | from urllib import urlretrieve 12 | except ImportError: 13 | from urllib.request import urlretrieve 14 | 15 | __all__ = ['resnet18', 'resnet50', 'resnet101'] 16 | 17 | model_urls = { 18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 21 | } 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | class Convkxk(nn.Module): 99 | def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0): 100 | super(Convkxk, self).__init__() 101 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 102 | bias=False) 103 | self.bn = nn.BatchNorm2d(out_planes) 104 | self.relu = nn.ReLU(inplace=True) 105 | 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | 114 | def forward(self, x): 115 | return self.relu(self.bn(self.conv(x))) 116 | 117 | class ResNet(nn.Module): 118 | 119 | def __init__(self, block, layers, num_classes=1000): 120 | super(ResNet, self).__init__() 121 | self.inplanes = 128 122 | self.conv1 = conv3x3(3, 64, stride=2) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.relu1 = nn.ReLU(inplace=True) 125 | self.conv2 = conv3x3(64, 64) 126 | self.bn2 = nn.BatchNorm2d(64) 127 | self.relu2 = nn.ReLU(inplace=True) 128 | self.conv3 = conv3x3(64, 128) 129 | self.bn3 = nn.BatchNorm2d(128) 130 | self.relu3 = nn.ReLU(inplace=True) 131 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 132 | 133 | self.layer1 = self._make_layer(block, 64, layers[0]) 134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 135 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 136 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 137 | # self.avgpool = nn.AvgPool2d(7, stride=1) 138 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | 148 | def _make_layer(self, block, planes, blocks, stride=1): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | nn.Conv2d(self.inplanes, planes * block.expansion, 153 | kernel_size=1, stride=stride, bias=False), 154 | nn.BatchNorm2d(planes * block.expansion), 155 | ) 156 | 157 | layers = [] 158 | layers.append(block(self.inplanes, planes, stride, downsample)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, blocks): 161 | layers.append(block(self.inplanes, planes)) 162 | 163 | return nn.Sequential(*layers) 164 | 165 | def forward(self, x): 166 | x = self.relu1(self.bn1(self.conv1(x))) 167 | x = self.relu2(self.bn2(self.conv2(x))) 168 | x = self.relu3(self.bn3(self.conv3(x))) 169 | x = self.maxpool(x) 170 | 171 | f = [] 172 | x = self.layer1(x) 173 | f.append(x) 174 | x = self.layer2(x) 175 | f.append(x) 176 | x = self.layer3(x) 177 | f.append(x) 178 | x = self.layer4(x) 179 | f.append(x) 180 | 181 | return tuple(f) 182 | 183 | # x = self.avgpool(x) 184 | # x = x.view(x.size(0), -1) 185 | # x = self.fc(x) 186 | 187 | # return x 188 | 189 | def resnet18(pretrained=False, **kwargs): 190 | """Constructs a ResNet-18 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on Places 193 | """ 194 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 195 | if pretrained: 196 | model.load_state_dict(load_url(model_urls['resnet18']), strict=False) 197 | return model 198 | 199 | 200 | def resnet50(pretrained=False, **kwargs): 201 | """Constructs a ResNet-50 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on Places 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 208 | return model 209 | 210 | 211 | def resnet101(pretrained=False, **kwargs): 212 | """Constructs a ResNet-101 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on Places 215 | """ 216 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 217 | if pretrained: 218 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 219 | return model 220 | 221 | 222 | def load_url(url, model_dir='./pretrained', map_location=None): 223 | if not os.path.exists(model_dir): 224 | os.makedirs(model_dir) 225 | filename = url.split('/')[-1] 226 | cached_file = os.path.join(model_dir, filename) 227 | if not os.path.exists(cached_file): 228 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 229 | urlretrieve(url, cached_file) 230 | return torch.load(cached_file, map_location=lambda storage, loc: storage) 231 | 232 | # unit test 233 | if __name__ == '__main__': 234 | batch_size = 32 235 | Height = 48 236 | Width = 160 237 | Channel = 3 238 | 239 | input_images = torch.randn(batch_size,Channel,Height,Width) 240 | model = resnet18(pretrained=False) 241 | output_features = model(input_images) 242 | 243 | print("Input size is:",input_images.shape) 244 | print("Output feature map size is:", len(output_features)) 245 | for layer in range(len(output_features)): 246 | print("Shape of layer {} is {}".format(layer, output_features[layer].shape)) -------------------------------------------------------------------------------- /models/ffm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is for FFM model in PAN. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['FFM'] 9 | 10 | class FFM(nn.Module): 11 | def __init__(self): 12 | super(FFM, self).__init__() 13 | 14 | def _upsample(self, x, size, scale=1): 15 | _, _, H, W = size 16 | return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear') 17 | 18 | def forward(self, f1_1, f2_1, f3_1, f4_1, f1_2, f2_2, f3_2, f4_2): 19 | f1 = f1_1 + f1_2 20 | f2 = f2_1 + f2_2 21 | f3 = f3_1 + f3_2 22 | f4 = f4_1 + f4_2 23 | f2 = self._upsample(f2, f1.size()) 24 | f3 = self._upsample(f3, f1.size()) 25 | f4 = self._upsample(f4, f1.size()) 26 | f = torch.cat((f1, f2, f3, f4), 1) 27 | 28 | return f 29 | 30 | # unit testing 31 | if __name__ == '__main__': 32 | batch_size = 32 33 | Height = 512 34 | Width = 768 35 | Channel = 128 36 | f1_1 = torch.randn(batch_size,Channel,Height//4,Width//4) 37 | f2_1 = torch.randn(batch_size,Channel,Height//8,Width//8) 38 | f3_1 = torch.randn(batch_size,Channel,Height//16,Width//16) 39 | f4_1 = torch.randn(batch_size,Channel,Height//32,Width//32) 40 | 41 | f1_2 = torch.randn(batch_size,Channel,Height//4,Width//4) 42 | f2_2 = torch.randn(batch_size,Channel,Height//8,Width//8) 43 | f3_2 = torch.randn(batch_size,Channel,Height//16,Width//16) 44 | f4_2 = torch.randn(batch_size,Channel,Height//32,Width//32) 45 | 46 | ffm_model = FFM() 47 | f = ffm_model(f1_1, f2_1, f3_1, f4_1, f1_2, f2_2, f3_2, f4_2) 48 | print("FFM input layer 1 shape:", f1_1.shape) 49 | print("FFM input layer 2 shape:", f2_1.shape) 50 | print("FFM input layer 3 shape:", f3_1.shape) 51 | print("FFM input layer 4 shape:", f4_1.shape) 52 | print("FFM output shape:", f.shape) -------------------------------------------------------------------------------- /models/fpem.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is is FPEM module for PAN. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['Conv_BN_ReLU','FPEM'] 10 | 11 | class Conv_BN_ReLU(nn.Module): 12 | def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0): 13 | super(Conv_BN_ReLU, self).__init__() 14 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 15 | bias=False) 16 | self.bn = nn.BatchNorm2d(out_planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | def forward(self, x): 28 | return self.relu(self.bn(self.conv(x))) 29 | 30 | class FPEM(nn.Module): 31 | def __init__(self, in_channels, out_channels): 32 | super(FPEM, self).__init__() 33 | planes = out_channels 34 | self.dwconv3_1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, groups=planes, bias=False) 35 | self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes) 36 | 37 | self.dwconv2_1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, groups=planes, bias=False) 38 | self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes) 39 | 40 | self.dwconv1_1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, groups=planes, bias=False) 41 | self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes) 42 | 43 | self.dwconv2_2 = nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1, groups=planes, bias=False) 44 | self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes) 45 | 46 | self.dwconv3_2 = nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1, groups=planes, bias=False) 47 | self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes) 48 | 49 | self.dwconv4_2 = nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1, groups=planes, bias=False) 50 | self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes) 51 | 52 | def _upsample_add(self, x, y): 53 | _, _, H, W = y.size() 54 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 55 | 56 | def forward(self, f1, f2, f3, f4): 57 | f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3))) 58 | f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2))) 59 | f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1))) 60 | 61 | f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1))) 62 | f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2))) 63 | f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3))) 64 | 65 | return f1, f2, f3, f4 66 | 67 | # unit testing 68 | if __name__ == '__main__': 69 | 70 | batch_size = 32 71 | Height = 512 72 | Width = 512 73 | Channel = 128 74 | 75 | f1 = torch.randn(batch_size,Channel,Height//4,Width//4) 76 | f2 = torch.randn(batch_size,Channel,Height//8,Width//8) 77 | f3 = torch.randn(batch_size,Channel,Height//16,Width//16) 78 | f4 = torch.randn(batch_size,Channel,Height//32,Width//32) 79 | print("Input of FPEM layer 1:", f1.shape) 80 | print("Input of FPEM layer 2:", f2.shape) 81 | print("Input of FPEM layer 3:", f3.shape) 82 | print("Input of FPEM layer 4:", f4.shape) 83 | 84 | fpem_model = FPEM(Channel, Channel) 85 | 86 | f1, f2, f3, f4 = fpem_model(f1, f2, f3, f4) 87 | print("Output of FPEM layer 1:", f1.shape) 88 | print("Output of FPEM layer 2:", f2.shape) 89 | print("Output of FPEM layer 3:", f3.shape) 90 | print("Output of FPEM layer 4:", f4.shape) -------------------------------------------------------------------------------- /models/head.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is for head detection for PAN. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | __all__ = ['PA_Head'] 9 | 10 | class PA_Head(nn.Module): 11 | def __init__(self, in_channels, hidden_dim, num_classes): 12 | super(PA_Head, self).__init__() 13 | self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) 14 | self.bn1 = nn.BatchNorm2d(hidden_dim) 15 | self.relu1 = nn.ReLU(inplace=True) 16 | 17 | self.conv2 = nn.Conv2d(hidden_dim, num_classes, kernel_size=1, stride=1, padding=0) 18 | 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | def forward(self, f): 28 | out = self.conv1(f) 29 | out = self.relu1(self.bn1(out)) 30 | out = self.conv2(out) 31 | 32 | return out -------------------------------------------------------------------------------- /models/pan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is the integrted model for PAN. 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .backbone import resnet18 10 | from .fpem import FPEM, Conv_BN_ReLU 11 | from .ffm import FFM 12 | from .head import PA_Head 13 | 14 | __all__ = ['PAN'] 15 | 16 | class PAN(nn.Module): 17 | def __init__(self, pretrained, neck_channel, pa_in_channels, hidden_dim, num_classes): 18 | super(PAN, self).__init__() 19 | self.backbone = resnet18(pretrained=pretrained) 20 | in_channels = neck_channel 21 | self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128) 22 | self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128) 23 | self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128) 24 | self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128) 25 | 26 | self.fpem1 = FPEM(128, 128) 27 | self.fpem2 = FPEM(128, 128) 28 | 29 | self.ffm = FFM() 30 | 31 | self.det_head = PA_Head(pa_in_channels, hidden_dim, num_classes) 32 | 33 | def _upsample(self, x, size, scale=1): 34 | _, _, H, W = size 35 | return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear') 36 | 37 | def forward(self, imgs): 38 | # backbone 39 | f = self.backbone(imgs) 40 | 41 | # reduce channel 42 | f1 = self.reduce_layer1(f[0]) 43 | f2 = self.reduce_layer2(f[1]) 44 | f3 = self.reduce_layer3(f[2]) 45 | f4 = self.reduce_layer4(f[3]) 46 | 47 | # FPEM 48 | f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4) 49 | f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1) 50 | 51 | # FFM 52 | f = self.ffm(f1_1, f2_1, f3_1, f4_1, f1_2, f2_2, f3_2, f4_2) 53 | 54 | # detection 55 | det_out = self.det_head(f) 56 | 57 | return det_out 58 | 59 | # unit testing 60 | if __name__ == '__main__': 61 | 62 | batch_size = 32 63 | Height = 32 64 | Width = 64 65 | neck_channel = [64, 128, 256, 512] 66 | pa_in_channels = 512 67 | hidden_dim = 128 68 | Channel = 3 69 | 70 | input_images = torch.randn(batch_size,Channel,Height,Width) 71 | 72 | model = PAN(pretrained=False, neck_channel=neck_channel, pa_in_channels=pa_in_channels, hidden_dim=hidden_dim, num_classes=6) 73 | 74 | det_out = model(input_images) 75 | print("PAN output size is:", det_out.shape) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance 2 | opencv-python 3 | torch 4 | torchvision 5 | pyclipper 6 | Cython (optional if want to speed up postprocessing) 7 | PIL 8 | scipy (for performance evaluation) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | THis is the main training code. 3 | ''' 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set GPU id at the very begining 6 | import argparse 7 | import random 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn.parallel 12 | import torch.optim as optim 13 | import torch.utils.data 14 | import torch.nn.functional as F 15 | from torch.multiprocessing import freeze_support 16 | import json 17 | import sys 18 | import time 19 | import pdb 20 | # internal package 21 | from dataset import ctw1500, totaltext, synthtext, msra, ic15 22 | from models.pan import PAN 23 | from loss.loss import loss 24 | from utils.helper import adjust_learning_rate, upsample 25 | from utils.average_meter import AverageMeter 26 | 27 | # main function: 28 | if __name__ == '__main__': 29 | freeze_support() 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | '--batch', type=int, default=16, help='input batch size') 33 | parser.add_argument( 34 | '--worker', type=int, default=4, help='number of data loading workers') 35 | parser.add_argument( 36 | '--epoch', type=int, default=601, help='number of epochs') 37 | parser.add_argument('--output', type=str, default='outputs', help='output folder name') 38 | parser.add_argument('--model', type=str, default='', help='model path') 39 | parser.add_argument('--dataset_type', type=str, default='ctw', help="dataset type - ctw | tt | synthtext | msra | ic15") 40 | parser.add_argument('--gpu', type=bool, default=False, help="GPU being used or not") 41 | 42 | opt = parser.parse_args() 43 | print(opt) 44 | 45 | opt.manualSeed = random.randint(1, 10000) # fix seed 46 | print("Random Seed:", opt.manualSeed) 47 | random.seed(opt.manualSeed) 48 | torch.manual_seed(opt.manualSeed) 49 | torch.cuda.manual_seed(opt.manualSeed) 50 | np.random.seed(opt.manualSeed) 51 | 52 | # turn on GPU for models: 53 | if opt.gpu == False: 54 | device = torch.device("cpu") 55 | print("CPU being used!") 56 | else: 57 | if torch.cuda.is_available() == True and opt.gpu == True: 58 | device = torch.device("cuda") 59 | print("GPU being used!") 60 | else: 61 | device = torch.device("cpu") 62 | print("CPU being used!") 63 | 64 | # set training parameters 65 | batch_size = opt.batch 66 | neck_channel = (64, 128, 256, 512) 67 | pa_in_channels = 512 68 | hidden_dim = 128 69 | num_classes = 6 70 | loss_text_weight = 1.0 71 | loss_kernel_weight = 0.5 72 | loss_emb_weight = 0.25 73 | opt.optimizer = 'Adam' 74 | opt.lr = 1e-3 75 | opt.schedule = 'polylr' 76 | 77 | epochs = opt.epoch 78 | worker = opt.worker 79 | dataset_type = opt.dataset_type 80 | output_path = opt.output 81 | trained_model_path = opt.model 82 | 83 | # create dataset 84 | print("Create dataset......") 85 | if dataset_type == 'ctw': # ctw dataset 86 | train_dataset = ctw1500.PAN_CTW(split='train', 87 | is_transform=True, 88 | img_size=640, 89 | short_size=640, 90 | kernel_scale=0.7, 91 | report_speed=False) 92 | elif dataset_type == 'tt': # totaltext dataset 93 | train_dataset = totaltext.PAN_TT(split='train', 94 | is_transform=True, 95 | img_size=640, 96 | short_size=640, 97 | kernel_scale=0.7, 98 | with_rec=False, 99 | report_speed=False) 100 | elif dataset_type == 'synthtext': # synthtext dataset 101 | train_dataset = synthtext.PAN_Synth(is_transform=True, 102 | img_size=640, 103 | short_size=640, 104 | kernel_scale=0.5, 105 | with_rec=False) 106 | elif dataset_type == 'msra': # msra dataset 107 | train_dataset = msra.PAN_MSRA(split='train', 108 | is_transform=True, 109 | img_size=736, 110 | short_size=736, 111 | kernel_scale=0.7, 112 | report_speed=False) 113 | elif dataset_type == 'ic15': # msra dataset 114 | train_dataset = ic15.PAN_IC15(split='train', 115 | is_transform=True, 116 | img_size=736, 117 | short_size=736, 118 | kernel_scale=0.5, 119 | with_rec=False) 120 | else: 121 | print("Not supported yet!") 122 | exit(1) 123 | 124 | # make dataloader 125 | train_dataloader = torch.utils.data.DataLoader( 126 | train_dataset, 127 | batch_size=batch_size, 128 | shuffle=True, 129 | num_workers=int(worker), 130 | drop_last=True, 131 | pin_memory=True) 132 | 133 | print("Length of train dataset is:", len(train_dataset)) 134 | 135 | # make model output folder 136 | try: 137 | os.makedirs(output_path) 138 | except OSError: 139 | pass 140 | 141 | # create model 142 | print("Create model......") 143 | model = PAN(pretrained=False, neck_channel=neck_channel, pa_in_channels=pa_in_channels, hidden_dim=hidden_dim, num_classes=num_classes) 144 | 145 | if trained_model_path != '': 146 | if torch.cuda.is_available() == True and opt.gpu == True: 147 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 148 | model = torch.nn.DataParallel(model).to(device) 149 | else: 150 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 151 | else: 152 | if torch.cuda.is_available() == True and opt.gpu == True: 153 | model = torch.nn.DataParallel(model).to(device) 154 | else: 155 | model = model.to(device) 156 | 157 | if opt.optimizer == 'SGD': 158 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.99, weight_decay=5e-4) 159 | elif opt.optimizer == 'Adam': 160 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 161 | else: 162 | print("Error: Please specify correct optimizer!") 163 | exit(1) 164 | 165 | # train, evaluate, and save model 166 | print("Training starts......") 167 | 168 | start_epoch = 0 169 | 170 | for epoch in range(start_epoch, epochs): 171 | print('Epoch: [%d | %d]' % (epoch + 1, epochs)) 172 | model.train() 173 | 174 | # meters 175 | losses = AverageMeter() 176 | losses_text = AverageMeter() 177 | losses_kernels = AverageMeter() 178 | losses_emb = AverageMeter() 179 | losses_rec = AverageMeter() 180 | ious_text = AverageMeter() 181 | ious_kernel = AverageMeter() 182 | 183 | for iter, data in enumerate(train_dataloader): 184 | 185 | # adjust learning rate 186 | adjust_learning_rate(optimizer, train_dataloader, epoch, iter, opt.schedule, opt.lr, epochs) 187 | 188 | outputs = dict() 189 | # forward for detection output 190 | det_out = model(data['imgs'].to(device)) 191 | det_out = upsample(det_out, data['imgs'].size()) 192 | # retreive ground truth labels 193 | gt_texts = data['gt_texts'].to(device) 194 | gt_kernels = data['gt_kernels'].to(device) 195 | training_masks = data['training_masks'].to(device) 196 | gt_instances = data['gt_instances'].to(device) 197 | gt_bboxes = data['gt_bboxes'].to(device) 198 | # calculate total loss 199 | det_loss = loss(det_out, gt_texts, gt_kernels, training_masks, gt_instances, gt_bboxes, loss_text_weight, loss_kernel_weight, loss_emb_weight) 200 | outputs.update(det_loss) 201 | 202 | # detection loss 203 | loss_text = torch.mean(outputs['loss_text']) 204 | losses_text.update(loss_text.item()) 205 | 206 | loss_kernels = torch.mean(outputs['loss_kernels']) 207 | losses_kernels.update(loss_kernels.item()) 208 | 209 | loss_emb = torch.mean(outputs['loss_emb']) 210 | losses_emb.update(loss_emb.item()) 211 | 212 | loss_total = loss_text + loss_kernels + loss_emb 213 | 214 | iou_text = torch.mean(outputs['iou_text']) 215 | ious_text.update(iou_text.item()) 216 | iou_kernel = torch.mean(outputs['iou_kernel']) 217 | ious_kernel.update(iou_kernel.item()) 218 | 219 | losses.update(loss_total.item()) 220 | 221 | # backward 222 | optimizer.zero_grad() 223 | loss_total.backward() 224 | optimizer.step() 225 | 226 | # print log 227 | #print("batch: {} / total batch: {}".format(iter+1, len(train_dataloader))) 228 | if iter % 20 == 0: 229 | output_log = '({batch}/{size}) LR: {lr:.6f} | ' \ 230 | 'Loss: {loss:.3f} | ' \ 231 | 'Loss (text/kernel/emb): {loss_text:.3f}/{loss_kernel:.3f}/{loss_emb:.3f} ' \ 232 | '| IoU (text/kernel): {iou_text:.3f}/{iou_kernel:.3f}'.format( 233 | batch=iter + 1, 234 | size=len(train_dataloader), 235 | lr=optimizer.param_groups[0]['lr'], 236 | loss_text=losses_text.avg, 237 | loss_kernel=losses_kernels.avg, 238 | loss_emb=losses_emb.avg, 239 | loss=losses.avg, 240 | iou_text=ious_text.avg, 241 | iou_kernel=ious_kernel.avg, 242 | ) 243 | print(output_log) 244 | sys.stdout.flush() 245 | with open(os.path.join(output_path,'statistics.txt'), 'a') as f: 246 | f.write("{} {} {} {} {} {}\n".format(losses_text.avg, losses_kernels.avg, losses_emb.avg, losses.avg, ious_text.avg, ious_kernel.avg)) 247 | 248 | if epoch % 20 == 0: 249 | print("Save model......") 250 | if torch.cuda.is_available() == True and opt.gpu == True: 251 | torch.save(model.module.state_dict(), '%s/model_epoch_%s.pth' % (output_path, str(epoch))) 252 | else: 253 | torch.save(model.state_dict(), '%s/model_epoch_%s.pth' % (output_path, str(epoch))) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/utils/__init__.py -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, max_len=-1): 4 | self.val = [] 5 | self.count = [] 6 | self.max_len = max_len 7 | self.avg = 0 8 | 9 | def update(self, val, n=1): 10 | self.val.append(val * n) 11 | self.count.append(n) 12 | if self.max_len > 0 and len(self.val) > self.max_len: 13 | self.val = self.val[-self.max_len:] 14 | self.count = self.count[-self.max_len:] 15 | self.avg = sum(self.val) / sum(self.count) -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Helper functions. 3 | ''' 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import os 9 | import cv2 10 | from .pa.pa import pa 11 | import pdb 12 | import zipfile 13 | 14 | def upsample(x, size, scale=1): 15 | _, _, H, W = size 16 | return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear') 17 | 18 | def adjust_learning_rate(optimizer, dataloader, epoch, iter, schedule, lr, num_epoch): 19 | if isinstance(schedule, str): 20 | assert schedule == 'polylr', 'Error: schedule should be polylr!' 21 | cur_iter = epoch * len(dataloader) + iter 22 | max_iter_num = num_epoch * len(dataloader) 23 | lr = lr * (1 - float(cur_iter) / max_iter_num) ** 0.9 24 | elif isinstance(schedule, tuple): 25 | for i in range(len(schedule)): 26 | if epoch < schedule[i]: 27 | break 28 | lr = lr * 0.1 29 | 30 | for param_group in optimizer.param_groups: 31 | param_group['lr'] = lr 32 | 33 | def get_results(out, img_meta, min_area, min_score, bbox_type): 34 | outputs = dict() 35 | 36 | score = torch.sigmoid(out[:, 0, :, :]) 37 | kernels = out[:, :2, :, :] > 0 38 | text_mask = kernels[:, :1, :, :] 39 | kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask 40 | emb = out[:, 2:, :, :] 41 | emb = emb * text_mask.float() 42 | 43 | score = score.data.cpu().numpy()[0].astype(np.float32) 44 | kernels = kernels.data.cpu().numpy()[0].astype(np.uint8) 45 | emb = emb.cpu().numpy()[0].astype(np.float32) 46 | 47 | # pa 48 | label = pa(kernels, emb) 49 | 50 | # image size 51 | org_img_size = img_meta['org_img_size'][0] 52 | img_size = img_meta['img_size'][0] 53 | 54 | label_num = np.max(label) + 1 55 | label = cv2.resize(label, (img_size[1], img_size[0]), interpolation=cv2.INTER_NEAREST) 56 | score = cv2.resize(score, (img_size[1], img_size[0]), interpolation=cv2.INTER_NEAREST) 57 | 58 | scale = (float(org_img_size[1]) / float(img_size[1]), 59 | float(org_img_size[0]) / float(img_size[0])) 60 | 61 | bboxes = [] 62 | scores = [] 63 | for i in range(1, label_num): 64 | ind = label == i 65 | points = np.array(np.where(ind)).transpose((1, 0)) 66 | 67 | if points.shape[0] < min_area: 68 | label[ind] = 0 69 | continue 70 | 71 | score_i = np.mean(score[ind]) 72 | if score_i < min_score: 73 | label[ind] = 0 74 | continue 75 | 76 | if bbox_type == 'rect': 77 | rect = cv2.minAreaRect(points[:, ::-1]) 78 | bbox = cv2.boxPoints(rect) * scale 79 | elif bbox_type == 'poly': 80 | binary = np.zeros(label.shape, dtype='uint8') 81 | binary[ind] = 1 82 | contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # bug in official released code 83 | bbox = contours[0] * scale 84 | 85 | bbox = bbox.astype('int32') 86 | bboxes.append(bbox.reshape(-1)) 87 | scores.append(score_i) 88 | 89 | outputs.update(dict( 90 | bboxes=bboxes, 91 | scores=scores 92 | )) 93 | 94 | return outputs 95 | 96 | def write_result(image_name, outputs, result_path): 97 | bboxes = outputs['bboxes'] 98 | 99 | lines = [] 100 | for i, bbox in enumerate(bboxes): 101 | #bbox = bbox.reshape(-1, 2)[:, ::-1].reshape(-1) 102 | bbox = bbox.reshape(-1, 2).reshape(-1) # fix write output format in (x,y) order 103 | values = [int(v) for v in bbox] 104 | line = "%d" % values[0] 105 | for v_id in range(1, len(values)): 106 | line += ",%d" % values[v_id] 107 | line += '\n' 108 | lines.append(line) 109 | 110 | file_name = '%s.txt' % image_name 111 | file_path = os.path.join(result_path, file_name) 112 | with open(file_path, 'w') as f: 113 | for line in lines: 114 | f.write(line) 115 | 116 | def draw_result(image_path, outputs, output_path): 117 | image_name, _ = os.path.splitext(os.path.basename(image_path)) 118 | num_contour = len(outputs['bboxes']) 119 | contours = [] 120 | for i in range(num_contour): 121 | contour = outputs['bboxes'][i] 122 | num_pair = len(contour) // 2 123 | contour = contour.reshape((num_pair, 2)) 124 | contours.append(contour) 125 | contours = np.asarray(contours) 126 | img = cv2.imread(image_path) 127 | img = cv2.drawContours(img, contours, -1, (0,255,0), 2) 128 | cv2.imwrite(os.path.join(output_path, image_name+'.png'), img) 129 | 130 | def write_result_ic15(img_name, outputs, result_path): 131 | assert result_path.endswith('.zip'), 'Error: ic15 result should be a zip file!' 132 | 133 | tmp_folder = result_path.replace('.zip', '') 134 | 135 | bboxes = outputs['bboxes'] 136 | 137 | lines = [] 138 | for i, bbox in enumerate(bboxes): 139 | values = [int(v) for v in bbox] 140 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n" % tuple(values) 141 | lines.append(line) 142 | 143 | file_name = 'res_%s.txt' % img_name 144 | file_path = os.path.join(tmp_folder, file_name) 145 | with open(file_path, 'w') as f: 146 | for line in lines: 147 | f.write(line) 148 | 149 | z = zipfile.ZipFile(result_path, 'a', zipfile.ZIP_DEFLATED) 150 | z.write(file_path, file_name) 151 | z.close() -------------------------------------------------------------------------------- /utils/pa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/pan-pytorch/e08ebcfa7568a47f8fcec48b302380749ef3776d/utils/pa/__init__.py -------------------------------------------------------------------------------- /utils/pa/pa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is an unofficial implementation of pixel aggregation function in pure python modified from pa.pyx 3 | ''' 4 | 5 | import numpy as np 6 | import cv2 7 | import pdb 8 | 9 | def _pa(kernels, emb, label, cc, kernel_num, label_num, min_area=0): 10 | pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 11 | mean_emb = np.zeros((label_num, 4), dtype=np.float32) 12 | area = np.full((label_num,), -1, dtype=np.float32) 13 | flag = np.zeros((label_num,), dtype=np.int32) 14 | inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 15 | p = np.zeros((label_num, 2), dtype=np.int32) 16 | 17 | max_rate = 1024 18 | for i in range(1, label_num): 19 | ind = label == i 20 | inds[i] = ind 21 | 22 | area[i] = np.sum(ind) 23 | 24 | if area[i] < min_area: 25 | label[ind] = 0 26 | continue 27 | 28 | px, py = np.where(ind) 29 | p[i] = (px[0], py[0]) 30 | 31 | for j in range(1, i): 32 | if area[j] < min_area: 33 | continue 34 | if cc[p[i, 0], p[i, 1]] != cc[p[j, 0], p[j, 1]]: 35 | continue 36 | rate = area[i] / area[j] 37 | if rate < 1 / max_rate or rate > max_rate: 38 | flag[i] = 1 39 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 40 | 41 | if flag[j] == 0: 42 | flag[j] = 1 43 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 44 | 45 | que = [] 46 | nxt_que = [] 47 | dx = [-1, 1, 0, 0] 48 | dy = [0, 0, -1, 1] 49 | 50 | points = np.array(np.where(label > 0)).transpose((1, 0)) 51 | for point_idx in range(points.shape[0]): 52 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 53 | que.append((tmpx, tmpy)) 54 | pred[tmpx, tmpy] = label[tmpx, tmpy] 55 | 56 | for kernel_idx in range(kernel_num - 2, -1, -1): 57 | while que: 58 | cur = que[0] 59 | que.pop(0) 60 | cur_label = pred[cur[0], cur[1]] 61 | 62 | is_edge = True 63 | for j in range(4): 64 | tmpx = cur[0] + dx[j] 65 | tmpy = cur[1] + dy[j] 66 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 67 | continue 68 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 69 | continue 70 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 71 | continue 72 | 73 | que.append((tmpx, tmpy)) 74 | pred[tmpx, tmpy] = cur_label 75 | is_edge = False 76 | if is_edge: 77 | nxt_que.append(cur) 78 | 79 | que, nxt_que = nxt_que, que 80 | 81 | return pred 82 | 83 | def pa(kernels, emb, min_area=0): 84 | kernel_num = kernels.shape[0] 85 | _, cc = cv2.connectedComponents(kernels[0], connectivity=4) # text region connected components 86 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) # kernel region connected components 87 | return _pa(kernels[:-1], emb, label, cc, kernel_num, label_num, min_area) -------------------------------------------------------------------------------- /utils/pa/pa.pyx: -------------------------------------------------------------------------------- 1 | import setuptools # important 2 | import numpy as np 3 | import cv2 4 | import torch 5 | cimport numpy as np 6 | cimport cython 7 | cimport libcpp 8 | cimport libcpp.pair 9 | cimport libcpp.queue 10 | from libcpp.pair cimport * 11 | from libcpp.queue cimport * 12 | 13 | @cython.boundscheck(False) 14 | @cython.wraparound(False) 15 | cdef np.ndarray[np.int32_t, ndim=2] _pa(np.ndarray[np.uint8_t, ndim=3] kernels, 16 | np.ndarray[np.float32_t, ndim=3] emb, 17 | np.ndarray[np.int32_t, ndim=2] label, 18 | np.ndarray[np.int32_t, ndim=2] cc, 19 | int kernel_num, 20 | int label_num, 21 | float min_area=0): 22 | cdef np.ndarray[np.int32_t, ndim=2] pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 23 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb = np.zeros((label_num, 4), dtype=np.float32) 24 | cdef np.ndarray[np.float32_t, ndim=1] area = np.full((label_num,), -1, dtype=np.float32) 25 | cdef np.ndarray[np.int32_t, ndim=1] flag = np.zeros((label_num,), dtype=np.int32) 26 | cdef np.ndarray[np.uint8_t, ndim=3] inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 27 | cdef np.ndarray[np.int32_t, ndim=2] p = np.zeros((label_num, 2), dtype=np.int32) 28 | 29 | cdef np.float32_t max_rate = 1024 30 | for i in range(1, label_num): 31 | ind = label == i 32 | inds[i] = ind 33 | 34 | area[i] = np.sum(ind) 35 | 36 | if area[i] < min_area: 37 | label[ind] = 0 38 | continue 39 | 40 | px, py = np.where(ind) 41 | p[i] = (px[0], py[0]) 42 | 43 | for j in range(1, i): 44 | if area[j] < min_area: 45 | continue 46 | if cc[p[i, 0], p[i, 1]] != cc[p[j, 0], p[j, 1]]: 47 | continue 48 | rate = area[i] / area[j] 49 | if rate < 1 / max_rate or rate > max_rate: 50 | flag[i] = 1 51 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 52 | 53 | if flag[j] == 0: 54 | flag[j] = 1 55 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 56 | 57 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t, np.int16_t]] que = \ 58 | queue[libcpp.pair.pair[np.int16_t, np.int16_t]]() 59 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t, np.int16_t]] nxt_que = \ 60 | queue[libcpp.pair.pair[np.int16_t, np.int16_t]]() 61 | cdef np.int16_t*dx = [-1, 1, 0, 0] 62 | cdef np.int16_t*dy = [0, 0, -1, 1] 63 | cdef np.int16_t tmpx, tmpy 64 | 65 | points = np.array(np.where(label > 0)).transpose((1, 0)) 66 | for point_idx in range(points.shape[0]): 67 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 68 | que.push(pair[np.int16_t, np.int16_t](tmpx, tmpy)) 69 | pred[tmpx, tmpy] = label[tmpx, tmpy] 70 | 71 | cdef libcpp.pair.pair[np.int16_t, np.int16_t] cur 72 | cdef int cur_label 73 | for kernel_idx in range(kernel_num - 2, -1, -1): 74 | while not que.empty(): 75 | cur = que.front() 76 | que.pop() 77 | cur_label = pred[cur.first, cur.second] 78 | 79 | is_edge = True 80 | for j in range(4): 81 | tmpx = cur.first + dx[j] 82 | tmpy = cur.second + dy[j] 83 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 84 | continue 85 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 86 | continue 87 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 88 | continue 89 | 90 | que.push(pair[np.int16_t, np.int16_t](tmpx, tmpy)) 91 | pred[tmpx, tmpy] = cur_label 92 | is_edge = False 93 | if is_edge: 94 | nxt_que.push(cur) 95 | 96 | que, nxt_que = nxt_que, que 97 | 98 | return pred 99 | 100 | def pa(kernels, emb, min_area=0): 101 | kernel_num = kernels.shape[0] 102 | _, cc = cv2.connectedComponents(kernels[0], connectivity=4) 103 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) 104 | 105 | return _pa(kernels[:-1], emb, label, cc, kernel_num, label_num, min_area) 106 | -------------------------------------------------------------------------------- /utils/pa/readme.txt: -------------------------------------------------------------------------------- 1 | 1) To use Cython version of pa build .so object first: python setup.py build_ext --inplace 2 | 2) Or to use pure Python version of pa by importing pa.py -------------------------------------------------------------------------------- /utils/pa/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | from Cython.Build import cythonize 3 | import numpy 4 | setup(ext_modules = cythonize(Extension( 5 | 'pa', 6 | sources=['pa.pyx'], 7 | language='c++', 8 | include_dirs=[numpy.get_include()], 9 | library_dirs=[], 10 | libraries=[], 11 | extra_compile_args=['-O3'], 12 | extra_link_args=[] 13 | ))) --------------------------------------------------------------------------------