├── LICENSE ├── README.md ├── dla ├── config │ ├── cascadeRCNN.py │ ├── cascadeRCNN_full_table_only.py │ ├── cascadeRCNN_ignore_all_but_cells.py │ └── cascadeRCNN_ignore_cells.py └── src │ ├── analyze_logs.py │ ├── construct_VOC.py │ ├── convert_ICDAR_to_VOC.py │ ├── convert_VOC_to_ICDAR.py │ ├── eval_ICDAR.py │ ├── eval_ICDAR_wF1.py │ ├── eval_VOC.py │ ├── eval_VOC_wF1.py │ ├── eval_rows_n_cols_only.py │ ├── get_dataset_summary.py │ ├── image_utils.py │ ├── inference.py │ ├── inference_original.py │ ├── inference_regiongiven.py │ ├── install.py │ ├── installation_files │ ├── cascade_rcnn_frozen.py │ ├── cascade_rcnn_frozenrpn.py │ ├── dataset__init__.py │ ├── detectors__init__.py │ ├── ignoringvoc.py │ ├── inference.py │ └── mean_ap.py │ ├── projection_histogram.py │ ├── table_structure_analysis.py │ ├── visualise_annotations.py │ └── xml_utils.py └── examples ├── table_detection_105.jpg ├── table_detection_3.jpg ├── table_detection_300.jpg ├── table_detection_enhanced_105.jpg ├── table_detection_enhanced_3.jpg ├── table_detection_enhanced_300.jpg ├── table_struct_recog_coarse_105.jpg ├── table_struct_recog_coarse_3.jpg ├── table_struct_recog_coarse_300.jpg ├── table_struct_recog_fine_105.jpg ├── table_struct_recog_fine_3.jpg └── table_struct_recog_fine_300.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, University of Southampton 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code or data must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. All advertising and publication materials containing results from the use 15 | of this software or data must acknowledge the University of Southampton 16 | and cite the following paper which describes the dataset: 17 | 18 | Ziomek. J. Middleton, S.E. 19 | GloSAT Historical Measurement Table Dataset: Enhanced Table Structure Recognition Annotation for Downstream Historical Data Rescue, 20 | 6th International Workshop on Historical Document Imaging and Processing (HIP-2021), 21 | Sept 5-6, 2021, Lausanne, Switzerland 22 | 23 | 4. Neither the name of the University of Southampton nor the 24 | names of its contributors may be used to endorse or promote products 25 | derived from this software without specific prior written permission. 26 | 27 | THIS SOFTWARE AND DATA IS PROVIDED BY University of Southampton ''AS IS'' AND ANY 28 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 29 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 30 | DISCLAIMED. IN NO EVENT SHALL University of Southampton BE LIABLE FOR ANY 31 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 32 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 33 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 34 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 35 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 36 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 37 | -------------------------------------------------------------------------------- /dla/config/cascadeRCNN.py: -------------------------------------------------------------------------------- 1 | #Adapted from https://github.com/DevashishPrasad/CascadeTabNet 2 | 3 | import os 4 | 5 | # basic settings 6 | model_dir='/scratch/sem03/glosat/dla_models/model_29_09_2020_train' 7 | resume_from = None 8 | #total_epochs = 601 9 | total_epochs = 100 10 | 11 | # model settings 12 | model = dict( 13 | type='CascadeRCNN', 14 | num_stages=3, 15 | pretrained='open-mmlab://msra/hrnetv2_w32', 16 | backbone=dict( 17 | type='HRNet', 18 | extra=dict( 19 | stage1=dict( 20 | num_modules=1, 21 | num_branches=1, 22 | block='BOTTLENECK', 23 | num_blocks=(4, ), 24 | num_channels=(64, )), 25 | stage2=dict( 26 | num_modules=1, 27 | num_branches=2, 28 | block='BASIC', 29 | num_blocks=(4, 4), 30 | num_channels=(32, 64)), 31 | stage3=dict( 32 | num_modules=4, 33 | num_branches=3, 34 | block='BASIC', 35 | num_blocks=(4, 4, 4), 36 | num_channels=(32, 64, 128)), 37 | stage4=dict( 38 | num_modules=3, 39 | num_branches=4, 40 | block='BASIC', 41 | num_blocks=(4, 4, 4, 4), 42 | num_channels=(32, 64, 128, 256)))), 43 | neck=dict(type='HRFPN', in_channels=[32, 64, 128, 256], out_channels=256), 44 | rpn_head=dict( 45 | type='RPNHead', 46 | in_channels=256, 47 | feat_channels=256, 48 | anchor_scales=[8], 49 | anchor_ratios=[0.5, 1.0, 2.0], 50 | anchor_strides=[4, 8, 16, 32, 64], 51 | target_means=[.0, .0, .0, .0], 52 | target_stds=[1.0, 1.0, 1.0, 1.0], 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 55 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), 56 | bbox_roi_extractor=dict( 57 | type='SingleRoIExtractor', 58 | roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), 59 | out_channels=256, 60 | featmap_strides=[4, 8, 16, 32]), 61 | bbox_head=[ 62 | dict( 63 | type='SharedFCBBoxHead', 64 | num_fcs=2, 65 | in_channels=256, 66 | fc_out_channels=1024, 67 | roi_feat_size=7, 68 | num_classes=81, 69 | target_means=[0., 0., 0., 0.], 70 | target_stds=[0.1, 0.1, 0.2, 0.2], 71 | reg_class_agnostic=True, 72 | loss_cls=dict( 73 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 74 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 75 | dict( 76 | type='SharedFCBBoxHead', 77 | num_fcs=2, 78 | in_channels=256, 79 | fc_out_channels=1024, 80 | roi_feat_size=7, 81 | num_classes=81, 82 | target_means=[0., 0., 0., 0.], 83 | target_stds=[0.05, 0.05, 0.1, 0.1], 84 | reg_class_agnostic=True, 85 | loss_cls=dict( 86 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 87 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 88 | dict( 89 | type='SharedFCBBoxHead', 90 | num_fcs=2, 91 | in_channels=256, 92 | fc_out_channels=1024, 93 | roi_feat_size=7, 94 | num_classes=81, 95 | target_means=[0., 0., 0., 0.], 96 | target_stds=[0.033, 0.033, 0.067, 0.067], 97 | reg_class_agnostic=True, 98 | loss_cls=dict( 99 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 100 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) 101 | ]) 102 | # model training and testing settings 103 | train_cfg = dict( 104 | rpn=dict( 105 | assigner=dict( 106 | type='MaxIoUAssigner', 107 | pos_iou_thr=0.7, 108 | neg_iou_thr=0.3, 109 | min_pos_iou=0.3, 110 | ignore_iof_thr=-1), 111 | sampler=dict( 112 | type='RandomSampler', 113 | num=256, 114 | pos_fraction=0.5, 115 | neg_pos_ub=-1, 116 | add_gt_as_proposals=False), 117 | allowed_border=0, 118 | pos_weight=-1, 119 | debug=False), 120 | rpn_proposal=dict( 121 | nms_across_levels=False, 122 | nms_pre=2000, 123 | nms_post=2000, 124 | max_num=2000, 125 | nms_thr=0.7, 126 | min_bbox_size=0), 127 | rcnn=[ 128 | dict( 129 | assigner=dict( 130 | type='MaxIoUAssigner', 131 | pos_iou_thr=0.5, 132 | neg_iou_thr=0.5, 133 | min_pos_iou=0.5, 134 | ignore_iof_thr=-1), 135 | sampler=dict( 136 | type='RandomSampler', 137 | num=512, 138 | pos_fraction=0.25, 139 | neg_pos_ub=-1, 140 | add_gt_as_proposals=True), 141 | mask_size=28, 142 | pos_weight=-1, 143 | debug=False), 144 | dict( 145 | assigner=dict( 146 | type='MaxIoUAssigner', 147 | pos_iou_thr=0.6, 148 | neg_iou_thr=0.6, 149 | min_pos_iou=0.6, 150 | ignore_iof_thr=-1), 151 | sampler=dict( 152 | type='RandomSampler', 153 | num=512, 154 | pos_fraction=0.25, 155 | neg_pos_ub=-1, 156 | add_gt_as_proposals=True), 157 | mask_size=28, 158 | pos_weight=-1, 159 | debug=False), 160 | dict( 161 | assigner=dict( 162 | type='MaxIoUAssigner', 163 | pos_iou_thr=0.7, 164 | neg_iou_thr=0.7, 165 | min_pos_iou=0.7, 166 | ignore_iof_thr=-1), 167 | sampler=dict( 168 | type='RandomSampler', 169 | num=512, 170 | pos_fraction=0.25, 171 | neg_pos_ub=-1, 172 | add_gt_as_proposals=True), 173 | mask_size=28, 174 | pos_weight=-1, 175 | debug=False) 176 | ], 177 | stage_loss_weights=[1, 0.5, 0.25]) 178 | test_cfg = dict( 179 | rpn=dict( 180 | nms_across_levels=False, 181 | nms_pre=1000, 182 | nms_post=1000, 183 | max_num=1000, 184 | nms_thr=0.7, 185 | min_bbox_size=0), 186 | rcnn=dict( 187 | score_thr=0.05, 188 | nms=dict(type='nms', iou_thr=0.5), 189 | max_per_img=100, 190 | mask_thr_binary=0.5)) 191 | # dataset settings 192 | dataset_type = 'VOCDataset' 193 | data_root = model_dir + os.sep + 'VOC2007' 194 | img_norm_cfg = dict( 195 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 196 | train_pipeline = [ 197 | dict(type='LoadImageFromFile'), 198 | dict(type='LoadAnnotations', with_bbox=True), 199 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 200 | dict(type='RandomFlip', flip_ratio=0.5), 201 | dict(type='Normalize', **img_norm_cfg), 202 | dict(type='Pad', size_divisor=32), 203 | dict(type='DefaultFormatBundle'), 204 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 205 | ] 206 | test_pipeline = [ 207 | dict(type='LoadImageFromFile'), 208 | dict( 209 | type='MultiScaleFlipAug', 210 | img_scale=(1333, 800), 211 | flip=False, 212 | transforms=[ 213 | dict(type='Resize', keep_ratio=True), 214 | dict(type='RandomFlip'), 215 | dict(type='Normalize', **img_norm_cfg), 216 | dict(type='Pad', size_divisor=32), 217 | dict(type='ImageToTensor', keys=['img']), 218 | dict(type='Collect', keys=['img']), 219 | ]) 220 | ] 221 | data = dict( 222 | imgs_per_gpu=1, 223 | workers_per_gpu=1, 224 | train=dict( 225 | type=dataset_type, 226 | ann_file= model_dir + os.sep + 'VOC2007/ImageSets/main.txt', 227 | img_prefix= model_dir + os.sep + 'VOC2007/', 228 | pipeline=train_pipeline), 229 | test=dict( 230 | type=dataset_type, 231 | ann_file=data_root + 'VOC2007/test.json', 232 | img_prefix=data_root + 'VOC2007/Test/', 233 | pipeline=test_pipeline)) 234 | # evaluation = dict(interval=1, metric=['bbox']) 235 | # optimizer 236 | optimizer = dict(type='SGD', lr=0.0012, momentum=0.9, weight_decay=0.0001) 237 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 238 | # learning policy 239 | lr_config = dict( 240 | policy='step', 241 | warmup='linear', 242 | warmup_iters=500, 243 | warmup_ratio=1.0 / 3, 244 | step=[16, 19]) 245 | checkpoint_config = dict(interval=1,create_symlink=False) 246 | # yapf:disable 247 | log_config = dict( 248 | interval=50, 249 | hooks=[ 250 | dict(type='TextLoggerHook'), 251 | # dict(type='TensorboardLoggerHook') 252 | ]) 253 | # yapf:enable 254 | # runtime settings 255 | dist_params = dict(backend='nccl') 256 | log_level = 'DEBUG' 257 | load_from = None 258 | workflow = [('train', 1)] 259 | -------------------------------------------------------------------------------- /dla/config/cascadeRCNN_full_table_only.py: -------------------------------------------------------------------------------- 1 | #Adapted from https://github.com/DevashishPrasad/CascadeTabNet 2 | 3 | import os 4 | 5 | # basic settings 6 | model_dir='/scratch/sem03/glosat/dla_models/model_29_09_2020_train' 7 | resume_from = None 8 | #total_epochs = 601 9 | total_epochs = 100 10 | 11 | # model settings 12 | model = dict( 13 | type='CascadeRCNN', 14 | num_stages=3, 15 | pretrained='open-mmlab://msra/hrnetv2_w32', 16 | backbone=dict( 17 | type='HRNet', 18 | extra=dict( 19 | stage1=dict( 20 | num_modules=1, 21 | num_branches=1, 22 | block='BOTTLENECK', 23 | num_blocks=(4, ), 24 | num_channels=(64, )), 25 | stage2=dict( 26 | num_modules=1, 27 | num_branches=2, 28 | block='BASIC', 29 | num_blocks=(4, 4), 30 | num_channels=(32, 64)), 31 | stage3=dict( 32 | num_modules=4, 33 | num_branches=3, 34 | block='BASIC', 35 | num_blocks=(4, 4, 4), 36 | num_channels=(32, 64, 128)), 37 | stage4=dict( 38 | num_modules=3, 39 | num_branches=4, 40 | block='BASIC', 41 | num_blocks=(4, 4, 4, 4), 42 | num_channels=(32, 64, 128, 256)))), 43 | neck=dict(type='HRFPN', in_channels=[32, 64, 128, 256], out_channels=256), 44 | rpn_head=dict( 45 | type='RPNHead', 46 | in_channels=256, 47 | feat_channels=256, 48 | anchor_scales=[8], 49 | anchor_ratios=[0.5, 1.0, 2.0], 50 | anchor_strides=[4, 8, 16, 32, 64], 51 | target_means=[.0, .0, .0, .0], 52 | target_stds=[1.0, 1.0, 1.0, 1.0], 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 55 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), 56 | bbox_roi_extractor=dict( 57 | type='SingleRoIExtractor', 58 | roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), 59 | out_channels=256, 60 | featmap_strides=[4, 8, 16, 32]), 61 | bbox_head=[ 62 | dict( 63 | type='SharedFCBBoxHead', 64 | num_fcs=2, 65 | in_channels=256, 66 | fc_out_channels=1024, 67 | roi_feat_size=7, 68 | num_classes=81, 69 | target_means=[0., 0., 0., 0.], 70 | target_stds=[0.1, 0.1, 0.2, 0.2], 71 | reg_class_agnostic=True, 72 | loss_cls=dict( 73 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 74 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 75 | dict( 76 | type='SharedFCBBoxHead', 77 | num_fcs=2, 78 | in_channels=256, 79 | fc_out_channels=1024, 80 | roi_feat_size=7, 81 | num_classes=81, 82 | target_means=[0., 0., 0., 0.], 83 | target_stds=[0.05, 0.05, 0.1, 0.1], 84 | reg_class_agnostic=True, 85 | loss_cls=dict( 86 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 87 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 88 | dict( 89 | type='SharedFCBBoxHead', 90 | num_fcs=2, 91 | in_channels=256, 92 | fc_out_channels=1024, 93 | roi_feat_size=7, 94 | num_classes=81, 95 | target_means=[0., 0., 0., 0.], 96 | target_stds=[0.033, 0.033, 0.067, 0.067], 97 | reg_class_agnostic=True, 98 | loss_cls=dict( 99 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 100 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) 101 | ]) 102 | # model training and testing settings 103 | train_cfg = dict( 104 | rpn=dict( 105 | assigner=dict( 106 | type='MaxIoUAssigner', 107 | pos_iou_thr=0.7, 108 | neg_iou_thr=0.3, 109 | min_pos_iou=0.3, 110 | ignore_iof_thr=-1), 111 | sampler=dict( 112 | type='RandomSampler', 113 | num=256, 114 | pos_fraction=0.5, 115 | neg_pos_ub=-1, 116 | add_gt_as_proposals=False), 117 | allowed_border=0, 118 | pos_weight=-1, 119 | debug=False), 120 | rpn_proposal=dict( 121 | nms_across_levels=False, 122 | nms_pre=2000, 123 | nms_post=2000, 124 | max_num=2000, 125 | nms_thr=0.7, 126 | min_bbox_size=0), 127 | rcnn=[ 128 | dict( 129 | assigner=dict( 130 | type='MaxIoUAssigner', 131 | pos_iou_thr=0.5, 132 | neg_iou_thr=0.5, 133 | min_pos_iou=0.5, 134 | ignore_iof_thr=-1), 135 | sampler=dict( 136 | type='RandomSampler', 137 | num=512, 138 | pos_fraction=0.25, 139 | neg_pos_ub=-1, 140 | add_gt_as_proposals=True), 141 | mask_size=28, 142 | pos_weight=-1, 143 | debug=False), 144 | dict( 145 | assigner=dict( 146 | type='MaxIoUAssigner', 147 | pos_iou_thr=0.6, 148 | neg_iou_thr=0.6, 149 | min_pos_iou=0.6, 150 | ignore_iof_thr=-1), 151 | sampler=dict( 152 | type='RandomSampler', 153 | num=512, 154 | pos_fraction=0.25, 155 | neg_pos_ub=-1, 156 | add_gt_as_proposals=True), 157 | mask_size=28, 158 | pos_weight=-1, 159 | debug=False), 160 | dict( 161 | assigner=dict( 162 | type='MaxIoUAssigner', 163 | pos_iou_thr=0.7, 164 | neg_iou_thr=0.7, 165 | min_pos_iou=0.7, 166 | ignore_iof_thr=-1), 167 | sampler=dict( 168 | type='RandomSampler', 169 | num=512, 170 | pos_fraction=0.25, 171 | neg_pos_ub=-1, 172 | add_gt_as_proposals=True), 173 | mask_size=28, 174 | pos_weight=-1, 175 | debug=False) 176 | ], 177 | stage_loss_weights=[1, 0.5, 0.25]) 178 | test_cfg = dict( 179 | rpn=dict( 180 | nms_across_levels=False, 181 | nms_pre=1000, 182 | nms_post=1000, 183 | max_num=1000, 184 | nms_thr=0.7, 185 | min_bbox_size=0), 186 | rcnn=dict( 187 | score_thr=0.05, 188 | nms=dict(type='nms', iou_thr=0.5), 189 | max_per_img=100, 190 | mask_thr_binary=0.5)) 191 | # dataset settings 192 | dataset_type = 'IgnoringVOCDataset' 193 | data_root = model_dir + os.sep + 'VOC2007' 194 | img_norm_cfg = dict( 195 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 196 | train_pipeline = [ 197 | dict(type='LoadImageFromFile'), 198 | dict(type='LoadAnnotations', with_bbox=True), 199 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 200 | dict(type='RandomFlip', flip_ratio=0.5), 201 | dict(type='Normalize', **img_norm_cfg), 202 | dict(type='Pad', size_divisor=32), 203 | dict(type='DefaultFormatBundle'), 204 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 205 | ] 206 | test_pipeline = [ 207 | dict(type='LoadImageFromFile'), 208 | dict( 209 | type='MultiScaleFlipAug', 210 | img_scale=(1333, 800), 211 | flip=False, 212 | transforms=[ 213 | dict(type='Resize', keep_ratio=True), 214 | dict(type='RandomFlip'), 215 | dict(type='Normalize', **img_norm_cfg), 216 | dict(type='Pad', size_divisor=32), 217 | dict(type='ImageToTensor', keys=['img']), 218 | dict(type='Collect', keys=['img']), 219 | ]) 220 | ] 221 | data = dict( 222 | imgs_per_gpu=1, 223 | workers_per_gpu=1, 224 | train=dict( 225 | type=dataset_type, 226 | ann_file= model_dir + os.sep + 'VOC2007/ImageSets/main.txt', 227 | img_prefix= model_dir + os.sep + 'VOC2007/', 228 | pipeline=train_pipeline, 229 | ignore = ('table_body','cell','header','heading')), 230 | test=dict( 231 | type=dataset_type, 232 | ann_file=data_root + 'VOC2007/test.json', 233 | img_prefix=data_root + 'VOC2007/Test/', 234 | pipeline=test_pipeline, 235 | ignore = ('table_body','cell','header','heading'))) 236 | # evaluation = dict(interval=1, metric=['bbox']) 237 | # optimizer 238 | optimizer = dict(type='SGD', lr=0.0012, momentum=0.9, weight_decay=0.0001) 239 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 240 | # learning policy 241 | lr_config = dict( 242 | policy='step', 243 | warmup='linear', 244 | warmup_iters=500, 245 | warmup_ratio=1.0 / 3, 246 | step=[16, 19]) 247 | checkpoint_config = dict(interval=1,create_symlink=False) 248 | # yapf:disable 249 | log_config = dict( 250 | interval=50, 251 | hooks=[ 252 | dict(type='TextLoggerHook'), 253 | # dict(type='TensorboardLoggerHook') 254 | ]) 255 | # yapf:enable 256 | # runtime settings 257 | dist_params = dict(backend='nccl') 258 | log_level = 'DEBUG' 259 | load_from = None 260 | workflow = [('train', 1)] 261 | -------------------------------------------------------------------------------- /dla/config/cascadeRCNN_ignore_all_but_cells.py: -------------------------------------------------------------------------------- 1 | #Adapted from https://github.com/DevashishPrasad/CascadeTabNet 2 | 3 | import os 4 | 5 | # basic settings 6 | model_dir='/scratch/sem03/glosat/dla_models/model_29_09_2020_train' 7 | resume_from = None 8 | #total_epochs = 601 9 | total_epochs = 100 10 | 11 | # model settings 12 | model = dict( 13 | type='CascadeRCNN', 14 | num_stages=3, 15 | pretrained='open-mmlab://msra/hrnetv2_w32', 16 | backbone=dict( 17 | type='HRNet', 18 | extra=dict( 19 | stage1=dict( 20 | num_modules=1, 21 | num_branches=1, 22 | block='BOTTLENECK', 23 | num_blocks=(4, ), 24 | num_channels=(64, )), 25 | stage2=dict( 26 | num_modules=1, 27 | num_branches=2, 28 | block='BASIC', 29 | num_blocks=(4, 4), 30 | num_channels=(32, 64)), 31 | stage3=dict( 32 | num_modules=4, 33 | num_branches=3, 34 | block='BASIC', 35 | num_blocks=(4, 4, 4), 36 | num_channels=(32, 64, 128)), 37 | stage4=dict( 38 | num_modules=3, 39 | num_branches=4, 40 | block='BASIC', 41 | num_blocks=(4, 4, 4, 4), 42 | num_channels=(32, 64, 128, 256)))), 43 | neck=dict(type='HRFPN', in_channels=[32, 64, 128, 256], out_channels=256), 44 | rpn_head=dict( 45 | type='RPNHead', 46 | in_channels=256, 47 | feat_channels=256, 48 | anchor_scales=[8], 49 | anchor_ratios=[0.5, 1.0, 2.0], 50 | anchor_strides=[4, 8, 16, 32, 64], 51 | target_means=[.0, .0, .0, .0], 52 | target_stds=[1.0, 1.0, 1.0, 1.0], 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 55 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), 56 | bbox_roi_extractor=dict( 57 | type='SingleRoIExtractor', 58 | roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), 59 | out_channels=256, 60 | featmap_strides=[4, 8, 16, 32]), 61 | bbox_head=[ 62 | dict( 63 | type='SharedFCBBoxHead', 64 | num_fcs=2, 65 | in_channels=256, 66 | fc_out_channels=1024, 67 | roi_feat_size=7, 68 | num_classes=81, 69 | target_means=[0., 0., 0., 0.], 70 | target_stds=[0.1, 0.1, 0.2, 0.2], 71 | reg_class_agnostic=True, 72 | loss_cls=dict( 73 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 74 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 75 | dict( 76 | type='SharedFCBBoxHead', 77 | num_fcs=2, 78 | in_channels=256, 79 | fc_out_channels=1024, 80 | roi_feat_size=7, 81 | num_classes=81, 82 | target_means=[0., 0., 0., 0.], 83 | target_stds=[0.05, 0.05, 0.1, 0.1], 84 | reg_class_agnostic=True, 85 | loss_cls=dict( 86 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 87 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 88 | dict( 89 | type='SharedFCBBoxHead', 90 | num_fcs=2, 91 | in_channels=256, 92 | fc_out_channels=1024, 93 | roi_feat_size=7, 94 | num_classes=81, 95 | target_means=[0., 0., 0., 0.], 96 | target_stds=[0.033, 0.033, 0.067, 0.067], 97 | reg_class_agnostic=True, 98 | loss_cls=dict( 99 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 100 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) 101 | ]) 102 | # model training and testing settings 103 | train_cfg = dict( 104 | rpn=dict( 105 | assigner=dict( 106 | type='MaxIoUAssigner', 107 | pos_iou_thr=0.7, 108 | neg_iou_thr=0.3, 109 | min_pos_iou=0.3, 110 | ignore_iof_thr=-1), 111 | sampler=dict( 112 | type='RandomSampler', 113 | num=256, 114 | pos_fraction=0.5, 115 | neg_pos_ub=-1, 116 | add_gt_as_proposals=False), 117 | allowed_border=0, 118 | pos_weight=-1, 119 | debug=False), 120 | rpn_proposal=dict( 121 | nms_across_levels=False, 122 | nms_pre=2000, 123 | nms_post=2000, 124 | max_num=2000, 125 | nms_thr=0.7, 126 | min_bbox_size=0), 127 | rcnn=[ 128 | dict( 129 | assigner=dict( 130 | type='MaxIoUAssigner', 131 | pos_iou_thr=0.5, 132 | neg_iou_thr=0.5, 133 | min_pos_iou=0.5, 134 | ignore_iof_thr=-1), 135 | sampler=dict( 136 | type='RandomSampler', 137 | num=512, 138 | pos_fraction=0.25, 139 | neg_pos_ub=-1, 140 | add_gt_as_proposals=True), 141 | mask_size=28, 142 | pos_weight=-1, 143 | debug=False), 144 | dict( 145 | assigner=dict( 146 | type='MaxIoUAssigner', 147 | pos_iou_thr=0.6, 148 | neg_iou_thr=0.6, 149 | min_pos_iou=0.6, 150 | ignore_iof_thr=-1), 151 | sampler=dict( 152 | type='RandomSampler', 153 | num=512, 154 | pos_fraction=0.25, 155 | neg_pos_ub=-1, 156 | add_gt_as_proposals=True), 157 | mask_size=28, 158 | pos_weight=-1, 159 | debug=False), 160 | dict( 161 | assigner=dict( 162 | type='MaxIoUAssigner', 163 | pos_iou_thr=0.7, 164 | neg_iou_thr=0.7, 165 | min_pos_iou=0.7, 166 | ignore_iof_thr=-1), 167 | sampler=dict( 168 | type='RandomSampler', 169 | num=512, 170 | pos_fraction=0.25, 171 | neg_pos_ub=-1, 172 | add_gt_as_proposals=True), 173 | mask_size=28, 174 | pos_weight=-1, 175 | debug=False) 176 | ], 177 | stage_loss_weights=[1, 0.5, 0.25]) 178 | test_cfg = dict( 179 | rpn=dict( 180 | nms_across_levels=False, 181 | nms_pre=1000, 182 | nms_post=1000, 183 | max_num=1000, 184 | nms_thr=0.7, 185 | min_bbox_size=0), 186 | rcnn=dict( 187 | score_thr=0.05, 188 | nms=dict(type='nms', iou_thr=0.5), 189 | max_per_img=100, 190 | mask_thr_binary=0.5)) 191 | # dataset settings 192 | dataset_type = 'IgnoringVOCDataset' 193 | data_root = model_dir + os.sep + 'VOC2007' 194 | img_norm_cfg = dict( 195 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 196 | train_pipeline = [ 197 | dict(type='LoadImageFromFile'), 198 | dict(type='LoadAnnotations', with_bbox=True), 199 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 200 | dict(type='RandomFlip', flip_ratio=0.5), 201 | dict(type='Normalize', **img_norm_cfg), 202 | dict(type='Pad', size_divisor=32), 203 | dict(type='DefaultFormatBundle'), 204 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 205 | ] 206 | test_pipeline = [ 207 | dict(type='LoadImageFromFile'), 208 | dict( 209 | type='MultiScaleFlipAug', 210 | img_scale=(1333, 800), 211 | flip=False, 212 | transforms=[ 213 | dict(type='Resize', keep_ratio=True), 214 | dict(type='RandomFlip'), 215 | dict(type='Normalize', **img_norm_cfg), 216 | dict(type='Pad', size_divisor=32), 217 | dict(type='ImageToTensor', keys=['img']), 218 | dict(type='Collect', keys=['img']), 219 | ]) 220 | ] 221 | data = dict( 222 | imgs_per_gpu=1, 223 | workers_per_gpu=1, 224 | train=dict( 225 | type=dataset_type, 226 | ann_file= model_dir + os.sep + 'VOC2007/ImageSets/main.txt', 227 | img_prefix= model_dir + os.sep + 'VOC2007/', 228 | pipeline=train_pipeline, 229 | ignore = ('table_body','full_table','header','heading')), 230 | test=dict( 231 | type=dataset_type, 232 | ann_file=data_root + 'VOC2007/test.json', 233 | img_prefix=data_root + 'VOC2007/Test/', 234 | pipeline=test_pipeline, 235 | ignore = ('table_body','full_table','header','heading'))) 236 | # evaluation = dict(interval=1, metric=['bbox']) 237 | # optimizer 238 | optimizer = dict(type='SGD', lr=0.0012, momentum=0.9, weight_decay=0.0001) 239 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 240 | # learning policy 241 | lr_config = dict( 242 | policy='step', 243 | warmup='linear', 244 | warmup_iters=500, 245 | warmup_ratio=1.0 / 3, 246 | step=[16, 19]) 247 | checkpoint_config = dict(interval=1,create_symlink=False) 248 | # yapf:disable 249 | log_config = dict( 250 | interval=50, 251 | hooks=[ 252 | dict(type='TextLoggerHook'), 253 | # dict(type='TensorboardLoggerHook') 254 | ]) 255 | # yapf:enable 256 | # runtime settings 257 | dist_params = dict(backend='nccl') 258 | log_level = 'DEBUG' 259 | load_from = None 260 | workflow = [('train', 1)] 261 | -------------------------------------------------------------------------------- /dla/config/cascadeRCNN_ignore_cells.py: -------------------------------------------------------------------------------- 1 | #Adapted from https://github.com/DevashishPrasad/CascadeTabNet 2 | 3 | import os 4 | 5 | # basic settings 6 | model_dir='/scratch/sem03/glosat/dla_models/model_29_09_2020_train' 7 | resume_from = None 8 | #total_epochs = 601 9 | total_epochs = 100 10 | 11 | # model settings 12 | model = dict( 13 | type='CascadeRCNN', 14 | num_stages=3, 15 | pretrained='open-mmlab://msra/hrnetv2_w32', 16 | backbone=dict( 17 | type='HRNet', 18 | extra=dict( 19 | stage1=dict( 20 | num_modules=1, 21 | num_branches=1, 22 | block='BOTTLENECK', 23 | num_blocks=(4, ), 24 | num_channels=(64, )), 25 | stage2=dict( 26 | num_modules=1, 27 | num_branches=2, 28 | block='BASIC', 29 | num_blocks=(4, 4), 30 | num_channels=(32, 64)), 31 | stage3=dict( 32 | num_modules=4, 33 | num_branches=3, 34 | block='BASIC', 35 | num_blocks=(4, 4, 4), 36 | num_channels=(32, 64, 128)), 37 | stage4=dict( 38 | num_modules=3, 39 | num_branches=4, 40 | block='BASIC', 41 | num_blocks=(4, 4, 4, 4), 42 | num_channels=(32, 64, 128, 256)))), 43 | neck=dict(type='HRFPN', in_channels=[32, 64, 128, 256], out_channels=256), 44 | rpn_head=dict( 45 | type='RPNHead', 46 | in_channels=256, 47 | feat_channels=256, 48 | anchor_scales=[8], 49 | anchor_ratios=[0.5, 1.0, 2.0], 50 | anchor_strides=[4, 8, 16, 32, 64], 51 | target_means=[.0, .0, .0, .0], 52 | target_stds=[1.0, 1.0, 1.0, 1.0], 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 55 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), 56 | bbox_roi_extractor=dict( 57 | type='SingleRoIExtractor', 58 | roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), 59 | out_channels=256, 60 | featmap_strides=[4, 8, 16, 32]), 61 | bbox_head=[ 62 | dict( 63 | type='SharedFCBBoxHead', 64 | num_fcs=2, 65 | in_channels=256, 66 | fc_out_channels=1024, 67 | roi_feat_size=7, 68 | num_classes=81, 69 | target_means=[0., 0., 0., 0.], 70 | target_stds=[0.1, 0.1, 0.2, 0.2], 71 | reg_class_agnostic=True, 72 | loss_cls=dict( 73 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 74 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 75 | dict( 76 | type='SharedFCBBoxHead', 77 | num_fcs=2, 78 | in_channels=256, 79 | fc_out_channels=1024, 80 | roi_feat_size=7, 81 | num_classes=81, 82 | target_means=[0., 0., 0., 0.], 83 | target_stds=[0.05, 0.05, 0.1, 0.1], 84 | reg_class_agnostic=True, 85 | loss_cls=dict( 86 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 87 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), 88 | dict( 89 | type='SharedFCBBoxHead', 90 | num_fcs=2, 91 | in_channels=256, 92 | fc_out_channels=1024, 93 | roi_feat_size=7, 94 | num_classes=81, 95 | target_means=[0., 0., 0., 0.], 96 | target_stds=[0.033, 0.033, 0.067, 0.067], 97 | reg_class_agnostic=True, 98 | loss_cls=dict( 99 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 100 | loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) 101 | ]) 102 | # model training and testing settings 103 | train_cfg = dict( 104 | rpn=dict( 105 | assigner=dict( 106 | type='MaxIoUAssigner', 107 | pos_iou_thr=0.7, 108 | neg_iou_thr=0.3, 109 | min_pos_iou=0.3, 110 | ignore_iof_thr=-1), 111 | sampler=dict( 112 | type='RandomSampler', 113 | num=256, 114 | pos_fraction=0.5, 115 | neg_pos_ub=-1, 116 | add_gt_as_proposals=False), 117 | allowed_border=0, 118 | pos_weight=-1, 119 | debug=False), 120 | rpn_proposal=dict( 121 | nms_across_levels=False, 122 | nms_pre=2000, 123 | nms_post=2000, 124 | max_num=2000, 125 | nms_thr=0.7, 126 | min_bbox_size=0), 127 | rcnn=[ 128 | dict( 129 | assigner=dict( 130 | type='MaxIoUAssigner', 131 | pos_iou_thr=0.5, 132 | neg_iou_thr=0.5, 133 | min_pos_iou=0.5, 134 | ignore_iof_thr=-1), 135 | sampler=dict( 136 | type='RandomSampler', 137 | num=512, 138 | pos_fraction=0.25, 139 | neg_pos_ub=-1, 140 | add_gt_as_proposals=True), 141 | mask_size=28, 142 | pos_weight=-1, 143 | debug=False), 144 | dict( 145 | assigner=dict( 146 | type='MaxIoUAssigner', 147 | pos_iou_thr=0.6, 148 | neg_iou_thr=0.6, 149 | min_pos_iou=0.6, 150 | ignore_iof_thr=-1), 151 | sampler=dict( 152 | type='RandomSampler', 153 | num=512, 154 | pos_fraction=0.25, 155 | neg_pos_ub=-1, 156 | add_gt_as_proposals=True), 157 | mask_size=28, 158 | pos_weight=-1, 159 | debug=False), 160 | dict( 161 | assigner=dict( 162 | type='MaxIoUAssigner', 163 | pos_iou_thr=0.7, 164 | neg_iou_thr=0.7, 165 | min_pos_iou=0.7, 166 | ignore_iof_thr=-1), 167 | sampler=dict( 168 | type='RandomSampler', 169 | num=512, 170 | pos_fraction=0.25, 171 | neg_pos_ub=-1, 172 | add_gt_as_proposals=True), 173 | mask_size=28, 174 | pos_weight=-1, 175 | debug=False) 176 | ], 177 | stage_loss_weights=[1, 0.5, 0.25]) 178 | test_cfg = dict( 179 | rpn=dict( 180 | nms_across_levels=False, 181 | nms_pre=1000, 182 | nms_post=1000, 183 | max_num=1000, 184 | nms_thr=0.7, 185 | min_bbox_size=0), 186 | rcnn=dict( 187 | score_thr=0.05, 188 | nms=dict(type='nms', iou_thr=0.5), 189 | max_per_img=100, 190 | mask_thr_binary=0.5)) 191 | # dataset settings 192 | dataset_type = 'IgnoringVOCDataset' 193 | data_root = model_dir + os.sep + 'VOC2007' 194 | img_norm_cfg = dict( 195 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 196 | train_pipeline = [ 197 | dict(type='LoadImageFromFile'), 198 | dict(type='LoadAnnotations', with_bbox=True), 199 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 200 | dict(type='RandomFlip', flip_ratio=0.5), 201 | dict(type='Normalize', **img_norm_cfg), 202 | dict(type='Pad', size_divisor=32), 203 | dict(type='DefaultFormatBundle'), 204 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 205 | ] 206 | test_pipeline = [ 207 | dict(type='LoadImageFromFile'), 208 | dict( 209 | type='MultiScaleFlipAug', 210 | img_scale=(1333, 800), 211 | flip=False, 212 | transforms=[ 213 | dict(type='Resize', keep_ratio=True), 214 | dict(type='RandomFlip'), 215 | dict(type='Normalize', **img_norm_cfg), 216 | dict(type='Pad', size_divisor=32), 217 | dict(type='ImageToTensor', keys=['img']), 218 | dict(type='Collect', keys=['img']), 219 | ]) 220 | ] 221 | data = dict( 222 | imgs_per_gpu=1, 223 | workers_per_gpu=1, 224 | train=dict( 225 | type=dataset_type, 226 | ann_file= model_dir + os.sep + 'VOC2007/ImageSets/main.txt', 227 | img_prefix= model_dir + os.sep + 'VOC2007/', 228 | pipeline=train_pipeline, 229 | ignore = ("cell"),), 230 | test=dict( 231 | type=dataset_type, 232 | ann_file=data_root + 'VOC2007/test.json', 233 | img_prefix=data_root + 'VOC2007/Test/', 234 | pipeline=test_pipeline, 235 | ignore = ("cell"))) 236 | # evaluation = dict(interval=1, metric=['bbox']) 237 | # optimizer 238 | optimizer = dict(type='SGD', lr=0.0012, momentum=0.9, weight_decay=0.0001) 239 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 240 | # learning policy 241 | lr_config = dict( 242 | policy='step', 243 | warmup='linear', 244 | warmup_iters=500, 245 | warmup_ratio=1.0 / 3, 246 | step=[16, 19]) 247 | checkpoint_config = dict(interval=1,create_symlink=False) 248 | # yapf:disable 249 | log_config = dict( 250 | interval=50, 251 | hooks=[ 252 | dict(type='TextLoggerHook'), 253 | # dict(type='TensorboardLoggerHook') 254 | ]) 255 | # yapf:enable 256 | # runtime settings 257 | dist_params = dict(backend='nccl') 258 | log_level = 'DEBUG' 259 | load_from = None 260 | workflow = [('train', 1)] 261 | -------------------------------------------------------------------------------- /dla/src/analyze_logs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import argparse 3 | import json 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--logs",type=str,default="logs") 8 | parser.add_argument("--outname",type=str,default="plots") 9 | parser.add_argument("-smooth",type=bool,default=False,action="store_true") 10 | 11 | args = parser.parse_args() 12 | 13 | losses={"s0.loss_cls":dict(), 14 | "s0.acc":dict(), 15 | "s0.loss_bbox":dict(), 16 | "s1.loss_cls":dict(), 17 | "s1.acc":dict(), 18 | "s1.loss_bbox":dict(), 19 | "s2.loss_cls":dict(), 20 | "s2.acc":dict(), 21 | "s2.loss_bbox":dict(), 22 | "loss_rpn_cls":dict(), 23 | "loss_rpn_bbox":dict(), 24 | "loss":dict()} 25 | 26 | def plot(losses,*args,smooth=True,outname=""): 27 | for arg in args: 28 | value_by_epoch = [] 29 | for epoch in sorted(losses[arg].keys()): 30 | if len(losses[arg][epoch])>0: 31 | value_by_epoch.append(sum(losses[arg][epoch])/len(losses[arg][epoch])) 32 | else: 33 | value_by_epoch.append(value_by_epoch[-1]) 34 | 35 | plt.plot(value_by_epoch,label=arg) 36 | 37 | if smooth: 38 | avg_by_epoch = [sum(value_by_epoch[i + j] for j in range(10))/10 for i in range(len(value_by_epoch)-10)] 39 | plt.plot(avg_by_epoch,label=arg + " smooth") 40 | 41 | plt.legend(loc="best") 42 | plt.savefig((outname if outname.endswith("/") else outname+ "/") + "".join(["".join(arg.split(".")) + "_" for arg in args]) + ".jpg") 43 | plt.clf() 44 | 45 | directory_name = "" 46 | if args.logs.endswith(".log.json" ): 47 | logfiles = [args.logs] 48 | else: 49 | logfiles = [file for file in os.listdir(args.logs) if file.endswith(".log.json")] 50 | directory_name = args.logs if args.logs.endswith("/") else args.logs + "/" 51 | 52 | for logfile in logfiles: 53 | for line in open(directory_name + logfile,"r").readlines(): 54 | logs = json.loads(line) 55 | for log in logs: 56 | if log in losses: 57 | if logs["epoch"] in losses[log]: 58 | losses[log][logs["epoch"]].append(logs[log]) 59 | else: 60 | losses[log][logs["epoch"]] = [logs[log]] 61 | 62 | plot(losses,"loss_rpn_cls","loss_rpn_bbox",smooth=args.smooth,outname=args.outname) 63 | plot(losses,"s0.loss_cls","s0.loss_bbox",smooth=args.smooth,outname=args.outname) 64 | plot(losses,"s1.loss_cls","s1.loss_bbox",smooth=args.smooth,outname=args.outname) 65 | plot(losses,"s2.loss_cls","s2.loss_bbox",smooth=args.smooth,outname=args.outname) 66 | plot(losses,"loss",smooth=args.smooth,outname=args.outname) 67 | plot(losses,"s0.acc","s1.acc","s2.acc",smooth=args.smooth,outname=args.outname) 68 | 69 | 70 | -------------------------------------------------------------------------------- /dla/src/construct_VOC.py: -------------------------------------------------------------------------------- 1 | import os, random, codecs 2 | 3 | # source dataset 4 | # //VOC_annotations ==> VOC xml labels 5 | # //labels ==> ICDAR xml labels 6 | # //page ==> Transkribus xml labels (original source) 7 | dataset_dir = '/data/glosat_table_dataset/datasets/GloSAT_dataset_coarse' 8 | #dataset_dir = '/data/glosat_table_dataset/datasets/GloSAT_dataset_fine' 9 | 10 | 11 | # resulting model dir structure 12 | # _train/VOC2007/JPEGImages 13 | # _train/VOC2007/Annotations 14 | # _train/VOC2007/ICDAR 15 | # _train/VOC2007/Transkribus 16 | # _test/VOC2007/JPEGImages 17 | # _test/VOC2007/Annotations 18 | # _test/VOC2007/ICDAR 19 | # _test/VOC2007/Transkribus 20 | model_dir = '/data/glosat_table_dataset/dla_models/model_table_det_full_table' 21 | #model_dir = '/data/glosat_table_dataset/dla_models/model_table_det_enhanced' 22 | #model_dir = '/data/glosat_table_dataset/dla_models/model_table_struct_coarse' 23 | #model_dir = '/data/glosat_table_dataset/dla_models/model_table_struct_fine' 24 | 25 | # images will be copied as 66% train 33% test 26 | train_dir = model_dir + '_train' 27 | holdout_dir = model_dir + '_test' 28 | 29 | data_sources = { 30 | "20cr_DWR_MO/": dataset_dir + "/20cr_DWR_MO", 31 | "20cr_DWR_NOAA/": dataset_dir +"/20cr_DWR_NOAA", 32 | "20cr_Kubota/": dataset_dir + "/20cr_Kubota", 33 | "20cr_Natal_Witnes/": dataset_dir + "/20cr_Natal_Witnes", 34 | "DWR/": dataset_dir + "/DWR", 35 | "WR_10_years/": dataset_dir + "/WR_10_years", 36 | "WesTech/": dataset_dir + "/WesTech_Rodgers", 37 | "WR_Devon_Extern/": dataset_dir +"/WR_Devon_Extern", 38 | "Ben_Nevis/": dataset_dir +"/Ben_Nevis"} 39 | 40 | if not os.path.exists( train_dir ) : 41 | os.mkdir( train_dir ) 42 | if not os.path.exists( train_dir + "/VOC2007" ) : 43 | os.mkdir( train_dir + "/VOC2007" ) 44 | if not os.path.exists( train_dir + "/VOC2007/Annotations" ) : 45 | os.mkdir( train_dir + "/VOC2007/Annotations" ) 46 | if not os.path.exists( train_dir + "/VOC2007/ICDAR" ) : 47 | os.mkdir( train_dir + "/VOC2007/ICDAR" ) 48 | if not os.path.exists( train_dir + "/VOC2007/Transkribus" ) : 49 | os.mkdir( train_dir + "/VOC2007/Transkribus" ) 50 | if not os.path.exists( train_dir + "/VOC2007/JPEGImages" ) : 51 | os.mkdir( train_dir + "/VOC2007/JPEGImages" ) 52 | if not os.path.exists( train_dir + "/VOC2007/ImageSets" ) : 53 | os.mkdir( train_dir + "/VOC2007/ImageSets" ) 54 | 55 | if not os.path.exists( holdout_dir ) : 56 | os.mkdir( holdout_dir ) 57 | if not os.path.exists( holdout_dir + "/VOC2007" ) : 58 | os.mkdir( holdout_dir + "/VOC2007" ) 59 | if not os.path.exists( holdout_dir + "/VOC2007/Annotations" ) : 60 | os.mkdir( holdout_dir + "/VOC2007/Annotations" ) 61 | if not os.path.exists( holdout_dir + "/VOC2007/ICDAR" ) : 62 | os.mkdir( holdout_dir + "/VOC2007/ICDAR" ) 63 | if not os.path.exists( holdout_dir + "/VOC2007/Transkribus" ) : 64 | os.mkdir( holdout_dir + "/VOC2007/Transkribus" ) 65 | if not os.path.exists( holdout_dir + "/VOC2007/JPEGImages" ) : 66 | os.mkdir( holdout_dir + "/VOC2007/JPEGImages" ) 67 | if not os.path.exists( holdout_dir + "/VOC2007/ImageSets" ) : 68 | os.mkdir( holdout_dir + "/VOC2007/ImageSets" ) 69 | 70 | list_test = [] 71 | list_train = [] 72 | 73 | for data_source in data_sources.values(): 74 | available_files = [] 75 | for file in os.listdir(data_source): 76 | if file.endswith(".jpg") and file.strip(".jpg") + ".xml" in os.listdir(data_source + "/VOC_annotations/"): 77 | available_files.append(file) 78 | 79 | for _ in range(int(len(available_files) * 0.75)): 80 | file = random.choice(available_files) 81 | os.system("cp {} {}".format(data_source + '/' + file,train_dir + "/VOC2007/JPEGImages/" + file)) 82 | os.system("cp {} {}".format(data_source + "/VOC_annotations/" + file.strip(".jpg") + ".xml",train_dir + "/VOC2007/Annotations/" + file.strip(".jpg") + ".xml")) 83 | os.system("cp {} {}".format(data_source + "/labels/" + file.strip(".jpg") + ".xml",train_dir + "/VOC2007/ICDAR/" + file.strip(".jpg") + ".xml")) 84 | os.system("cp {} {}".format(data_source + "/page/" + file.strip(".jpg") + ".xml",train_dir + "/VOC2007/Transkribus/" + file.strip(".jpg") + ".xml")) 85 | list_train.append( int( file.strip(".jpg") ) ) 86 | available_files.remove(file) 87 | 88 | for _ in range(len(available_files)): 89 | file = random.choice(available_files) 90 | os.system("cp {} {}".format(data_source + '/' + file,holdout_dir + "/VOC2007/JPEGImages/" + file)) 91 | os.system("cp {} {}".format(data_source + "/VOC_annotations/" + file.strip(".jpg") + ".xml", holdout_dir + "/VOC2007/Annotations/" + file.strip(".jpg") + ".xml")) 92 | os.system("cp {} {}".format(data_source + "/labels/" + file.strip(".jpg") + ".xml", holdout_dir + "/VOC2007/ICDAR/" + file.strip(".jpg") + ".xml")) 93 | os.system("cp {} {}".format(data_source + "/page/" + file.strip(".jpg") + ".xml", holdout_dir + "/VOC2007/Transkribus/" + file.strip(".jpg") + ".xml")) 94 | list_test.append( int( file.strip(".jpg") ) ) 95 | available_files.remove(file) 96 | 97 | list_train = sorted( list_train ) 98 | write_handle = codecs.open( train_dir + "/VOC2007/ImageSets/main.txt", 'w', 'utf-8', errors = 'replace' ) 99 | for image_id in list_train : 100 | write_handle.write( str(image_id) + '\n' ) 101 | write_handle.close() 102 | 103 | list_test = sorted( list_test ) 104 | write_handle = codecs.open( holdout_dir + "/VOC2007/ImageSets/main.txt", 'w', 'utf-8', errors = 'replace' ) 105 | for image_id in list_test : 106 | write_handle.write( str(image_id) + '\n' ) 107 | write_handle.close() 108 | -------------------------------------------------------------------------------- /dla/src/convert_ICDAR_to_VOC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml_utils as xml_utils 3 | 4 | # 5 | # note: this is not needed for GloSAT dataset as it is shipped with VOC, ICDAR and Transkribus xml labels 6 | # 7 | 8 | voc_dir = '/home/juliuszziomek/Documents/Python/GloSAT_fine_test/VOC2007_test_noheader/Annotations/' 9 | icdar_dir = '/home/juliuszziomek/Documents/Python/GloSAT_fine_test/VOC2007/ICDAR/' 10 | 11 | if not os.path.exists( icdar_dir ) : 12 | raise Exception( 'model dir not found : ' + repr(icdar_dir) ) 13 | if not os.path.exists( voc_dir ) : 14 | os.mkdir( voc_dir ) 15 | 16 | # see inference.py code for the basis of this code 17 | CLASSES = ("table_body","cell","full_table","header","heading") 18 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 19 | 20 | for file in os.listdir( icdar_dir ) : 21 | if file.endswith(".xml"): 22 | 23 | icdar_parsed = xml_utils.load_ICDAR_xml( icdar_dir + '/' + file ) 24 | 25 | tables = [] 26 | cells = [] 27 | 28 | for entry in icdar_parsed: 29 | tables.append(entry["region"]) 30 | cells += entry["cells"] 31 | 32 | 33 | xml_utils.save_VOC_xml_from_cells(headings=[],headers=[],bodies=tables,full_tables=tables,cells=cells,filename=os.path.join(voc_dir,file),width=1000,height=1000) 34 | -------------------------------------------------------------------------------- /dla/src/convert_VOC_to_ICDAR.py: -------------------------------------------------------------------------------- 1 | import os, random, codecs 2 | import table_structure_analysis as tsa 3 | import xml_utils as xml_utils 4 | from image_utils import put_box, put_line 5 | import argparse 6 | 7 | # 8 | # note: this is not needed for GloSAT dataset as it is shipped with VOC, ICDAR and Transkribus xml labels 9 | # 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('voc',type=str) 14 | parser.add_argument('icdar',type=str) 15 | args = parser.parse_args() 16 | 17 | voc_dir = args.voc 18 | icdar_dir = args.icdar 19 | 20 | if not os.path.exists( voc_dir ) : 21 | raise Exception( 'model dir not found : ' + repr(voc_dir) ) 22 | if not os.path.exists( icdar_dir ) : 23 | os.mkdir( icdar_dir ) 24 | 25 | # see inference.py code for the basis of this code 26 | CLASSES = ("table_body","cell","full_table","header","heading") 27 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 28 | 29 | for file in os.listdir( voc_dir ) : 30 | if not(file.endswith(".xml")): 31 | continue 32 | 33 | # load VOC xml 34 | # voc_parsed = [ {"name":name,"bbox":[xmin,ymin,xmax,ymax]}, ... ] 35 | voc_parsed = xml_utils.load_VOC_xml( voc_dir + '/' + file ) 36 | 37 | headings = [] 38 | headers = [] 39 | tables = [] 40 | full_tables = [] 41 | predicted_cells = [] 42 | 43 | for entry in voc_parsed : 44 | if entry['name'] == 'header' : 45 | headers.append( entry['bbox'] ) 46 | elif entry['name'] == 'table_body' : 47 | tables.append( entry['bbox'] ) 48 | elif entry['name'] == 'heading' : 49 | headings.append( entry['bbox'] ) 50 | elif entry['name'] == 'full_table' : 51 | full_tables.append( entry['bbox'] ) 52 | elif entry['name'] == 'cell' : 53 | predicted_cells.append( entry['bbox'] ) 54 | 55 | for table in tables: 56 | if all(tsa.how_much_contained(table,full_table)<0.5 for full_table in full_tables): 57 | full_tables.append(table) 58 | 59 | rows_by_table = [] 60 | cols_by_table = [] 61 | 62 | for table in full_tables: 63 | cells = [] 64 | for cell in predicted_cells : 65 | 66 | # assign a cell to a table if cell area > 0 AND cell overlap with table is > 50% 67 | if area(cell) > 0 : 68 | if tsa.how_much_contained(cell,table)>0.5: 69 | cells.append(cell) 70 | 71 | if cells != []: 72 | rows, cols = tsa.reconstruct_table(cells,table,eps=0.02) 73 | else: 74 | rows,cols = [],[] 75 | 76 | rows_by_table.append(rows) 77 | cols_by_table.append(cols) 78 | 79 | xml_utils.save_ICDAR_xml( full_tables, cols_by_table, rows_by_table, icdar_dir +'/' + file ) 80 | 81 | -------------------------------------------------------------------------------- /dla/src/eval_ICDAR.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | try: 18 | import xml_utils as xml_utils 19 | except: 20 | import dla.src.xml_utils as xml_utils 21 | 22 | import argparse 23 | import os 24 | 25 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 26 | 27 | def IoU(box1,box2): 28 | 29 | area1 = area(box1) 30 | area2 = area(box2) 31 | 32 | intersection_box = [max(box1[0],box2[0]), 33 | max(box1[1],box2[1]), 34 | min(box1[2],box2[2]), 35 | min(box1[3],box2[3])] 36 | 37 | intersection_area = area(intersection_box) 38 | 39 | return intersection_area/(area1 + area2 - intersection_area) 40 | 41 | def calculate_scores(output,gt,IoU_threshold): 42 | scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 43 | duplicated_positives = 0 44 | total_IoU = 0 45 | 46 | for gt_object in gt: 47 | object_detected = False 48 | for output_object in output: 49 | if IoU(gt_object,output_object)>IoU_threshold: 50 | total_IoU += IoU(gt_object,output_object) 51 | if object_detected: 52 | duplicated_positives += 1 53 | else: 54 | object_detected = True 55 | scores["true_pos"] += 1 56 | 57 | if not object_detected: 58 | scores["false_neg"] += 1 59 | 60 | scores["mean_IoU"] = total_IoU/(scores["true_pos"] + duplicated_positives) if total_IoU else None 61 | scores["false_pos"] = max(len(output) - scores["true_pos"] - duplicated_positives,0) 62 | 63 | return scores 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('gt',type=str) 67 | parser.add_argument('output',type=str) 68 | parser.add_argument('--IoU_threshold',type=float) 69 | args = parser.parse_args() 70 | 71 | IoU_threshold = args.IoU_threshold if args.IoU_threshold else 0.5 72 | 73 | gt_path = args.gt if args.gt.endswith("/") else args.gt + "/" 74 | output_path = args.output if args.output.endswith("/") else args.output + "/" 75 | 76 | table_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 77 | cell_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 78 | table_no = 0 79 | cell_no = 0 80 | 81 | for file in os.listdir(output_path): 82 | if file in os.listdir(gt_path): 83 | output = xml_utils.load_ICDAR_xml(output_path + file) 84 | gt = xml_utils.load_ICDAR_xml(gt_path + file) 85 | 86 | output_tables = [] 87 | output_cells = [] 88 | gt_tables = [] 89 | gt_cells = [] 90 | 91 | for table in output: 92 | output_tables.append(table["region"]) 93 | output_cells += table["cells"] 94 | 95 | for table in gt: 96 | gt_tables.append(table["region"]) 97 | gt_cells += table["cells"] 98 | 99 | new_table_scores = calculate_scores(output_tables,gt_tables,IoU_threshold ) 100 | new_cell_scores =calculate_scores(output_cells,gt_cells,IoU_threshold ) 101 | 102 | for score_type in new_table_scores: 103 | table_scores[score_type] += new_table_scores[score_type] if new_table_scores[score_type] else 0 104 | cell_scores[score_type] += new_cell_scores[score_type] if new_cell_scores[score_type] else 0 105 | 106 | if new_table_scores["mean_IoU"]: 107 | table_no += 1 108 | if new_cell_scores["mean_IoU"]: 109 | cell_no += 1 110 | 111 | precision_tables = table_scores["true_pos"] / (table_scores["true_pos"] + table_scores["false_pos"]) if (table_scores["true_pos"] + table_scores["false_pos"])!=0 else None 112 | precision_cells = cell_scores["true_pos"] / (cell_scores["true_pos"] + cell_scores["false_pos"]) if (cell_scores["true_pos"] + cell_scores["false_pos"]) !=0 else None 113 | recall_tables = table_scores["true_pos"] / (table_scores["true_pos"] + table_scores["false_neg"]) if (table_scores["true_pos"] + table_scores["false_neg"]) !=0 else None 114 | recall_cells = cell_scores["true_pos"] / (cell_scores["true_pos"] + cell_scores["false_neg"]) if (cell_scores["true_pos"] + cell_scores["false_neg"]) !=0 else None 115 | 116 | print("Table:","Precision:",precision_tables,"Recall:",recall_tables,"Mean IoU:",table_scores["mean_IoU"]/table_no if table_no>0 else None) 117 | print("Cell:","Precision:",precision_cells,"Recall:",recall_cells,"Mean IoU:",cell_scores["mean_IoU"]/cell_no if cell_no>0 else None) -------------------------------------------------------------------------------- /dla/src/eval_ICDAR_wF1.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | try: 18 | import xml_utils as xml_utils 19 | except: 20 | import dla.src.xml_utils as xml_utils 21 | 22 | import argparse 23 | import os 24 | 25 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 26 | 27 | def IoU(box1,box2): 28 | 29 | area1 = area(box1) 30 | area2 = area(box2) 31 | 32 | intersection_box = [max(box1[0],box2[0]), 33 | max(box1[1],box2[1]), 34 | min(box1[2],box2[2]), 35 | min(box1[3],box2[3])] 36 | 37 | intersection_area = area(intersection_box) 38 | 39 | return intersection_area/(area1 + area2 - intersection_area) 40 | 41 | def calculate_scores(output,gt,IoU_threshold): 42 | scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 43 | duplicated_positives = 0 44 | total_IoU = 0 45 | 46 | for gt_object in gt: 47 | object_detected = False 48 | for output_object in output: 49 | if IoU(gt_object,output_object)>IoU_threshold: 50 | total_IoU += IoU(gt_object,output_object) 51 | if object_detected: 52 | duplicated_positives += 1 53 | else: 54 | object_detected = True 55 | scores["true_pos"] += 1 56 | 57 | if not object_detected: 58 | scores["false_neg"] += 1 59 | 60 | scores["mean_IoU"] = total_IoU/(scores["true_pos"] + duplicated_positives) if total_IoU else None 61 | scores["false_pos"] = max(len(output) - scores["true_pos"] - duplicated_positives,0) 62 | 63 | return scores 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('gt',type=str) 67 | parser.add_argument('output',type=str) 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | gt_path = args.gt if args.gt.endswith("/") else args.gt + "/" 73 | output_path = args.output if args.output.endswith("/") else args.output + "/" 74 | 75 | table_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 76 | cell_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 77 | table_no = 0 78 | cell_no = 0 79 | F1scores = 0 80 | IoU_thresholds = {0.6,0.7,0.8,0.9} 81 | 82 | for IoU_threshold in IoU_thresholds: 83 | for file in os.listdir(output_path): 84 | if file in os.listdir(gt_path): 85 | output = xml_utils.load_ICDAR_xml(output_path + file) 86 | gt = xml_utils.load_ICDAR_xml(gt_path + file) 87 | 88 | output_tables = [] 89 | output_cells = [] 90 | gt_tables = [] 91 | gt_cells = [] 92 | 93 | for table in output: 94 | output_tables.append(table["region"]) 95 | output_cells += table["cells"] 96 | 97 | for table in gt: 98 | gt_tables.append(table["region"]) 99 | gt_cells += table["cells"] 100 | 101 | new_table_scores = calculate_scores(output_tables,gt_tables,IoU_threshold ) 102 | new_cell_scores =calculate_scores(output_cells,gt_cells,IoU_threshold ) 103 | 104 | for score_type in new_table_scores: 105 | table_scores[score_type] += new_table_scores[score_type] if new_table_scores[score_type] else 0 106 | cell_scores[score_type] += new_cell_scores[score_type] if new_cell_scores[score_type] else 0 107 | 108 | if new_table_scores["mean_IoU"]: 109 | table_no += 1 110 | if new_cell_scores["mean_IoU"]: 111 | cell_no += 1 112 | 113 | 114 | precision = cell_scores["true_pos"] / (cell_scores["true_pos"] + cell_scores["false_pos"]) if (cell_scores["true_pos"] + cell_scores["false_pos"]) !=0 else None 115 | recall = cell_scores["true_pos"] / (cell_scores["true_pos"] + cell_scores["false_neg"]) if (cell_scores["true_pos"] + cell_scores["false_neg"]) !=0 else None 116 | 117 | F1scores += 2 * precision * recall / (precision + recall) * IoU_threshold if precision!=None and recall!=None else 0 118 | 119 | print("F1 scores",F1scores/sum(IoU_thresholds)) -------------------------------------------------------------------------------- /dla/src/eval_VOC.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | CLASSES = ("heading","header","full_table","table_body","cell") 18 | 19 | try: 20 | import xml_utils as xml_utils 21 | except: 22 | import dla.src.xml_utils as xml_utils 23 | 24 | import argparse 25 | import os 26 | 27 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 28 | 29 | def IoU(box1,box2): 30 | 31 | box1 = [box1[0],box1[1],(box1[2] + 1) if box1[0]==box1[2] else box1[2],(box1[3] + 1) if box1[1]==box1[3] else box1[3]] 32 | 33 | area1 = area(box1) 34 | area2 = area(box2) 35 | 36 | intersection_box = [max(box1[0],box2[0]), 37 | max(box1[1],box2[1]), 38 | min(box1[2],box2[2]), 39 | min(box1[3],box2[3])] 40 | 41 | intersection_area = area(intersection_box) 42 | 43 | return intersection_area/(area1 + area2 - intersection_area) 44 | 45 | def calculate_scores(output,gt,IoU_threshold): 46 | scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 47 | duplicated_positives = 0 48 | total_IoU = 0 49 | 50 | for gt_object in gt: 51 | object_detected = False 52 | for output_object in output: 53 | if IoU(gt_object,output_object)>IoU_threshold: 54 | total_IoU += IoU(gt_object,output_object) 55 | if object_detected: 56 | duplicated_positives += 1 57 | else: 58 | object_detected = True 59 | scores["true_pos"] += 1 60 | 61 | if not object_detected: 62 | scores["false_neg"] += 1 63 | 64 | scores["mean_IoU"] = total_IoU/(scores["true_pos"] + duplicated_positives) if total_IoU else None 65 | scores["false_pos"] = max(len(output) - scores["true_pos"] - duplicated_positives,0) 66 | 67 | return scores 68 | 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('gt',type=str) 71 | parser.add_argument('output',type=str) 72 | parser.add_argument('--IoU_threshold',type=float) 73 | args = parser.parse_args() 74 | 75 | IoU_threshold = args.IoU_threshold if args.IoU_threshold else 0.5 76 | 77 | gt_path = args.gt if args.gt.endswith("/") else args.gt + "/" 78 | output_path = args.output if args.output.endswith("/") else args.output + "/" 79 | 80 | scores = {c : {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} for c in CLASSES} 81 | no = {c : 0 for c in CLASSES} 82 | 83 | 84 | for file in os.listdir(output_path): 85 | if file in os.listdir(gt_path): 86 | output = xml_utils.load_VOC_xml(output_path + file) 87 | gt = xml_utils.load_VOC_xml(gt_path + file) 88 | 89 | output_objects = {c:[] for c in CLASSES} 90 | gt_objects = {c:[] for c in CLASSES} 91 | 92 | for obj in output: 93 | output_objects[obj["name"]].append(obj["bbox"]) 94 | 95 | for obj in gt: 96 | gt_objects[obj["name"]].append(obj["bbox"]) 97 | 98 | for class_ in CLASSES: 99 | new_scores = calculate_scores(output_objects[class_],gt_objects[class_],IoU_threshold) 100 | 101 | for score_type in new_scores: 102 | scores[class_][score_type] += new_scores[score_type] if new_scores[score_type] else 0 103 | 104 | if new_scores["mean_IoU"]: 105 | no[class_] += 1 106 | 107 | 108 | for class_ in CLASSES: 109 | precision = scores[class_]["true_pos"] / (scores[class_]["true_pos"] + scores[class_]["false_pos"]) if (scores[class_]["true_pos"] + scores[class_]["false_pos"])!=0 else None 110 | 111 | recall = scores[class_]["true_pos"] / (scores[class_]["true_pos"] + scores[class_]["false_neg"]) if (scores[class_]["true_pos"] + scores[class_]["false_neg"]) !=0 else None 112 | 113 | 114 | print(class_,"Precision:",precision,"Recall:",recall,"Mean IoU:",scores[class_]["mean_IoU"]/no[class_] if no[class_]>0 else None) 115 | -------------------------------------------------------------------------------- /dla/src/eval_VOC_wF1.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | CLASSES = ("heading","header","full_table","table_body","cell") 18 | 19 | try: 20 | import xml_utils as xml_utils 21 | except: 22 | import dla.src.xml_utils as xml_utils 23 | 24 | import argparse 25 | import os 26 | import collections 27 | 28 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 29 | 30 | def IoU(box1,box2): 31 | 32 | box1 = [box1[0],box1[1],(box1[2] + 1) if box1[0]==box1[2] else box1[2],(box1[3] + 1) if box1[1]==box1[3] else box1[3]] 33 | 34 | area1 = area(box1) 35 | area2 = area(box2) 36 | 37 | intersection_box = [max(box1[0],box2[0]), 38 | max(box1[1],box2[1]), 39 | min(box1[2],box2[2]), 40 | min(box1[3],box2[3])] 41 | 42 | intersection_area = area(intersection_box) 43 | 44 | return intersection_area/(area1 + area2 - intersection_area) 45 | 46 | def calculate_scores(output,gt,IoU_threshold): 47 | scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 48 | duplicated_positives = 0 49 | total_IoU = 0 50 | 51 | for gt_object in gt: 52 | object_detected = False 53 | for output_object in output: 54 | if IoU(gt_object,output_object)>IoU_threshold: 55 | total_IoU += IoU(gt_object,output_object) 56 | if object_detected: 57 | duplicated_positives += 1 58 | else: 59 | object_detected = True 60 | scores["true_pos"] += 1 61 | 62 | if not object_detected: 63 | scores["false_neg"] += 1 64 | 65 | scores["mean_IoU"] = total_IoU/(scores["true_pos"] + duplicated_positives) if total_IoU else None 66 | scores["false_pos"] = max(len(output) - scores["true_pos"] - duplicated_positives,0) 67 | 68 | return scores 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('gt',type=str) 72 | parser.add_argument('output',type=str) 73 | args = parser.parse_args() 74 | 75 | IoU_thresholds = {0.6,0.7,0.8,0.9} 76 | 77 | F1scores = collections.defaultdict(lambda:0) 78 | 79 | for IoU_threshold in IoU_thresholds: 80 | 81 | gt_path = args.gt if args.gt.endswith("/") else args.gt + "/" 82 | output_path = args.output if args.output.endswith("/") else args.output + "/" 83 | 84 | scores = {c : {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} for c in CLASSES} 85 | no = {c : 0 for c in CLASSES} 86 | 87 | 88 | for file in os.listdir(output_path): 89 | if file in os.listdir(gt_path): 90 | output = xml_utils.load_VOC_xml(output_path + file) 91 | gt = xml_utils.load_VOC_xml(gt_path + file) 92 | 93 | output_objects = {c:[] for c in CLASSES} 94 | gt_objects = {c:[] for c in CLASSES} 95 | 96 | for obj in output: 97 | output_objects[obj["name"]].append(obj["bbox"]) 98 | 99 | for obj in gt: 100 | gt_objects[obj["name"]].append(obj["bbox"]) 101 | 102 | for class_ in CLASSES: 103 | new_scores = calculate_scores(output_objects[class_],gt_objects[class_],IoU_threshold) 104 | 105 | for score_type in new_scores: 106 | scores[class_][score_type] += new_scores[score_type] if new_scores[score_type] else 0 107 | 108 | if new_scores["mean_IoU"]: 109 | no[class_] += 1 110 | 111 | #if class_ == "full_table": 112 | # if new_scores["false_neg"]>1 or new_scores["false_pos"]>1: 113 | # print(file) 114 | 115 | #print("\n IoU threshold:",IoU_threshold,"\n") 116 | for class_ in CLASSES: 117 | precision = scores[class_]["true_pos"] / (scores[class_]["true_pos"] + scores[class_]["false_pos"]) if (scores[class_]["true_pos"] + scores[class_]["false_pos"])!=0 else None 118 | 119 | recall = scores[class_]["true_pos"] / (scores[class_]["true_pos"] + scores[class_]["false_neg"]) if (scores[class_]["true_pos"] + scores[class_]["false_neg"]) !=0 else None 120 | 121 | 122 | #print(class_,"Precision:",precision,"Recall:",recall,"Mean IoU:",scores[class_]["mean_IoU"]/no[class_] if no[class_]>0 else None) 123 | 124 | F1scores[class_] += 2 * precision * recall / (precision + recall) * IoU_threshold if precision!=None and recall!=None else 0 125 | 126 | print("\n\n") 127 | for class_ in CLASSES: 128 | print(class_,"Avg.W.F1",F1scores[class_]/sum(IoU_thresholds)) 129 | -------------------------------------------------------------------------------- /dla/src/eval_rows_n_cols_only.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | try: 18 | import xml_utils as xml_utils 19 | except: 20 | import dla.src.xml_utils as xml_utils 21 | 22 | import argparse 23 | import os 24 | 25 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 26 | 27 | def IoU(box1,box2): 28 | 29 | area1 = area(box1) 30 | area2 = area(box2) 31 | 32 | if area1==0 and area2==0: 33 | return 0 34 | 35 | intersection_box = [max(box1[0],box2[0]), 36 | max(box1[1],box2[1]), 37 | min(box1[2],box2[2]), 38 | min(box1[3],box2[3])] 39 | 40 | intersection_area = area(intersection_box) 41 | 42 | return intersection_area/(area1 + area2 - intersection_area) 43 | 44 | def calculate_scores(output,gt,IoU_threshold): 45 | scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 46 | duplicated_positives = 0 47 | total_IoU = 0 48 | 49 | for gt_object in gt: 50 | object_detected = False 51 | for output_object in output: 52 | if IoU(gt_object,output_object)>IoU_threshold: 53 | total_IoU += IoU(gt_object,output_object) 54 | if object_detected: 55 | duplicated_positives += 1 56 | else: 57 | object_detected = True 58 | scores["true_pos"] += 1 59 | 60 | if not object_detected: 61 | scores["false_neg"] += 1 62 | 63 | scores["mean_IoU"] = total_IoU/(scores["true_pos"] + duplicated_positives) if total_IoU else None 64 | scores["false_pos"] = max(len(output) - scores["true_pos"] - duplicated_positives,0) 65 | 66 | return scores 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('gt',type=str) 70 | parser.add_argument('output',type=str) 71 | parser.add_argument('--IoU_threshold',type=float) 72 | args = parser.parse_args() 73 | 74 | IoU_threshold = args.IoU_threshold if args.IoU_threshold else 0.5 75 | 76 | gt_path = args.gt if args.gt.endswith("/") else args.gt + "/" 77 | output_path = args.output if args.output.endswith("/") else args.output + "/" 78 | 79 | table_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 80 | row_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 81 | col_scores = {"true_pos":0,"false_pos":0,"false_neg":0,"mean_IoU":0} 82 | table_no = 0 83 | row_no = 0 84 | col_no = 0 85 | 86 | PROXIMITY_FACTOR = 5 87 | 88 | 89 | for file in os.listdir(output_path): 90 | if file in os.listdir(gt_path): 91 | output = xml_utils.load_ICDAR_xml_lines(output_path + file) 92 | gt = xml_utils.load_ICDAR_xml_lines(gt_path + file) 93 | 94 | output_tables = [] 95 | 96 | gt_tables = [] 97 | 98 | total_rows = 0 99 | total_cols = 0 100 | foundtable_rows = 0 101 | foundtable_cols = 0 102 | 103 | for table in output: 104 | output_tables.append(table["region"]) 105 | total_rows += len(table["rows"][1:]) 106 | total_cols += len(table["cols"][1:]) 107 | 108 | for table in gt: 109 | gt_tables.append(table["region"]) 110 | 111 | new_table_scores = calculate_scores(output_tables,gt_tables,IoU_threshold) 112 | 113 | for score_type in new_table_scores: 114 | table_scores[score_type] += new_table_scores[score_type] if new_table_scores[score_type] else 0 115 | 116 | if new_table_scores["mean_IoU"]: 117 | table_no += 1 118 | 119 | for gt_table in gt: 120 | match_found = False 121 | for out_table in output: 122 | if IoU(out_table["region"],gt_table["region"])>IoU_threshold: 123 | 124 | match_found = True 125 | 126 | gt_table["rows"] = gt_table["rows"][1:] 127 | out_table["rows"] = out_table["rows"][1:] 128 | gt_table["cols"] = gt_table["cols"][1:] 129 | out_table["cols"] = out_table["cols"][1:] 130 | 131 | foundtable_rows += len(out_table["rows"]) 132 | foundtable_cols += len(out_table["cols"]) 133 | 134 | row_proximity = (gt_table["region"][3] - gt_table["region"][1]) /(len(gt_table["rows"])+2) / PROXIMITY_FACTOR 135 | col_proximity = (gt_table["region"][2] - gt_table["region"][0]) / (len(gt_table["cols"])+2)/ PROXIMITY_FACTOR 136 | 137 | gt_rows = [[gt_table["region"][0],row - row_proximity,gt_table["region"][2],row + row_proximity] for row in gt_table["rows"]] 138 | gt_cols = [[col - col_proximity,gt_table["region"][1],col + col_proximity,gt_table["region"][3]] for col in gt_table["cols"]] 139 | 140 | out_rows = [[gt_table["region"][0],row - row_proximity,gt_table["region"][2],row + row_proximity] for row in out_table["rows"]] 141 | out_cols = [[col - col_proximity,gt_table["region"][1],col + col_proximity,gt_table["region"][3]] for col in out_table["cols"]] 142 | 143 | new_row_scores = calculate_scores(out_rows,gt_rows,IoU_threshold) 144 | new_col_scores = calculate_scores(out_cols,gt_cols,IoU_threshold) 145 | 146 | for score_type in new_row_scores: 147 | row_scores[score_type] += new_row_scores[score_type] if new_row_scores[score_type] else 0 148 | 149 | for score_type in new_col_scores: 150 | col_scores[score_type] += new_col_scores[score_type] if new_col_scores[score_type] else 0 151 | 152 | if new_row_scores["mean_IoU"]: 153 | row_no += 1 154 | 155 | if new_col_scores["mean_IoU"]: 156 | col_no += 1 157 | break 158 | 159 | if not(match_found): 160 | row_scores["false_neg"] += len(gt_table["rows"][1:]) 161 | col_scores["false_neg"] += len(gt_table["cols"][1:]) 162 | 163 | row_scores["false_pos"] += total_rows - foundtable_rows 164 | col_scores["false_pos"] += total_cols - foundtable_cols 165 | 166 | 167 | 168 | precision_tables = table_scores["true_pos"] / (table_scores["true_pos"] + table_scores["false_pos"]) if (table_scores["true_pos"] + table_scores["false_pos"])!=0 else None 169 | precision_rows = row_scores["true_pos"] / (row_scores["true_pos"] + row_scores["false_pos"]) if (row_scores["true_pos"] + row_scores["false_pos"]) !=0 else None 170 | precision_cols = col_scores["true_pos"] / (col_scores["true_pos"] + col_scores["false_pos"]) if (col_scores["true_pos"] + col_scores["false_pos"]) !=0 else None 171 | recall_tables = table_scores["true_pos"] / (table_scores["true_pos"] + table_scores["false_neg"]) if (table_scores["true_pos"] + table_scores["false_neg"]) !=0 else None 172 | recall_rows = row_scores["true_pos"] / (row_scores["true_pos"] + row_scores["false_neg"]) if (row_scores["true_pos"] + row_scores["false_neg"]) !=0 else None 173 | recall_cols = col_scores["true_pos"] / (col_scores["true_pos"] + col_scores["false_neg"]) if (col_scores["true_pos"] + col_scores["false_neg"]) !=0 else None 174 | 175 | print("Table:","Precision:",precision_tables,"Recall:",recall_tables,"Mean IoU:",table_scores["mean_IoU"]/table_no if table_no else None) 176 | print("row:","Precision:",precision_rows,"Recall:",recall_rows,"Mean IoU:",row_scores["mean_IoU"]/row_no if row_no else None) 177 | print("col:","Precision:",precision_cols,"Recall:",recall_cols,"Mean IoU:",col_scores["mean_IoU"]/col_no if col_no else None) 178 | print("row F1",2*precision_rows*recall_rows/(precision_rows+recall_rows),"col F1",2*precision_cols*recall_cols/(precision_cols+recall_cols)) -------------------------------------------------------------------------------- /dla/src/get_dataset_summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import xml_utils as xml_utils 4 | import pandas as pd 5 | 6 | summaries = [] 7 | dataset_root = '/media/DATA/GloSAT_dataset_fine' 8 | files = set() 9 | 10 | for source in os.listdir( dataset_root ) : 11 | 12 | cells = 0 13 | tables = 0 14 | images = 0 15 | header_cells = 0 16 | headers = 0 17 | table_per_type = collections.defaultdict(lambda:0) 18 | docs_per_type = collections.defaultdict(lambda:0) 19 | 20 | if "labels" not in os.listdir(os.path.join(dataset_root,source)): 21 | continue 22 | 23 | for file in os.listdir(os.path.join(dataset_root,source,"labels")): 24 | if file.endswith(".xml"): 25 | 26 | if file.strip(".xml") + ".jpg" not in os.listdir(os.path.join(dataset_root,source)) and file.strip(".xml") + ".JPG" not in os.listdir(os.path.join(dataset_root,source)): 27 | continue 28 | 29 | images += 1 30 | 31 | icdar_parsed, doc_type = xml_utils.get_ICDAR_summary( os.path.join(dataset_root,source,"labels",file) ) 32 | 33 | docs_per_type[doc_type] += 1 34 | 35 | for entry in icdar_parsed: 36 | tables += 1 37 | cells += len(entry["cells"]) 38 | header_cells += entry["header_no"] 39 | headers += entry["header_no"]!=0 40 | table_per_type[entry["type"]] += 1 41 | 42 | summary = pd.DataFrame([[source,images,cells,tables,header_cells,headers]],columns=["source","images","cells","tables","header_cells","headers"]) 43 | 44 | for table_type in pd.unique(list(table_per_type.keys())): 45 | summary[table_type + " tables"] = table_per_type[table_type] 46 | 47 | for doc_type in pd.unique(list(docs_per_type.keys())): 48 | summary[doc_type + " docs"] = docs_per_type[doc_type] 49 | 50 | summaries.append(summary) 51 | 52 | summaries = pd.concat(summaries).fillna(0) 53 | 54 | 55 | print("\n\n",summaries) 56 | 57 | #summaries.to_csv("dataset_summary_fine") 58 | 59 | print("\n\n",summaries.sum(axis=0)) -------------------------------------------------------------------------------- /dla/src/image_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | import cv2 18 | import numpy as np 19 | from sklearn.cluster import KMeans 20 | 21 | #Ratios that indicate an edge 22 | VERTICAL_EDGE = 1/5 23 | HORIZONTAL_EDGE = 10 24 | CELL_DIM_THRESHOLD = 0.05 25 | 26 | #Character or cell decision 27 | MIN_THRESHOLD = 0.01 28 | MAX_CELL_AREA = 0.25 29 | 30 | def put_box(image,box,colour,text=None,thickness=2): 31 | image= cv2.rectangle(image,((int)(box[0]),(int)(box[1])),((int)(box[2]),(int)(box[3])),colour,thickness=thickness) 32 | 33 | if text: 34 | image = cv2.putText(image,text,((int)(box[0]),(int)(box[1])),cv2.FONT_HERSHEY_SIMPLEX,fontScale = 1,color = colour,thickness = thickness) 35 | 36 | return image 37 | 38 | def put_line(image,start,end,colour,thickness=3): 39 | image = cv2.line(image,start,end,color=colour,thickness=thickness) 40 | return image 41 | 42 | def erosion(image,kernel_size=3,iters = 1): 43 | image = cv2.erode(image.astype(np.uint8), np.ones((kernel_size,kernel_size), np.uint8) , iterations=iters) 44 | return image 45 | 46 | def find_contours(image,mode=cv2.RETR_EXTERNAL): 47 | 48 | ret, thresh = cv2.threshold(image.astype(np.uint8), 50, 255, 0) 49 | contours, hierarchy = cv2.findContours(thresh, mode, cv2.CHAIN_APPROX_TC89_KCOS) 50 | 51 | boxes = [] 52 | for cnt in contours: 53 | x,y,w,h = cv2.boundingRect(cnt) 54 | boxes.append([x,y,x+w,y+h]) 55 | 56 | return boxes 57 | 58 | def filter_pixels(labels_im,label): 59 | ''' 60 | Filter the mask, removing the pixels not belonging to the layer. 61 | ''' 62 | 63 | filtered_im = np.zeros_like(labels_im) 64 | 65 | filtered_im[labels_im==label] = 255 66 | 67 | return filtered_im 68 | 69 | def normalise_contrast(img): 70 | ''' 71 | Performs local contrast normalisation 72 | ''' 73 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) 74 | img_normalised = clahe.apply(img) 75 | 76 | return img_normalised 77 | 78 | def preprocess(image): 79 | image = normalise_contrast(image) 80 | image = cv2.threshold(image, 120, 255, cv2.THRESH_BINARY)[1] 81 | return image 82 | 83 | def postprocess(region,image_shape,area_threshold=0.001,dim_threshold=0.02): 84 | area = lambda box: (box[3] - box[1]) * (box[2] - box[0]) 85 | 86 | if area(region)< area((0,0) + image_shape) * area_threshold: 87 | return False 88 | 89 | x1,y1,x2,y2 = region 90 | h,w = image_shape 91 | 92 | if x2-x1MAX_CELL_AREA*image_shape[0] * image_shape[1]: 167 | continue 168 | 169 | if area(c)>=threshold : 170 | cells.append(c) 171 | 172 | else: 173 | characters.append(c) 174 | 175 | return cells,characters 176 | 177 | def find_background_contours(img_original,mode=cv2.RETR_LIST): 178 | img = cv2.erode(img_original, np.ones((3,3), np.uint8) , iterations=1) 179 | img = cv2.threshold(img, 125, 255, cv2.THRESH_BINARY)[1] 180 | 181 | num_labels, labels_im = cv2.connectedComponents(img) 182 | 183 | boxes = find_contours(invert_image(labels_im),mode=mode) 184 | cells_or_characters = [] 185 | 186 | for no,box in enumerate(boxes): 187 | 188 | #Do not process boxes outside the image 189 | if any(box[i]>=img_original.shape[(i+1)%2] for i in range(4)): 190 | continue 191 | 192 | #Check if box is not an edge 193 | if not(((box[2]-box[0])/(box[3]-box[1]) > HORIZONTAL_EDGE and (box[3]-box[1])0] = -1 201 | image += 1 202 | image *= 255 203 | return image 204 | 205 | def run_kmeans_areas(areas): 206 | kmeans = KMeans(n_clusters=2, random_state=0).fit(areas) 207 | return np.mean(kmeans.cluster_centers_) 208 | 209 | def non_text_removal(image): 210 | image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY) 211 | cells_or_characters = find_background_contours(image) 212 | _, characters = divide_cells_or_characters(cells_or_characters,image.shape) 213 | mask = np.zeros_like(image) 214 | 215 | for box in characters: 216 | mask[box[1]:box[3],box[0]:box[2]] = 1 217 | 218 | image *= mask 219 | 220 | image[image==0] = 255 221 | 222 | return image 223 | -------------------------------------------------------------------------------- /dla/src/inference.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | from mmdet.apis import init_detector, inference_detector 18 | import mmcv 19 | import dla.src.table_structure_analysis as tsa 20 | import dla.src.xml_utils as xml_utils 21 | from dla.src.image_utils import put_box, put_line 22 | import argparse 23 | import os 24 | import cv2 25 | import collections 26 | import numpy as np 27 | 28 | THRESHOLD = 0.5 29 | CLASSES = ("table_body","cell","full_table","header","heading") 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('table_checkpoint',type=str) 33 | parser.add_argument('--cell_checkpoint',type=str,dest='cell_checkpoint') 34 | parser.add_argument('--coarse_cell_checkpoint',type=str,dest='coarse_cell_checkpoint') 35 | parser.add_argument('--load_from',type=str,dest='load_from') 36 | parser.add_argument('--out',type=str,dest='out') 37 | parser.add_argument('--voc',type=bool,dest='voc') 38 | parser.add_argument('--visual',type=bool,dest='visual') 39 | parser.add_argument('--raw_cells',type=bool,dest='raw_cells') 40 | parser.add_argument('--skip_headers',type=bool,dest='skip_headers') 41 | args = parser.parse_args() 42 | 43 | # Load model 44 | table_checkpoint_file = args.table_checkpoint 45 | cell_checkpoint_file = args.cell_checkpoint 46 | coarse_cell_checkpoint_file = args.coarse_cell_checkpoint 47 | path = args.load_from 48 | save_to = args.out 49 | raw_cells = args.raw_cells 50 | image_list = os.listdir(path) 51 | config_file = "dla/config/cascadeRCNN.py" 52 | segment_headers = args.skip_headers 53 | 54 | cells_by_image = collections.defaultdict(list) 55 | coarse_cells_by_image = collections.defaultdict(list) 56 | 57 | if cell_checkpoint_file: 58 | model = init_detector(config_file, cell_checkpoint_file, device='cuda:0') 59 | 60 | for image_name in image_list: 61 | 62 | if args.visual: 63 | image = cv2.imread(os.path.join(path,image_name)) 64 | 65 | width, height,_ = image.shape 66 | 67 | # Run Inference 68 | result = inference_detector(model, os.path.join(path,image_name)) 69 | 70 | #Process cells 71 | cells_by_image[image_name] = result[CLASSES.index("cell")].tolist() 72 | 73 | del model 74 | 75 | if coarse_cell_checkpoint_file: 76 | model = init_detector(config_file, coarse_cell_checkpoint_file, device='cuda:0') 77 | 78 | 79 | for image_name in image_list: 80 | 81 | if args.visual: 82 | image = cv2.imread(os.path.join(path,image_name)) 83 | 84 | width, height,_ = image.shape 85 | 86 | # Run Inference 87 | result = inference_detector(model, os.path.join(path,image_name)) 88 | 89 | #Process cells 90 | coarse_cells_by_image[image_name] += result[CLASSES.index("cell")].tolist() 91 | 92 | del model 93 | 94 | model = init_detector(config_file, table_checkpoint_file, device='cuda:0') 95 | 96 | for image_name in image_list: 97 | 98 | if args.visual: 99 | image = cv2.imread(os.path.join(path,image_name)) 100 | 101 | width, height,_ = image.shape 102 | else: 103 | width, height = 1000, 1000 104 | 105 | # Run Inference 106 | result = inference_detector(model,os.path.join(path,image_name)) 107 | 108 | #Process table headings 109 | headings = [] 110 | for box in result[CLASSES.index("heading")]: 111 | if box[4]>THRESHOLD : 112 | headings.append(box[0:4]) 113 | 114 | if args.visual: 115 | put_box(image,box,(255,0,255),"heading") 116 | 117 | #Process table headers 118 | headers = [] 119 | for box in result[CLASSES.index("header")]: 120 | if box[4]>THRESHOLD : 121 | headers.append(box[0:4]) 122 | 123 | if args.visual: 124 | put_box(image,box,(255,0,0),"header") 125 | 126 | #Process table bodies 127 | tables = [] 128 | for box in result[CLASSES.index("table_body")]: 129 | if box[4]>THRESHOLD : 130 | tables.append(box[0:4]) 131 | 132 | if args.visual: 133 | put_box(image,box,(0,0,255),"table_body") 134 | 135 | #Process tables 136 | full_tables = [] 137 | for box in result[CLASSES.index("full_table")]: 138 | if box[4]>THRESHOLD : 139 | full_tables.append(box[0:4]) 140 | 141 | if all(tsa.how_much_contained(table,box)<0.5 for table in tables): 142 | tables.append(box[0:4]) 143 | if args.visual: 144 | put_box(image,box,(0,0,255),"table_body") 145 | 146 | if args.visual: 147 | put_box(image,box,(0,255,255),"Full_Table") 148 | 149 | for table in tables: 150 | if all(tsa.how_much_contained(table,full_table)<0.5 for full_table in full_tables): 151 | full_tables.append(table) 152 | 153 | 154 | if raw_cells: 155 | cells = [] 156 | for box in cells_by_image[image_name]: 157 | if box[4]>THRESHOLD : 158 | cells.append(box[0:4]) 159 | 160 | if args.visual: 161 | put_box(image,box,(0,0,255),"cell") 162 | 163 | xml_utils.save_VOC_xml_from_cells(headings,headers,tables,full_tables,cells,save_to + image_name.split(".")[-2] + ".xml",width,height) 164 | else: 165 | #Process cells 166 | 167 | rows_by_table = [] 168 | cols_by_table = [] 169 | full_table_by_table = [] 170 | 171 | for table in tables: 172 | cells = [] 173 | coarse_cells = [] 174 | 175 | found_fulltable = False 176 | for full_table in full_tables: 177 | if tsa.how_much_contained(table,full_table)>0.5: 178 | full_table_by_table.append(full_table) 179 | found_fulltable = True 180 | break 181 | 182 | if not found_fulltable: 183 | full_table_by_table.append(table) 184 | 185 | 186 | if cell_checkpoint_file: 187 | for box in cells_by_image[image_name]: 188 | cell = box[0:4] 189 | 190 | if box[4]>THRESHOLD: 191 | if tsa.how_much_contained(cell,table if not segment_headers else full_table_by_table[-1])>0.5: 192 | cells.append(cell) 193 | #if args.visual: 194 | # put_box(image,box,(0,0,0)) 195 | 196 | if coarse_cell_checkpoint_file: 197 | for box in coarse_cells_by_image[image_name]: 198 | cell = box[0:4] 199 | 200 | if box[4]>THRESHOLD: 201 | if tsa.how_much_contained(cell,table if not segment_headers else full_table_by_table[-1])>0.5: 202 | coarse_cells.append(cell) 203 | #if args.visual: 204 | # put_box(image,box,(0,0,0)) 205 | 206 | if cells != [] or coarse_cells!=[]: 207 | if coarse_cell_checkpoint_file: 208 | rows, cols = tsa.reconstruct_table_coarse_and_fine(coarse_cells,cells,table if segment_headers else full_table_by_table[-1],eps=0.02) 209 | else: 210 | rows, cols = tsa.reconstruct_table(cells,table if segment_headers else full_table_by_table[-1],eps=0.02) 211 | else: 212 | rows,cols = [],[] 213 | 214 | if args.visual: 215 | for row in rows: 216 | put_line(image,((int)(table[0]),(int)(row)),((int)(table[2]),(int)(row)),colour=(0,255,0)) 217 | 218 | for col in cols: 219 | put_line(image,((int)(col),(int)(table[1])),((int)(col),(int)(table[3])),colour=(0,255,0)) 220 | 221 | rows_by_table.append(rows) 222 | cols_by_table.append(cols) 223 | 224 | if not(args.voc): 225 | xml_utils.save_ICDAR_xml(full_table_by_table,cols_by_table,rows_by_table,save_to + image_name.split(".")[-2] + ".xml") 226 | else: 227 | xml_utils.save_VOC_xml(headings,headers,tables,full_tables,cols_by_table,rows_by_table,save_to + image_name.split(".")[-2] + ".xml",width,height) 228 | 229 | if args.visual: 230 | cv2.imwrite(save_to + "out_%s"%(image_name),image) 231 | -------------------------------------------------------------------------------- /dla/src/inference_original.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | from mmdet.apis import init_detector, inference_detector 18 | import mmcv 19 | import dla.src.table_structure_analysis as tsa 20 | import dla.src.xml_utils as xml_utils 21 | from dla.src.image_utils import put_box, put_line 22 | import argparse 23 | import os 24 | import cv2 25 | 26 | THRESHOLD = 0.5 27 | CLASSES = ("bordered_table","cell","borderless_table") 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('table_checkpoint',type=str) 31 | parser.add_argument('--load_from',type=str,dest='load_from') 32 | parser.add_argument('--use_cells',type=bool,dest='use_cells',default=False) 33 | parser.add_argument('--out',type=str,dest='out') 34 | parser.add_argument('--voc',type=bool,dest='voc') 35 | parser.add_argument('--visual',type=bool,dest='visual') 36 | args = parser.parse_args() 37 | 38 | # Load model 39 | table_checkpoint_file = args.table_checkpoint 40 | path = args.load_from 41 | save_to = args.out 42 | image_list = os.listdir(path) 43 | config_file = "dla/config/cascadeRCNN.py" 44 | 45 | 46 | model = init_detector(config_file, table_checkpoint_file, device='cuda:0') 47 | 48 | for image_name in image_list: 49 | 50 | if args.visual: 51 | image = cv2.imread(os.path.join(path,image_name)) 52 | 53 | width, height,_ = image.shape 54 | 55 | # Run Inference 56 | result = inference_detector(model,os.path.join(path,image_name)) 57 | 58 | #Process tables 59 | bordered_tables = [] 60 | for box in result[CLASSES.index("bordered_table")]: 61 | if box[4]>THRESHOLD : 62 | bordered_tables.append(box[0:4].tolist()) 63 | 64 | if args.visual: 65 | put_box(image,box,(0,255,255),"Full_Table") 66 | 67 | borderless_tables = [] 68 | for box in result[CLASSES.index("borderless_table")]: 69 | if box[4]>THRESHOLD : 70 | borderless_tables.append(box[0:4].tolist()) 71 | 72 | if args.visual: 73 | put_box(image,box,(0,255,255),"Full_Table") 74 | 75 | full_tables = [] 76 | 77 | #If borderless table and bordered table are both detected, save only the one with greater area 78 | for bordered_table in bordered_tables: 79 | borderless_match = None 80 | for borderless_table in borderless_tables: 81 | if tsa.IoU(bordered_table,borderless_table)>0.5: 82 | if tsa.area(borderless_table)>tsa.area(bordered_table): 83 | borderless_match = borderless_table 84 | 85 | if borderless_match: 86 | full_tables.append(borderless_match) 87 | else: 88 | full_tables.append(bordered_table) 89 | 90 | for borderless_table in borderless_tables: 91 | bordered_match = None 92 | for bordered_table in bordered_tables: 93 | if tsa.IoU(borderless_table,bordered_table)>0.5: 94 | if tsa.area(borderless_table)THRESHOLD : 113 | cells.append(box[0:4].tolist()) 114 | 115 | if args.visual: 116 | put_box(image,box,(0,255,255),"Cell") 117 | 118 | for table in full_tables: 119 | cells = [] 120 | for box in cells: 121 | cell = box[0:4].tolist() 122 | 123 | if box[4]>THRESHOLD: 124 | if tsa.how_much_contained(cell,table)>0.5: 125 | cells.append(cell) 126 | if args.visual: 127 | put_box(image,box,(0,255,0)) 128 | 129 | if cells != []: 130 | rows, cols = tsa.reconstruct_table(cells,table,eps=0.02) 131 | else: 132 | rows,cols = [], [] 133 | 134 | if args.visual: 135 | for row in rows: 136 | put_line(image,((int)(table[0]),(int)(row)),((int)(table[2]),(int)(row)),colour=(0,255,0)) 137 | 138 | for col in cols: 139 | put_line(image,((int)(col),(int)(table[1])),((int)(col),(int)(table[3])),colour=(0,255,0)) 140 | 141 | rows_by_table.append(rows) 142 | cols_by_table.append(cols) 143 | else: 144 | rows_by_table = [[] for _ in full_tables] 145 | cols_by_table = [[] for _ in full_tables] 146 | 147 | if not(args.voc): 148 | xml_utils.save_ICDAR_xml(full_tables,cols_by_table,rows_by_table,save_to + image_name.split(".")[-2] + ".xml") 149 | else: 150 | xml_utils.save_VOC_xml([],[],[],full_tables,cols_by_table,rows_by_table,save_to + image_name.split(".")[-2] + ".xml",width,height) 151 | 152 | if args.visual: 153 | cv2.imwrite(save_to + "out_%s"%(image_name),image) 154 | -------------------------------------------------------------------------------- /dla/src/inference_regiongiven.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | from mmdet.apis import init_detector, inference_detector 18 | import mmcv 19 | import dla.src.table_structure_analysis as tsa 20 | import dla.src.xml_utils as xml_utils 21 | from dla.src.image_utils import put_box, put_line 22 | import argparse 23 | import os 24 | import cv2 25 | import collections 26 | import numpy as np 27 | 28 | from xml_utils import load_VOC_xml 29 | 30 | THRESHOLD = 0.5 31 | CLASSES = ("table_body","cell","full_table","header","heading") 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('annotation',type=str) 35 | parser.add_argument('--cell_checkpoint',type=str,dest='cell_checkpoint') 36 | parser.add_argument('--coarse_cell_checkpoint',type=str,dest='coarse_cell_checkpoint') 37 | parser.add_argument('--load_from',type=str,dest='load_from') 38 | parser.add_argument('--out',type=str,dest='out') 39 | parser.add_argument('--voc', type=bool,dest='voc') 40 | parser.add_argument('--visual', type=bool,dest='visual') 41 | parser.add_argument('--raw_cells', type=bool,dest='raw_cells') 42 | parser.add_argument('--skip_headers', type=bool,dest='skip_headers') 43 | args = parser.parse_args() 44 | 45 | # Load model 46 | annotation_folder = args.annotation 47 | cell_checkpoint_file = args.cell_checkpoint 48 | coarse_cell_checkpoint_file = args.coarse_cell_checkpoint 49 | path = args.load_from 50 | save_to = args.out 51 | raw_cells = args.raw_cells 52 | annotations = set(file.strip(".xml") for file in os.listdir(annotation_folder)) 53 | image_list = [file for file in os.listdir(path) if file.strip(".jpg") in annotations] 54 | config_file = "dla/config/cascadeRCNN.py" 55 | segment_headers = not args.skip_headers 56 | 57 | cells_by_image = collections.defaultdict(list) 58 | coarse_cells_by_image = collections.defaultdict(list) 59 | 60 | if cell_checkpoint_file: 61 | model = init_detector(config_file, cell_checkpoint_file, device='cuda:0') 62 | 63 | for image_name in image_list: 64 | 65 | if args.visual: 66 | image = cv2.imread(os.path.join(path,image_name)) 67 | 68 | width, height,_ = image.shape 69 | 70 | # Run Inference 71 | result = inference_detector(model, os.path.join(path,image_name)) 72 | 73 | #Process cells 74 | cells_by_image[image_name] = result[CLASSES.index("cell")].tolist() 75 | 76 | del model 77 | 78 | if coarse_cell_checkpoint_file: 79 | model = init_detector(config_file, coarse_cell_checkpoint_file, device='cuda:0') 80 | 81 | 82 | for image_name in image_list: 83 | 84 | if args.visual: 85 | image = cv2.imread(os.path.join(path,image_name)) 86 | 87 | width, height,_ = image.shape 88 | 89 | # Run Inference 90 | result = inference_detector(model, os.path.join(path,image_name)) 91 | 92 | #Process cells 93 | coarse_cells_by_image[image_name] += result[CLASSES.index("cell")].tolist() 94 | 95 | del model 96 | 97 | 98 | for image_name in image_list: 99 | 100 | if args.visual: 101 | image = cv2.imread(os.path.join(path,image_name)) 102 | 103 | width, height,_ = image.shape 104 | else: 105 | width, height = 1000, 1000 106 | 107 | # Run Inference 108 | result = load_VOC_xml(os.path.join(annotation_folder,image_name.strip(".jpg") + ".xml")) 109 | 110 | headings = [] 111 | headers = [] 112 | tables = [] 113 | full_tables = [] 114 | 115 | for object in result: 116 | if object["name"] =="heading": 117 | headings.append(object["bbox"]) 118 | 119 | if object["name"] =="header": 120 | headers.append(object["bbox"]) 121 | 122 | if object["name"] =="table_body": 123 | tables.append(object["bbox"]) 124 | 125 | if object["name"] =="full_table": 126 | full_tables.append(object["bbox"]) 127 | 128 | if raw_cells: 129 | cells = [] 130 | for box in cells_by_image[image_name]: 131 | if box[4]>THRESHOLD : 132 | cells.append(box[0:4]) 133 | 134 | if args.visual: 135 | put_box(image,box,(0,0,255),"cell") 136 | 137 | xml_utils.save_VOC_xml_from_cells(headings,headers,tables,full_tables,cells,save_to + image_name.split(".")[-2] + ".xml",width,height) 138 | else: 139 | #Process cells 140 | 141 | rows_by_table = [] 142 | cols_by_table = [] 143 | full_table_by_table = [] 144 | 145 | for table in tables: 146 | cells = [] 147 | coarse_cells = [] 148 | 149 | found_fulltable = False 150 | for full_table in full_tables: 151 | if tsa.how_much_contained(table,full_table)>0.5: 152 | full_table_by_table.append(full_table) 153 | found_fulltable = True 154 | break 155 | 156 | if not found_fulltable: 157 | full_table_by_table.append(table) 158 | 159 | 160 | if cell_checkpoint_file: 161 | for box in cells_by_image[image_name]: 162 | cell = box[0:4] 163 | 164 | if box[4]>THRESHOLD: 165 | if tsa.how_much_contained(cell,table if not segment_headers else full_table_by_table[-1])>0.5: 166 | cells.append(cell) 167 | #if args.visual: 168 | # put_box(image,box,(0,0,0)) 169 | 170 | if coarse_cell_checkpoint_file: 171 | for box in coarse_cells_by_image[image_name]: 172 | cell = box[0:4] 173 | 174 | if box[4]>THRESHOLD: 175 | if tsa.how_much_contained(cell,table if not segment_headers else full_table_by_table[-1])>0.5: 176 | coarse_cells.append(cell) 177 | #if args.visual: 178 | # put_box(image,box,(0,0,0)) 179 | 180 | if cells != [] or coarse_cells!=[]: 181 | if coarse_cell_checkpoint_file: 182 | rows, cols = tsa.reconstruct_table_coarse_and_fine(coarse_cells,cells,table if not segment_headers else full_table_by_table[-1],eps=0.02) 183 | else: 184 | rows, cols = tsa.reconstruct_table(cells,table if not segment_headers else full_table_by_table[-1],eps=0.02) 185 | else: 186 | rows,cols = [],[] 187 | 188 | if args.visual: 189 | for row in rows: 190 | put_line(image,((int)(table[0]),(int)(row)),((int)(table[2]),(int)(row)),colour=(0,255,0)) 191 | 192 | for col in cols: 193 | put_line(image,((int)(col),(int)(table[1])),((int)(col),(int)(table[3])),colour=(0,255,0)) 194 | 195 | rows_by_table.append(rows) 196 | cols_by_table.append(cols) 197 | 198 | if not(args.voc): 199 | xml_utils.save_ICDAR_xml(full_table_by_table,cols_by_table,rows_by_table,save_to + image_name.split(".")[-2] + ".xml") 200 | else: 201 | xml_utils.save_VOC_xml(headings,headers,tables,full_tables,cols_by_table,rows_by_table,save_to + image_name.split(".")[-2] + ".xml",width,height) 202 | 203 | if args.visual: 204 | cv2.imwrite(save_to + "out_%s"%(image_name),image) 205 | -------------------------------------------------------------------------------- /dla/src/install.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | os.system(f"cp dla/src/installation_files/dataset__init__.py mmdet/datasets/__init__.py") 5 | os.system(f"cp dla/src/installation_files/ignoringvoc.py mmdet/datasets/ignoringvoc.py") 6 | os.system(f"cp dla/src/installation_files/mean_ap.py mmdet/core/evaluation/mean_ap.py") 7 | os.system(f"cp dla/src/installation_files/inference.py mmdet/apis/inference.py") 8 | 9 | os.system(f"cp dla/src/installation_files/detectors__init__.py mmdet/models/detectors/__init__.py") 10 | os.system(f"cp dla/src/installation_files/cascade_rcnn_frozen.py mmdet/models/detectors/cascade_rcnn_frozen.py") 11 | os.system(f"cp dla/src/installation_files/cascade_rcnn_frozenrpn.py mmdet/models/detectors/cascade_rcnn_frozenrpn.py") 12 | 13 | sys.stdout.write("Installation complete. \n") 14 | -------------------------------------------------------------------------------- /dla/src/installation_files/cascade_rcnn_frozen.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner, 7 | build_sampler, merge_aug_bboxes, merge_aug_masks, 8 | multiclass_nms) 9 | from .. import builder 10 | from ..registry import DETECTORS 11 | from .base import BaseDetector 12 | from .test_mixins import RPNTestMixin 13 | 14 | 15 | @DETECTORS.register_module 16 | class CascadeRCNNFrozen(BaseDetector, RPNTestMixin): 17 | 18 | def __init__(self, 19 | num_stages, 20 | backbone, 21 | neck=None, 22 | shared_head=None, 23 | rpn_head=None, 24 | bbox_roi_extractor=None, 25 | bbox_head=None, 26 | mask_roi_extractor=None, 27 | mask_head=None, 28 | train_cfg=None, 29 | test_cfg=None, 30 | pretrained=None): 31 | assert bbox_roi_extractor is not None 32 | assert bbox_head is not None 33 | super(CascadeRCNNFrozen, self).__init__() 34 | 35 | self.num_stages = num_stages 36 | self.backbone = builder.build_backbone(backbone) 37 | 38 | if neck is not None: 39 | self.neck = builder.build_neck(neck) 40 | 41 | if rpn_head is not None: 42 | self.rpn_head = builder.build_head(rpn_head) 43 | 44 | if shared_head is not None: 45 | self.shared_head = builder.build_shared_head(shared_head) 46 | 47 | if bbox_head is not None: 48 | self.bbox_roi_extractor = nn.ModuleList() 49 | self.bbox_head = nn.ModuleList() 50 | if not isinstance(bbox_roi_extractor, list): 51 | bbox_roi_extractor = [ 52 | bbox_roi_extractor for _ in range(num_stages) 53 | ] 54 | if not isinstance(bbox_head, list): 55 | bbox_head = [bbox_head for _ in range(num_stages)] 56 | assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages 57 | for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): 58 | self.bbox_roi_extractor.append( 59 | builder.build_roi_extractor(roi_extractor)) 60 | self.bbox_head.append(builder.build_head(head)) 61 | 62 | if mask_head is not None: 63 | self.mask_head = nn.ModuleList() 64 | if not isinstance(mask_head, list): 65 | mask_head = [mask_head for _ in range(num_stages)] 66 | assert len(mask_head) == self.num_stages 67 | for head in mask_head: 68 | self.mask_head.append(builder.build_head(head)) 69 | if mask_roi_extractor is not None: 70 | self.share_roi_extractor = False 71 | self.mask_roi_extractor = nn.ModuleList() 72 | if not isinstance(mask_roi_extractor, list): 73 | mask_roi_extractor = [ 74 | mask_roi_extractor for _ in range(num_stages) 75 | ] 76 | assert len(mask_roi_extractor) == self.num_stages 77 | for roi_extractor in mask_roi_extractor: 78 | self.mask_roi_extractor.append( 79 | builder.build_roi_extractor(roi_extractor)) 80 | else: 81 | self.share_roi_extractor = True 82 | self.mask_roi_extractor = self.bbox_roi_extractor 83 | 84 | self.train_cfg = train_cfg 85 | self.test_cfg = test_cfg 86 | 87 | self.init_weights(pretrained=pretrained) 88 | 89 | @property 90 | def with_rpn(self): 91 | return hasattr(self, 'rpn_head') and self.rpn_head is not None 92 | 93 | def init_weights(self, pretrained=None): 94 | super(CascadeRCNNFrozen, self).init_weights(pretrained) 95 | self.backbone.init_weights(pretrained=pretrained) 96 | if self.with_neck: 97 | if isinstance(self.neck, nn.Sequential): 98 | for m in self.neck: 99 | m.init_weights() 100 | else: 101 | self.neck.init_weights() 102 | if self.with_rpn: 103 | self.rpn_head.init_weights() 104 | if self.with_shared_head: 105 | self.shared_head.init_weights(pretrained=pretrained) 106 | for i in range(self.num_stages): 107 | if self.with_bbox: 108 | self.bbox_roi_extractor[i].init_weights() 109 | self.bbox_head[i].init_weights() 110 | if self.with_mask: 111 | if not self.share_roi_extractor: 112 | self.mask_roi_extractor[i].init_weights() 113 | self.mask_head[i].init_weights() 114 | 115 | def extract_feat(self, img): 116 | x = self.backbone(img) 117 | if self.with_neck: 118 | x = self.neck(x) 119 | return x 120 | 121 | def forward_dummy(self, img): 122 | outs = () 123 | # backbone 124 | x = self.extract_feat(img) 125 | # rpn 126 | if self.with_rpn: 127 | rpn_outs = self.rpn_head(x) 128 | outs = outs + (rpn_outs, ) 129 | proposals = torch.randn(1000, 4).to(device=img.device) 130 | # bbox heads 131 | rois = bbox2roi([proposals]) 132 | if self.with_bbox: 133 | for i in range(self.num_stages): 134 | bbox_feats = self.bbox_roi_extractor[i]( 135 | x[:self.bbox_roi_extractor[i].num_inputs], rois) 136 | if self.with_shared_head: 137 | bbox_feats = self.shared_head(bbox_feats) 138 | cls_score, bbox_pred = self.bbox_head[i](bbox_feats) 139 | outs = outs + (cls_score, bbox_pred) 140 | # mask heads 141 | if self.with_mask: 142 | mask_rois = rois[:100] 143 | for i in range(self.num_stages): 144 | mask_feats = self.mask_roi_extractor[i]( 145 | x[:self.mask_roi_extractor[i].num_inputs], mask_rois) 146 | if self.with_shared_head: 147 | mask_feats = self.shared_head(mask_feats) 148 | mask_pred = self.mask_head[i](mask_feats) 149 | outs = outs + (mask_pred, ) 150 | return outs 151 | 152 | def forward_train(self, 153 | img, 154 | img_metas, 155 | gt_bboxes, 156 | gt_labels, 157 | gt_bboxes_ignore=None, 158 | gt_masks=None, 159 | proposals=None): 160 | """ 161 | Args: 162 | img (Tensor): of shape (N, C, H, W) encoding input images. 163 | Typically these should be mean centered and std scaled. 164 | 165 | img_metas (list[dict]): list of image info dict where each dict 166 | has: 'img_shape', 'scale_factor', 'flip', and my also contain 167 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 168 | For details on the values of these keys see 169 | `mmdet/datasets/pipelines/formatting.py:Collect`. 170 | 171 | gt_bboxes (list[Tensor]): each item are the truth boxes for each 172 | image in [tl_x, tl_y, br_x, br_y] format. 173 | 174 | gt_labels (list[Tensor]): class indices corresponding to each box 175 | 176 | gt_bboxes_ignore (None | list[Tensor]): specify which bounding 177 | boxes can be ignored when computing the loss. 178 | 179 | gt_masks (None | Tensor) : true segmentation masks for each box 180 | used if the architecture supports a segmentation task. 181 | 182 | proposals : override rpn proposals with custom proposals. Use when 183 | `with_rpn` is False. 184 | 185 | Returns: 186 | dict[str, Tensor]: a dictionary of loss components 187 | """ 188 | 189 | with torch.no_grad(): 190 | x = self.extract_feat(img) 191 | 192 | losses = dict() 193 | 194 | if self.with_rpn: 195 | rpn_outs = self.rpn_head(x) 196 | rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas, 197 | self.train_cfg.rpn) 198 | rpn_losses = self.rpn_head.loss( 199 | *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) 200 | losses.update(rpn_losses) 201 | 202 | proposal_cfg = self.train_cfg.get('rpn_proposal', 203 | self.test_cfg.rpn) 204 | proposal_inputs = rpn_outs + (img_metas, proposal_cfg) 205 | proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) 206 | else: 207 | proposal_list = proposals 208 | 209 | for i in range(self.num_stages): 210 | self.current_stage = i 211 | rcnn_train_cfg = self.train_cfg.rcnn[i] 212 | lw = self.train_cfg.stage_loss_weights[i] 213 | 214 | # assign gts and sample proposals 215 | sampling_results = [] 216 | if self.with_bbox or self.with_mask: 217 | bbox_assigner = build_assigner(rcnn_train_cfg.assigner) 218 | bbox_sampler = build_sampler( 219 | rcnn_train_cfg.sampler, context=self) 220 | num_imgs = img.size(0) 221 | if gt_bboxes_ignore is None: 222 | gt_bboxes_ignore = [None for _ in range(num_imgs)] 223 | 224 | for j in range(num_imgs): 225 | assign_result = bbox_assigner.assign( 226 | proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], 227 | gt_labels[j]) 228 | sampling_result = bbox_sampler.sample( 229 | assign_result, 230 | proposal_list[j], 231 | gt_bboxes[j], 232 | gt_labels[j], 233 | feats=[lvl_feat[j][None] for lvl_feat in x]) 234 | sampling_results.append(sampling_result) 235 | 236 | # bbox head forward and loss 237 | bbox_roi_extractor = self.bbox_roi_extractor[i] 238 | bbox_head = self.bbox_head[i] 239 | 240 | rois = bbox2roi([res.bboxes for res in sampling_results]) 241 | 242 | if len(rois) == 0: 243 | # If there are no predicted and/or truth boxes, then we cannot 244 | # compute head / mask losses 245 | continue 246 | 247 | bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], 248 | rois) 249 | if self.with_shared_head: 250 | bbox_feats = self.shared_head(bbox_feats) 251 | cls_score, bbox_pred = bbox_head(bbox_feats) 252 | 253 | bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes, 254 | gt_labels, rcnn_train_cfg) 255 | loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets) 256 | for name, value in loss_bbox.items(): 257 | losses['s{}.{}'.format(i, name)] = ( 258 | value * lw if 'loss' in name else value) 259 | 260 | # mask head forward and loss 261 | if self.with_mask: 262 | if not self.share_roi_extractor: 263 | mask_roi_extractor = self.mask_roi_extractor[i] 264 | pos_rois = bbox2roi( 265 | [res.pos_bboxes for res in sampling_results]) 266 | mask_feats = mask_roi_extractor( 267 | x[:mask_roi_extractor.num_inputs], pos_rois) 268 | if self.with_shared_head: 269 | mask_feats = self.shared_head(mask_feats) 270 | else: 271 | # reuse positive bbox feats 272 | pos_inds = [] 273 | device = bbox_feats.device 274 | for res in sampling_results: 275 | pos_inds.append( 276 | torch.ones( 277 | res.pos_bboxes.shape[0], 278 | device=device, 279 | dtype=torch.uint8)) 280 | pos_inds.append( 281 | torch.zeros( 282 | res.neg_bboxes.shape[0], 283 | device=device, 284 | dtype=torch.uint8)) 285 | pos_inds = torch.cat(pos_inds) 286 | mask_feats = bbox_feats[pos_inds.type(torch.bool)] 287 | mask_head = self.mask_head[i] 288 | mask_pred = mask_head(mask_feats) 289 | mask_targets = mask_head.get_target(sampling_results, gt_masks, 290 | rcnn_train_cfg) 291 | pos_labels = torch.cat( 292 | [res.pos_gt_labels for res in sampling_results]) 293 | loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels) 294 | for name, value in loss_mask.items(): 295 | losses['s{}.{}'.format(i, name)] = ( 296 | value * lw if 'loss' in name else value) 297 | 298 | # refine bboxes 299 | if i < self.num_stages - 1: 300 | pos_is_gts = [res.pos_is_gt for res in sampling_results] 301 | roi_labels = bbox_targets[0] # bbox_targets is a tuple 302 | with torch.no_grad(): 303 | proposal_list = bbox_head.refine_bboxes( 304 | rois, roi_labels, bbox_pred, pos_is_gts, img_metas) 305 | 306 | 307 | 308 | return losses 309 | 310 | def simple_test(self, img, img_metas, proposals=None, rescale=False): 311 | """Run inference on a single image. 312 | 313 | Args: 314 | img (Tensor): must be in shape (N, C, H, W) 315 | img_metas (list[dict]): a list with one dictionary element. 316 | See `mmdet/datasets/pipelines/formatting.py:Collect` for 317 | details of meta dicts. 318 | proposals : if specified overrides rpn proposals 319 | rescale (bool): if True returns boxes in original image space 320 | 321 | Returns: 322 | dict: results 323 | """ 324 | x = self.extract_feat(img) 325 | 326 | proposal_list = self.simple_test_rpn( 327 | x, img_metas, 328 | self.test_cfg.rpn) if proposals is None else proposals 329 | 330 | img_shape = img_metas[0]['img_shape'] 331 | ori_shape = img_metas[0]['ori_shape'] 332 | scale_factor = img_metas[0]['scale_factor'] 333 | 334 | # "ms" in variable names means multi-stage 335 | ms_bbox_result = {} 336 | ms_segm_result = {} 337 | ms_scores = [] 338 | rcnn_test_cfg = self.test_cfg.rcnn 339 | 340 | rois = bbox2roi(proposal_list) 341 | for i in range(self.num_stages): 342 | bbox_roi_extractor = self.bbox_roi_extractor[i] 343 | bbox_head = self.bbox_head[i] 344 | 345 | bbox_feats = bbox_roi_extractor( 346 | x[:len(bbox_roi_extractor.featmap_strides)], rois) 347 | if self.with_shared_head: 348 | bbox_feats = self.shared_head(bbox_feats) 349 | 350 | cls_score, bbox_pred = bbox_head(bbox_feats) 351 | ms_scores.append(cls_score) 352 | 353 | if i < self.num_stages - 1: 354 | bbox_label = cls_score.argmax(dim=1) 355 | rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, 356 | img_metas[0]) 357 | 358 | cls_score = sum(ms_scores) / self.num_stages 359 | det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( 360 | rois, 361 | cls_score, 362 | bbox_pred, 363 | img_shape, 364 | scale_factor, 365 | rescale=rescale, 366 | cfg=rcnn_test_cfg) 367 | bbox_result = bbox2result(det_bboxes, det_labels, 368 | self.bbox_head[-1].num_classes) 369 | ms_bbox_result['ensemble'] = bbox_result 370 | 371 | if self.with_mask: 372 | if det_bboxes.shape[0] == 0: 373 | mask_classes = self.mask_head[-1].num_classes - 1 374 | segm_result = [[] for _ in range(mask_classes)] 375 | else: 376 | if isinstance(scale_factor, float): # aspect ratio fixed 377 | _bboxes = ( 378 | det_bboxes[:, :4] * 379 | scale_factor if rescale else det_bboxes) 380 | else: 381 | _bboxes = ( 382 | det_bboxes[:, :4] * 383 | torch.from_numpy(scale_factor).to(det_bboxes.device) 384 | if rescale else det_bboxes) 385 | 386 | mask_rois = bbox2roi([_bboxes]) 387 | aug_masks = [] 388 | for i in range(self.num_stages): 389 | mask_roi_extractor = self.mask_roi_extractor[i] 390 | mask_feats = mask_roi_extractor( 391 | x[:len(mask_roi_extractor.featmap_strides)], mask_rois) 392 | if self.with_shared_head: 393 | mask_feats = self.shared_head(mask_feats) 394 | mask_pred = self.mask_head[i](mask_feats) 395 | aug_masks.append(mask_pred.sigmoid().cpu().numpy()) 396 | merged_masks = merge_aug_masks(aug_masks, 397 | [img_metas] * self.num_stages, 398 | self.test_cfg.rcnn) 399 | segm_result = self.mask_head[-1].get_seg_masks( 400 | merged_masks, _bboxes, det_labels, rcnn_test_cfg, 401 | ori_shape, scale_factor, rescale) 402 | ms_segm_result['ensemble'] = segm_result 403 | 404 | if self.with_mask: 405 | results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble']) 406 | else: 407 | results = ms_bbox_result['ensemble'] 408 | 409 | return results 410 | 411 | def aug_test(self, imgs, img_metas, proposals=None, rescale=False): 412 | """Test with augmentations. 413 | 414 | If rescale is False, then returned bboxes and masks will fit the scale 415 | of imgs[0]. 416 | """ 417 | # recompute feats to save memory 418 | proposal_list = self.aug_test_rpn( 419 | self.extract_feats(imgs), img_metas, self.test_cfg.rpn) 420 | 421 | rcnn_test_cfg = self.test_cfg.rcnn 422 | aug_bboxes = [] 423 | aug_scores = [] 424 | for x, img_meta in zip(self.extract_feats(imgs), img_metas): 425 | # only one image in the batch 426 | img_shape = img_meta[0]['img_shape'] 427 | scale_factor = img_meta[0]['scale_factor'] 428 | flip = img_meta[0]['flip'] 429 | 430 | proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, 431 | scale_factor, flip) 432 | # "ms" in variable names means multi-stage 433 | ms_scores = [] 434 | 435 | rois = bbox2roi([proposals]) 436 | for i in range(self.num_stages): 437 | bbox_roi_extractor = self.bbox_roi_extractor[i] 438 | bbox_head = self.bbox_head[i] 439 | 440 | bbox_feats = bbox_roi_extractor( 441 | x[:len(bbox_roi_extractor.featmap_strides)], rois) 442 | if self.with_shared_head: 443 | bbox_feats = self.shared_head(bbox_feats) 444 | 445 | cls_score, bbox_pred = bbox_head(bbox_feats) 446 | ms_scores.append(cls_score) 447 | 448 | if i < self.num_stages - 1: 449 | bbox_label = cls_score.argmax(dim=1) 450 | rois = bbox_head.regress_by_class(rois, bbox_label, 451 | bbox_pred, img_meta[0]) 452 | 453 | cls_score = sum(ms_scores) / float(len(ms_scores)) 454 | bboxes, scores = self.bbox_head[-1].get_det_bboxes( 455 | rois, 456 | cls_score, 457 | bbox_pred, 458 | img_shape, 459 | scale_factor, 460 | rescale=False, 461 | cfg=None) 462 | aug_bboxes.append(bboxes) 463 | aug_scores.append(scores) 464 | 465 | # after merging, bboxes will be rescaled to the original image size 466 | merged_bboxes, merged_scores = merge_aug_bboxes( 467 | aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) 468 | det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, 469 | rcnn_test_cfg.score_thr, 470 | rcnn_test_cfg.nms, 471 | rcnn_test_cfg.max_per_img) 472 | 473 | bbox_result = bbox2result(det_bboxes, det_labels, 474 | self.bbox_head[-1].num_classes) 475 | 476 | if self.with_mask: 477 | if det_bboxes.shape[0] == 0: 478 | segm_result = [[] 479 | for _ in range(self.mask_head[-1].num_classes - 480 | 1)] 481 | else: 482 | aug_masks = [] 483 | aug_img_metas = [] 484 | for x, img_meta in zip(self.extract_feats(imgs), img_metas): 485 | img_shape = img_meta[0]['img_shape'] 486 | scale_factor = img_meta[0]['scale_factor'] 487 | flip = img_meta[0]['flip'] 488 | _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, 489 | scale_factor, flip) 490 | mask_rois = bbox2roi([_bboxes]) 491 | for i in range(self.num_stages): 492 | mask_feats = self.mask_roi_extractor[i]( 493 | x[:len(self.mask_roi_extractor[i].featmap_strides 494 | )], mask_rois) 495 | if self.with_shared_head: 496 | mask_feats = self.shared_head(mask_feats) 497 | mask_pred = self.mask_head[i](mask_feats) 498 | aug_masks.append(mask_pred.sigmoid().cpu().numpy()) 499 | aug_img_metas.append(img_meta) 500 | merged_masks = merge_aug_masks(aug_masks, aug_img_metas, 501 | self.test_cfg.rcnn) 502 | 503 | ori_shape = img_metas[0][0]['ori_shape'] 504 | segm_result = self.mask_head[-1].get_seg_masks( 505 | merged_masks, 506 | det_bboxes, 507 | det_labels, 508 | rcnn_test_cfg, 509 | ori_shape, 510 | scale_factor=1.0, 511 | rescale=False) 512 | return bbox_result, segm_result 513 | else: 514 | return bbox_result 515 | 516 | def show_result(self, data, result, **kwargs): 517 | if self.with_mask: 518 | ms_bbox_result, ms_segm_result = result 519 | if isinstance(ms_bbox_result, dict): 520 | result = (ms_bbox_result['ensemble'], 521 | ms_segm_result['ensemble']) 522 | else: 523 | if isinstance(result, dict): 524 | result = result['ensemble'] 525 | super(CascadeRCNNFrozen, self).show_result(data, result, **kwargs) 526 | -------------------------------------------------------------------------------- /dla/src/installation_files/cascade_rcnn_frozenrpn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner, 7 | build_sampler, merge_aug_bboxes, merge_aug_masks, 8 | multiclass_nms) 9 | from .. import builder 10 | from ..registry import DETECTORS 11 | from .base import BaseDetector 12 | from .test_mixins import RPNTestMixin 13 | 14 | 15 | @DETECTORS.register_module 16 | class CascadeRCNNFrozenRPN(BaseDetector, RPNTestMixin): 17 | 18 | def __init__(self, 19 | num_stages, 20 | backbone, 21 | neck=None, 22 | shared_head=None, 23 | rpn_head=None, 24 | bbox_roi_extractor=None, 25 | bbox_head=None, 26 | mask_roi_extractor=None, 27 | mask_head=None, 28 | train_cfg=None, 29 | test_cfg=None, 30 | pretrained=None): 31 | assert bbox_roi_extractor is not None 32 | assert bbox_head is not None 33 | super(CascadeRCNNFrozenRPN, self).__init__() 34 | 35 | self.num_stages = num_stages 36 | self.backbone = builder.build_backbone(backbone) 37 | 38 | if neck is not None: 39 | self.neck = builder.build_neck(neck) 40 | 41 | if rpn_head is not None: 42 | self.rpn_head = builder.build_head(rpn_head) 43 | 44 | if shared_head is not None: 45 | self.shared_head = builder.build_shared_head(shared_head) 46 | 47 | if bbox_head is not None: 48 | self.bbox_roi_extractor = nn.ModuleList() 49 | self.bbox_head = nn.ModuleList() 50 | if not isinstance(bbox_roi_extractor, list): 51 | bbox_roi_extractor = [ 52 | bbox_roi_extractor for _ in range(num_stages) 53 | ] 54 | if not isinstance(bbox_head, list): 55 | bbox_head = [bbox_head for _ in range(num_stages)] 56 | assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages 57 | for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): 58 | self.bbox_roi_extractor.append( 59 | builder.build_roi_extractor(roi_extractor)) 60 | self.bbox_head.append(builder.build_head(head)) 61 | 62 | if mask_head is not None: 63 | self.mask_head = nn.ModuleList() 64 | if not isinstance(mask_head, list): 65 | mask_head = [mask_head for _ in range(num_stages)] 66 | assert len(mask_head) == self.num_stages 67 | for head in mask_head: 68 | self.mask_head.append(builder.build_head(head)) 69 | if mask_roi_extractor is not None: 70 | self.share_roi_extractor = False 71 | self.mask_roi_extractor = nn.ModuleList() 72 | if not isinstance(mask_roi_extractor, list): 73 | mask_roi_extractor = [ 74 | mask_roi_extractor for _ in range(num_stages) 75 | ] 76 | assert len(mask_roi_extractor) == self.num_stages 77 | for roi_extractor in mask_roi_extractor: 78 | self.mask_roi_extractor.append( 79 | builder.build_roi_extractor(roi_extractor)) 80 | else: 81 | self.share_roi_extractor = True 82 | self.mask_roi_extractor = self.bbox_roi_extractor 83 | 84 | self.train_cfg = train_cfg 85 | self.test_cfg = test_cfg 86 | 87 | self.init_weights(pretrained=pretrained) 88 | 89 | @property 90 | def with_rpn(self): 91 | return hasattr(self, 'rpn_head') and self.rpn_head is not None 92 | 93 | def init_weights(self, pretrained=None): 94 | super(CascadeRCNNFrozenRPN, self).init_weights(pretrained) 95 | self.backbone.init_weights(pretrained=pretrained) 96 | if self.with_neck: 97 | if isinstance(self.neck, nn.Sequential): 98 | for m in self.neck: 99 | m.init_weights() 100 | else: 101 | self.neck.init_weights() 102 | if self.with_rpn: 103 | self.rpn_head.init_weights() 104 | if self.with_shared_head: 105 | self.shared_head.init_weights(pretrained=pretrained) 106 | for i in range(self.num_stages): 107 | if self.with_bbox: 108 | self.bbox_roi_extractor[i].init_weights() 109 | self.bbox_head[i].init_weights() 110 | if self.with_mask: 111 | if not self.share_roi_extractor: 112 | self.mask_roi_extractor[i].init_weights() 113 | self.mask_head[i].init_weights() 114 | 115 | def extract_feat(self, img): 116 | x = self.backbone(img) 117 | if self.with_neck: 118 | x = self.neck(x) 119 | return x 120 | 121 | def forward_dummy(self, img): 122 | outs = () 123 | # backbone 124 | x = self.extract_feat(img) 125 | # rpn 126 | if self.with_rpn: 127 | rpn_outs = self.rpn_head(x) 128 | outs = outs + (rpn_outs, ) 129 | proposals = torch.randn(1000, 4).to(device=img.device) 130 | # bbox heads 131 | rois = bbox2roi([proposals]) 132 | if self.with_bbox: 133 | for i in range(self.num_stages): 134 | bbox_feats = self.bbox_roi_extractor[i]( 135 | x[:self.bbox_roi_extractor[i].num_inputs], rois) 136 | if self.with_shared_head: 137 | bbox_feats = self.shared_head(bbox_feats) 138 | cls_score, bbox_pred = self.bbox_head[i](bbox_feats) 139 | outs = outs + (cls_score, bbox_pred) 140 | # mask heads 141 | if self.with_mask: 142 | mask_rois = rois[:100] 143 | for i in range(self.num_stages): 144 | mask_feats = self.mask_roi_extractor[i]( 145 | x[:self.mask_roi_extractor[i].num_inputs], mask_rois) 146 | if self.with_shared_head: 147 | mask_feats = self.shared_head(mask_feats) 148 | mask_pred = self.mask_head[i](mask_feats) 149 | outs = outs + (mask_pred, ) 150 | return outs 151 | 152 | def forward_train(self, 153 | img, 154 | img_metas, 155 | gt_bboxes, 156 | gt_labels, 157 | gt_bboxes_ignore=None, 158 | gt_masks=None, 159 | proposals=None): 160 | """ 161 | Args: 162 | img (Tensor): of shape (N, C, H, W) encoding input images. 163 | Typically these should be mean centered and std scaled. 164 | 165 | img_metas (list[dict]): list of image info dict where each dict 166 | has: 'img_shape', 'scale_factor', 'flip', and my also contain 167 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 168 | For details on the values of these keys see 169 | `mmdet/datasets/pipelines/formatting.py:Collect`. 170 | 171 | gt_bboxes (list[Tensor]): each item are the truth boxes for each 172 | image in [tl_x, tl_y, br_x, br_y] format. 173 | 174 | gt_labels (list[Tensor]): class indices corresponding to each box 175 | 176 | gt_bboxes_ignore (None | list[Tensor]): specify which bounding 177 | boxes can be ignored when computing the loss. 178 | 179 | gt_masks (None | Tensor) : true segmentation masks for each box 180 | used if the architecture supports a segmentation task. 181 | 182 | proposals : override rpn proposals with custom proposals. Use when 183 | `with_rpn` is False. 184 | 185 | Returns: 186 | dict[str, Tensor]: a dictionary of loss components 187 | """ 188 | losses = dict() 189 | with torch.no_grad(): 190 | x = self.extract_feat(img) 191 | 192 | if self.with_rpn: 193 | rpn_outs = self.rpn_head(x) 194 | 195 | proposal_cfg = self.train_cfg.get('rpn_proposal', 196 | self.test_cfg.rpn) 197 | proposal_inputs = rpn_outs + (img_metas, proposal_cfg) 198 | proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) 199 | else: 200 | proposal_list = proposals 201 | 202 | for i in range(self.num_stages): 203 | self.current_stage = i 204 | rcnn_train_cfg = self.train_cfg.rcnn[i] 205 | lw = self.train_cfg.stage_loss_weights[i] 206 | 207 | # assign gts and sample proposals 208 | sampling_results = [] 209 | if self.with_bbox or self.with_mask: 210 | bbox_assigner = build_assigner(rcnn_train_cfg.assigner) 211 | bbox_sampler = build_sampler( 212 | rcnn_train_cfg.sampler, context=self) 213 | num_imgs = img.size(0) 214 | if gt_bboxes_ignore is None: 215 | gt_bboxes_ignore = [None for _ in range(num_imgs)] 216 | 217 | for j in range(num_imgs): 218 | assign_result = bbox_assigner.assign( 219 | proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j], 220 | gt_labels[j]) 221 | sampling_result = bbox_sampler.sample( 222 | assign_result, 223 | proposal_list[j], 224 | gt_bboxes[j], 225 | gt_labels[j], 226 | feats=[lvl_feat[j][None] for lvl_feat in x]) 227 | sampling_results.append(sampling_result) 228 | 229 | # bbox head forward and loss 230 | bbox_roi_extractor = self.bbox_roi_extractor[i] 231 | bbox_head = self.bbox_head[i] 232 | 233 | rois = bbox2roi([res.bboxes for res in sampling_results]) 234 | 235 | if len(rois) == 0: 236 | # If there are no predicted and/or truth boxes, then we cannot 237 | # compute head / mask losses 238 | continue 239 | 240 | bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], 241 | rois) 242 | if self.with_shared_head: 243 | bbox_feats = self.shared_head(bbox_feats) 244 | cls_score, bbox_pred = bbox_head(bbox_feats) 245 | 246 | bbox_targets = bbox_head.get_target(sampling_results, gt_bboxes, 247 | gt_labels, rcnn_train_cfg) 248 | loss_bbox = bbox_head.loss(cls_score, bbox_pred, *bbox_targets) 249 | for name, value in loss_bbox.items(): 250 | losses['s{}.{}'.format(i, name)] = ( 251 | value * lw if 'loss' in name else value) 252 | 253 | # mask head forward and loss 254 | if self.with_mask: 255 | if not self.share_roi_extractor: 256 | mask_roi_extractor = self.mask_roi_extractor[i] 257 | pos_rois = bbox2roi( 258 | [res.pos_bboxes for res in sampling_results]) 259 | mask_feats = mask_roi_extractor( 260 | x[:mask_roi_extractor.num_inputs], pos_rois) 261 | if self.with_shared_head: 262 | mask_feats = self.shared_head(mask_feats) 263 | else: 264 | # reuse positive bbox feats 265 | pos_inds = [] 266 | device = bbox_feats.device 267 | for res in sampling_results: 268 | pos_inds.append( 269 | torch.ones( 270 | res.pos_bboxes.shape[0], 271 | device=device, 272 | dtype=torch.uint8)) 273 | pos_inds.append( 274 | torch.zeros( 275 | res.neg_bboxes.shape[0], 276 | device=device, 277 | dtype=torch.uint8)) 278 | pos_inds = torch.cat(pos_inds) 279 | mask_feats = bbox_feats[pos_inds.type(torch.bool)] 280 | mask_head = self.mask_head[i] 281 | mask_pred = mask_head(mask_feats) 282 | mask_targets = mask_head.get_target(sampling_results, gt_masks, 283 | rcnn_train_cfg) 284 | pos_labels = torch.cat( 285 | [res.pos_gt_labels for res in sampling_results]) 286 | loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels) 287 | for name, value in loss_mask.items(): 288 | losses['s{}.{}'.format(i, name)] = ( 289 | value * lw if 'loss' in name else value) 290 | 291 | # refine bboxes 292 | if i < self.num_stages - 1: 293 | pos_is_gts = [res.pos_is_gt for res in sampling_results] 294 | roi_labels = bbox_targets[0] # bbox_targets is a tuple 295 | with torch.no_grad(): 296 | proposal_list = bbox_head.refine_bboxes( 297 | rois, roi_labels, bbox_pred, pos_is_gts, img_metas) 298 | 299 | 300 | 301 | return losses 302 | 303 | def simple_test(self, img, img_metas, proposals=None, rescale=False): 304 | """Run inference on a single image. 305 | 306 | Args: 307 | img (Tensor): must be in shape (N, C, H, W) 308 | img_metas (list[dict]): a list with one dictionary element. 309 | See `mmdet/datasets/pipelines/formatting.py:Collect` for 310 | details of meta dicts. 311 | proposals : if specified overrides rpn proposals 312 | rescale (bool): if True returns boxes in original image space 313 | 314 | Returns: 315 | dict: results 316 | """ 317 | x = self.extract_feat(img) 318 | 319 | proposal_list = self.simple_test_rpn( 320 | x, img_metas, 321 | self.test_cfg.rpn) if proposals is None else proposals 322 | 323 | img_shape = img_metas[0]['img_shape'] 324 | ori_shape = img_metas[0]['ori_shape'] 325 | scale_factor = img_metas[0]['scale_factor'] 326 | 327 | # "ms" in variable names means multi-stage 328 | ms_bbox_result = {} 329 | ms_segm_result = {} 330 | ms_scores = [] 331 | rcnn_test_cfg = self.test_cfg.rcnn 332 | 333 | rois = bbox2roi(proposal_list) 334 | for i in range(self.num_stages): 335 | bbox_roi_extractor = self.bbox_roi_extractor[i] 336 | bbox_head = self.bbox_head[i] 337 | 338 | bbox_feats = bbox_roi_extractor( 339 | x[:len(bbox_roi_extractor.featmap_strides)], rois) 340 | if self.with_shared_head: 341 | bbox_feats = self.shared_head(bbox_feats) 342 | 343 | cls_score, bbox_pred = bbox_head(bbox_feats) 344 | ms_scores.append(cls_score) 345 | 346 | if i < self.num_stages - 1: 347 | bbox_label = cls_score.argmax(dim=1) 348 | rois = bbox_head.regress_by_class(rois, bbox_label, bbox_pred, 349 | img_metas[0]) 350 | 351 | cls_score = sum(ms_scores) / self.num_stages 352 | det_bboxes, det_labels = self.bbox_head[-1].get_det_bboxes( 353 | rois, 354 | cls_score, 355 | bbox_pred, 356 | img_shape, 357 | scale_factor, 358 | rescale=rescale, 359 | cfg=rcnn_test_cfg) 360 | bbox_result = bbox2result(det_bboxes, det_labels, 361 | self.bbox_head[-1].num_classes) 362 | ms_bbox_result['ensemble'] = bbox_result 363 | 364 | if self.with_mask: 365 | if det_bboxes.shape[0] == 0: 366 | mask_classes = self.mask_head[-1].num_classes - 1 367 | segm_result = [[] for _ in range(mask_classes)] 368 | else: 369 | if isinstance(scale_factor, float): # aspect ratio fixed 370 | _bboxes = ( 371 | det_bboxes[:, :4] * 372 | scale_factor if rescale else det_bboxes) 373 | else: 374 | _bboxes = ( 375 | det_bboxes[:, :4] * 376 | torch.from_numpy(scale_factor).to(det_bboxes.device) 377 | if rescale else det_bboxes) 378 | 379 | mask_rois = bbox2roi([_bboxes]) 380 | aug_masks = [] 381 | for i in range(self.num_stages): 382 | mask_roi_extractor = self.mask_roi_extractor[i] 383 | mask_feats = mask_roi_extractor( 384 | x[:len(mask_roi_extractor.featmap_strides)], mask_rois) 385 | if self.with_shared_head: 386 | mask_feats = self.shared_head(mask_feats) 387 | mask_pred = self.mask_head[i](mask_feats) 388 | aug_masks.append(mask_pred.sigmoid().cpu().numpy()) 389 | merged_masks = merge_aug_masks(aug_masks, 390 | [img_metas] * self.num_stages, 391 | self.test_cfg.rcnn) 392 | segm_result = self.mask_head[-1].get_seg_masks( 393 | merged_masks, _bboxes, det_labels, rcnn_test_cfg, 394 | ori_shape, scale_factor, rescale) 395 | ms_segm_result['ensemble'] = segm_result 396 | 397 | if self.with_mask: 398 | results = (ms_bbox_result['ensemble'], ms_segm_result['ensemble']) 399 | else: 400 | results = ms_bbox_result['ensemble'] 401 | 402 | return results 403 | 404 | def aug_test(self, imgs, img_metas, proposals=None, rescale=False): 405 | """Test with augmentations. 406 | 407 | If rescale is False, then returned bboxes and masks will fit the scale 408 | of imgs[0]. 409 | """ 410 | # recompute feats to save memory 411 | proposal_list = self.aug_test_rpn( 412 | self.extract_feats(imgs), img_metas, self.test_cfg.rpn) 413 | 414 | rcnn_test_cfg = self.test_cfg.rcnn 415 | aug_bboxes = [] 416 | aug_scores = [] 417 | for x, img_meta in zip(self.extract_feats(imgs), img_metas): 418 | # only one image in the batch 419 | img_shape = img_meta[0]['img_shape'] 420 | scale_factor = img_meta[0]['scale_factor'] 421 | flip = img_meta[0]['flip'] 422 | 423 | proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, 424 | scale_factor, flip) 425 | # "ms" in variable names means multi-stage 426 | ms_scores = [] 427 | 428 | rois = bbox2roi([proposals]) 429 | for i in range(self.num_stages): 430 | bbox_roi_extractor = self.bbox_roi_extractor[i] 431 | bbox_head = self.bbox_head[i] 432 | 433 | bbox_feats = bbox_roi_extractor( 434 | x[:len(bbox_roi_extractor.featmap_strides)], rois) 435 | if self.with_shared_head: 436 | bbox_feats = self.shared_head(bbox_feats) 437 | 438 | cls_score, bbox_pred = bbox_head(bbox_feats) 439 | ms_scores.append(cls_score) 440 | 441 | if i < self.num_stages - 1: 442 | bbox_label = cls_score.argmax(dim=1) 443 | rois = bbox_head.regress_by_class(rois, bbox_label, 444 | bbox_pred, img_meta[0]) 445 | 446 | cls_score = sum(ms_scores) / float(len(ms_scores)) 447 | bboxes, scores = self.bbox_head[-1].get_det_bboxes( 448 | rois, 449 | cls_score, 450 | bbox_pred, 451 | img_shape, 452 | scale_factor, 453 | rescale=False, 454 | cfg=None) 455 | aug_bboxes.append(bboxes) 456 | aug_scores.append(scores) 457 | 458 | # after merging, bboxes will be rescaled to the original image size 459 | merged_bboxes, merged_scores = merge_aug_bboxes( 460 | aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) 461 | det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, 462 | rcnn_test_cfg.score_thr, 463 | rcnn_test_cfg.nms, 464 | rcnn_test_cfg.max_per_img) 465 | 466 | bbox_result = bbox2result(det_bboxes, det_labels, 467 | self.bbox_head[-1].num_classes) 468 | 469 | if self.with_mask: 470 | if det_bboxes.shape[0] == 0: 471 | segm_result = [[] 472 | for _ in range(self.mask_head[-1].num_classes - 473 | 1)] 474 | else: 475 | aug_masks = [] 476 | aug_img_metas = [] 477 | for x, img_meta in zip(self.extract_feats(imgs), img_metas): 478 | img_shape = img_meta[0]['img_shape'] 479 | scale_factor = img_meta[0]['scale_factor'] 480 | flip = img_meta[0]['flip'] 481 | _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, 482 | scale_factor, flip) 483 | mask_rois = bbox2roi([_bboxes]) 484 | for i in range(self.num_stages): 485 | mask_feats = self.mask_roi_extractor[i]( 486 | x[:len(self.mask_roi_extractor[i].featmap_strides 487 | )], mask_rois) 488 | if self.with_shared_head: 489 | mask_feats = self.shared_head(mask_feats) 490 | mask_pred = self.mask_head[i](mask_feats) 491 | aug_masks.append(mask_pred.sigmoid().cpu().numpy()) 492 | aug_img_metas.append(img_meta) 493 | merged_masks = merge_aug_masks(aug_masks, aug_img_metas, 494 | self.test_cfg.rcnn) 495 | 496 | ori_shape = img_metas[0][0]['ori_shape'] 497 | segm_result = self.mask_head[-1].get_seg_masks( 498 | merged_masks, 499 | det_bboxes, 500 | det_labels, 501 | rcnn_test_cfg, 502 | ori_shape, 503 | scale_factor=1.0, 504 | rescale=False) 505 | return bbox_result, segm_result 506 | else: 507 | return bbox_result 508 | 509 | def show_result(self, data, result, **kwargs): 510 | if self.with_mask: 511 | ms_bbox_result, ms_segm_result = result 512 | if isinstance(ms_bbox_result, dict): 513 | result = (ms_bbox_result['ensemble'], 514 | ms_segm_result['ensemble']) 515 | else: 516 | if isinstance(result, dict): 517 | result = result['ensemble'] 518 | super(CascadeRCNNFrozenRPN, self).show_result(data, result, **kwargs) 519 | -------------------------------------------------------------------------------- /dla/src/installation_files/dataset__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_dataloader, build_dataset 2 | from .cityscapes import CityscapesDataset 3 | from .coco import CocoDataset 4 | from .custom import CustomDataset 5 | from .dataset_wrappers import ConcatDataset, RepeatDataset 6 | from .registry import DATASETS 7 | from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler 8 | from .voc import VOCDataset 9 | from .wider_face import WIDERFaceDataset 10 | from .xml_style import XMLDataset 11 | from .ignoringvoc import IgnoringVOCDataset 12 | 13 | __all__ = [ 14 | 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 15 | 'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler', 16 | 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 17 | 'WIDERFaceDataset', 'DATASETS', 'build_dataset', 'IgnoringVOCDataset' 18 | ] 19 | -------------------------------------------------------------------------------- /dla/src/installation_files/detectors__init__.py: -------------------------------------------------------------------------------- 1 | from .atss import ATSS 2 | from .base import BaseDetector 3 | from .cascade_rcnn import CascadeRCNN 4 | from .double_head_rcnn import DoubleHeadRCNN 5 | from .fast_rcnn import FastRCNN 6 | from .faster_rcnn import FasterRCNN 7 | from .fcos import FCOS 8 | from .fovea import FOVEA 9 | from .grid_rcnn import GridRCNN 10 | from .htc import HybridTaskCascade 11 | from .mask_rcnn import MaskRCNN 12 | from .mask_scoring_rcnn import MaskScoringRCNN 13 | from .reppoints_detector import RepPointsDetector 14 | from .retinanet import RetinaNet 15 | from .rpn import RPN 16 | from .single_stage import SingleStageDetector 17 | from .two_stage import TwoStageDetector 18 | from .cascade_rcnn_frozen import CascadeRCNNFrozen 19 | from .cascade_rcnn_frozenrpn import CascadeRCNNFrozenRPN 20 | 21 | __all__ = [ 22 | 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 23 | 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 24 | 'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 25 | 'RepPointsDetector', 'FOVEA', 'CascadeRCNNFrozen', 'CascadeRCNNFrozenRPN' 26 | ] 27 | -------------------------------------------------------------------------------- /dla/src/installation_files/ignoringvoc.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import xml.etree.ElementTree as ET 3 | 4 | import mmcv 5 | import numpy as np 6 | from mmdet.core import eval_map, eval_recalls 7 | from PIL import Image 8 | 9 | from .custom import CustomDataset 10 | from .registry import DATASETS 11 | 12 | 13 | @DATASETS.register_module() 14 | class IgnoringXMLDataset(CustomDataset): 15 | 16 | def __init__(self, min_size=None,ignore=[], **kwargs): 17 | 18 | super(IgnoringXMLDataset, self).__init__(**kwargs) 19 | self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} 20 | self.min_size = min_size 21 | self.ignore = ignore 22 | 23 | def load_annotations(self, ann_file): 24 | img_infos = [] 25 | img_ids = mmcv.list_from_file(ann_file) 26 | for img_id in img_ids: 27 | filename = 'JPEGImages/{}.jpg'.format(img_id) 28 | xml_path = osp.join(self.img_prefix, 'Annotations', 29 | '{}.xml'.format(img_id)) 30 | tree = ET.parse(xml_path) 31 | root = tree.getroot() 32 | size = root.find('size') 33 | width = 0 34 | height = 0 35 | if size is not None: 36 | width = int(size.find('width').text) 37 | height = int(size.find('height').text) 38 | else: 39 | img_path = osp.join(self.img_prefix, 'JPEGImages', 40 | '{}.jpg'.format(img_id)) 41 | img = Image.open(img_path) 42 | width, height = img.size 43 | img_infos.append( 44 | dict(id=img_id, filename=filename, width=width, height=height)) 45 | return img_infos 46 | 47 | def get_ann_info(self, idx): 48 | img_id = self.img_infos[idx]['id'] 49 | xml_path = osp.join(self.img_prefix, 'Annotations', 50 | '{}.xml'.format(img_id)) 51 | tree = ET.parse(xml_path) 52 | root = tree.getroot() 53 | bboxes = [] 54 | labels = [] 55 | bboxes_ignore = [] 56 | labels_ignore = [] 57 | for obj in root.findall('object'): 58 | name = obj.find('name').text 59 | label = self.cat2label[name] 60 | difficult = int(obj.find('difficult').text) 61 | bnd_box = obj.find('bndbox') 62 | # Coordinates may be float type 63 | bbox = [ 64 | int(float(bnd_box.find('xmin').text)), 65 | int(float(bnd_box.find('ymin').text)), 66 | int(float(bnd_box.find('xmax').text)), 67 | int(float(bnd_box.find('ymax').text)) 68 | ] 69 | ignore = name in self.ignore 70 | if self.min_size: 71 | assert not self.test_mode 72 | w = bbox[2] - bbox[0] 73 | h = bbox[3] - bbox[1] 74 | if w < self.min_size or h < self.min_size: 75 | ignore = True 76 | if difficult or ignore: 77 | bboxes_ignore.append(bbox) 78 | labels_ignore.append(label) 79 | else: 80 | bboxes.append(bbox) 81 | labels.append(label) 82 | if not bboxes: 83 | bboxes = np.zeros((0, 4)) 84 | labels = np.zeros((0, )) 85 | else: 86 | bboxes = np.array(bboxes, ndmin=2) - 1 87 | labels = np.array(labels) 88 | if not bboxes_ignore: 89 | bboxes_ignore = np.zeros((0, 4)) 90 | labels_ignore = np.zeros((0, )) 91 | else: 92 | bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 93 | labels_ignore = np.array(labels_ignore) 94 | ann = dict( 95 | bboxes=bboxes.astype(np.float32), 96 | labels=labels.astype(np.int64), 97 | bboxes_ignore=bboxes_ignore.astype(np.float32), 98 | labels_ignore=labels_ignore.astype(np.int64)) 99 | return ann 100 | 101 | 102 | @DATASETS.register_module() 103 | class IgnoringVOCDataset(IgnoringXMLDataset): 104 | 105 | CLASSES = ('table_body','cell','full_table','header','heading') 106 | 107 | def __init__(self, ignore=[],**kwargs): 108 | super(IgnoringVOCDataset, self).__init__(ignore=ignore,**kwargs) 109 | if 'VOC2007' in self.img_prefix: 110 | self.year = 2007 111 | elif 'VOC2012' in self.img_prefix: 112 | self.year = 2012 113 | else: 114 | raise ValueError('Cannot infer dataset year from img_prefix') 115 | 116 | def evaluate(self, 117 | results, 118 | metric='mAP', 119 | logger=None, 120 | proposal_nums=(100, 300, 1000), 121 | iou_thr=0.5, 122 | scale_ranges=None): 123 | if not isinstance(metric, str): 124 | assert len(metric) == 1 125 | metric = metric[0] 126 | allowed_metrics = ['mAP', 'recall'] 127 | if metric not in allowed_metrics: 128 | raise KeyError('metric {} is not supported'.format(metric)) 129 | annotations = [self.get_ann_info(i) for i in range(len(self))] 130 | eval_results = {} 131 | if metric == 'mAP': 132 | assert isinstance(iou_thr, float) 133 | if self.year == 2007: 134 | ds_name = 'voc07' 135 | else: 136 | ds_name = self.dataset.CLASSES 137 | mean_ap, _ = eval_map( 138 | results, 139 | annotations, 140 | scale_ranges=None, 141 | iou_thr=iou_thr, 142 | dataset=ds_name, 143 | logger=logger) 144 | eval_results['mAP'] = mean_ap 145 | elif metric == 'recall': 146 | gt_bboxes = [ann['bboxes'] for ann in annotations] 147 | if isinstance(iou_thr, float): 148 | iou_thr = [iou_thr] 149 | recalls = eval_recalls( 150 | gt_bboxes, results, proposal_nums, iou_thr, logger=logger) 151 | for i, num in enumerate(proposal_nums): 152 | for j, iou in enumerate(iou_thr): 153 | eval_results['recall@{}@{}'.format(num, iou)] = recalls[i, 154 | j] 155 | if recalls.shape[1] > 1: 156 | ar = recalls.mean(axis=1) 157 | for i, num in enumerate(proposal_nums): 158 | eval_results['AR@{}'.format(num)] = ar[i] 159 | return eval_results 160 | -------------------------------------------------------------------------------- /dla/src/installation_files/inference.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import matplotlib.pyplot as plt 4 | import mmcv 5 | import numpy as np 6 | import pycocotools.mask as maskUtils 7 | import torch 8 | from mmcv.parallel import collate, scatter 9 | from mmcv.runner import load_checkpoint 10 | 11 | from mmdet.core import get_classes 12 | from mmdet.datasets.pipelines import Compose 13 | from mmdet.models import build_detector 14 | 15 | 16 | def init_detector(config, checkpoint=None, device='cuda:0'): 17 | """Initialize a detector from config file. 18 | 19 | Args: 20 | config (str or :obj:`mmcv.Config`): Config file path or the config 21 | object. 22 | checkpoint (str, optional): Checkpoint path. If left as None, the model 23 | will not load any weights. 24 | 25 | Returns: 26 | nn.Module: The constructed detector. 27 | """ 28 | if isinstance(config, str): 29 | config = mmcv.Config.fromfile(config) 30 | elif not isinstance(config, mmcv.Config): 31 | raise TypeError('config must be a filename or Config object, ' 32 | 'but got {}'.format(type(config))) 33 | config.model.pretrained = None 34 | model = build_detector(config.model, test_cfg=config.test_cfg) 35 | if checkpoint is not None: 36 | checkpoint = load_checkpoint(model, checkpoint) 37 | if 'CLASSES' in checkpoint.get('meta', {}): 38 | model.CLASSES = checkpoint['meta']['CLASSES'] 39 | else: 40 | warnings.warn('Class names are not saved in the checkpoint\'s ' 41 | 'meta data, use COCO classes by default.') 42 | model.CLASSES = get_classes('coco') 43 | model.cfg = config # save the config in the model for convenience 44 | model.to(device) 45 | model.eval() 46 | return model 47 | 48 | 49 | class LoadImage(object): 50 | 51 | def __call__(self, results): 52 | if isinstance(results['img'], str): 53 | results['filename'] = results['img'] 54 | else: 55 | results['filename'] = None 56 | img = mmcv.imread(results['img']) 57 | results['img'] = img 58 | results['img_shape'] = img.shape 59 | results['ori_shape'] = img.shape 60 | return results 61 | 62 | 63 | def inference_detector(model, img): 64 | """Inference image(s) with the detector. 65 | 66 | Args: 67 | model (nn.Module): The loaded detector. 68 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 69 | images. 70 | 71 | Returns: 72 | If imgs is a str, a generator will be returned, otherwise return the 73 | detection results directly. 74 | """ 75 | cfg = model.cfg 76 | device = next(model.parameters()).device # model device 77 | # build the data pipeline 78 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 79 | test_pipeline = Compose(test_pipeline) 80 | # prepare data 81 | data = dict(img=img) 82 | data = test_pipeline(data) 83 | data = scatter(collate([data], samples_per_gpu=1), [device])[0] 84 | # forward the model 85 | with torch.no_grad(): 86 | result = model(return_loss=False, rescale=True, **data) 87 | return result 88 | 89 | 90 | async def async_inference_detector(model, img): 91 | """Async inference image(s) with the detector. 92 | 93 | Args: 94 | model (nn.Module): The loaded detector. 95 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 96 | images. 97 | 98 | Returns: 99 | Awaitable detection results. 100 | """ 101 | cfg = model.cfg 102 | device = next(model.parameters()).device # model device 103 | # build the data pipeline 104 | test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] 105 | test_pipeline = Compose(test_pipeline) 106 | # prepare data 107 | data = dict(img=img) 108 | data = test_pipeline(data) 109 | data = scatter(collate([data], samples_per_gpu=1), [device])[0] 110 | 111 | # We don't restore `torch.is_grad_enabled()` value during concurrent 112 | # inference since execution can overlap 113 | torch.set_grad_enabled(False) 114 | result = await model.aforward_test(rescale=True, **data) 115 | return result 116 | 117 | 118 | # TODO: merge this method with the one in BaseDetector 119 | def show_result(img, 120 | result, 121 | class_names, 122 | score_thr=0.3, 123 | wait_time=0, 124 | show=True, 125 | out_file=None): 126 | """Visualize the detection results on the image. 127 | 128 | Args: 129 | img (str or np.ndarray): Image filename or loaded image. 130 | result (tuple[list] or list): The detection result, can be either 131 | (bbox, segm) or just bbox. 132 | class_names (list[str] or tuple[str]): A list of class names. 133 | score_thr (float): The threshold to visualize the bboxes and masks. 134 | wait_time (int): Value of waitKey param. 135 | show (bool, optional): Whether to show the image with opencv or not. 136 | out_file (str, optional): If specified, the visualization result will 137 | be written to the out file instead of shown in a window. 138 | 139 | Returns: 140 | np.ndarray or None: If neither `show` nor `out_file` is specified, the 141 | visualized image is returned, otherwise None is returned. 142 | """ 143 | assert isinstance(class_names, (tuple, list)) 144 | img = mmcv.imread(img) 145 | img = img.copy() 146 | if isinstance(result, tuple): 147 | bbox_result, segm_result = result 148 | else: 149 | bbox_result, segm_result = result, None 150 | bboxes = np.vstack(bbox_result) 151 | labels = [ 152 | np.full(bbox.shape[0], i, dtype=np.int32) 153 | for i, bbox in enumerate(bbox_result) 154 | ] 155 | labels = np.concatenate(labels) 156 | # draw segmentation masks 157 | if segm_result is not None: 158 | segms = mmcv.concat_list(segm_result) 159 | inds = np.where(bboxes[:, -1] > score_thr)[0] 160 | np.random.seed(42) 161 | color_masks = [ 162 | np.random.randint(0, 256, (1, 3), dtype=np.uint8) 163 | for _ in range(max(labels) + 1) 164 | ] 165 | for i in inds: 166 | i = int(i) 167 | color_mask = color_masks[labels[i]] 168 | mask = maskUtils.decode(segms[i]).astype(np.bool) 169 | img[mask] = img[mask] * 0.5 + color_mask * 0.5 170 | # if out_file specified, do not show image in window 171 | if out_file is not None: 172 | show = False 173 | # draw bounding boxes 174 | mmcv.imshow_det_bboxes( 175 | img, 176 | bboxes, 177 | labels, 178 | class_names=class_names, 179 | score_thr=score_thr, 180 | show=show, 181 | wait_time=wait_time, 182 | out_file=out_file) 183 | if not (show or out_file): 184 | return img 185 | 186 | 187 | def show_result_pyplot(img, 188 | result, 189 | class_names, 190 | score_thr=0.3, 191 | fig_size=(15, 10)): 192 | """Visualize the detection results on the image. 193 | 194 | Args: 195 | img (str or np.ndarray): Image filename or loaded image. 196 | result (tuple[list] or list): The detection result, can be either 197 | (bbox, segm) or just bbox. 198 | class_names (list[str] or tuple[str]): A list of class names. 199 | score_thr (float): The threshold to visualize the bboxes and masks. 200 | fig_size (tuple): Figure size of the pyplot figure. 201 | out_file (str, optional): If specified, the visualization result will 202 | be written to the out file instead of shown in a window. 203 | """ 204 | img = show_result( 205 | img, result, class_names, score_thr=score_thr, show=False) 206 | plt.figure(figsize=fig_size) 207 | plt.imshow(mmcv.bgr2rgb(img)) 208 | -------------------------------------------------------------------------------- /dla/src/installation_files/mean_ap.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import mmcv 4 | import numpy as np 5 | from terminaltables import AsciiTable 6 | 7 | from mmdet.utils import print_log 8 | from .bbox_overlaps import bbox_overlaps 9 | from .class_names import get_classes 10 | 11 | 12 | def average_precision(recalls, precisions, mode='area'): 13 | """Calculate average precision (for single or multiple scales). 14 | 15 | Args: 16 | recalls (ndarray): shape (num_scales, num_dets) or (num_dets, ) 17 | precisions (ndarray): shape (num_scales, num_dets) or (num_dets, ) 18 | mode (str): 'area' or '11points', 'area' means calculating the area 19 | under precision-recall curve, '11points' means calculating 20 | the average precision of recalls at [0, 0.1, ..., 1] 21 | 22 | Returns: 23 | float or ndarray: calculated average precision 24 | """ 25 | no_scale = False 26 | if recalls.ndim == 1: 27 | no_scale = True 28 | recalls = recalls[np.newaxis, :] 29 | precisions = precisions[np.newaxis, :] 30 | assert recalls.shape == precisions.shape and recalls.ndim == 2 31 | num_scales = recalls.shape[0] 32 | ap = np.zeros(num_scales, dtype=np.float32) 33 | if mode == 'area': 34 | zeros = np.zeros((num_scales, 1), dtype=recalls.dtype) 35 | ones = np.ones((num_scales, 1), dtype=recalls.dtype) 36 | mrec = np.hstack((zeros, recalls, ones)) 37 | mpre = np.hstack((zeros, precisions, zeros)) 38 | for i in range(mpre.shape[1] - 1, 0, -1): 39 | mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i]) 40 | for i in range(num_scales): 41 | ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0] 42 | ap[i] = np.sum( 43 | (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1]) 44 | elif mode == '11points': 45 | for i in range(num_scales): 46 | for thr in np.arange(0, 1 + 1e-3, 0.1): 47 | precs = precisions[i, recalls[i, :] >= thr] 48 | prec = precs.max() if precs.size > 0 else 0 49 | ap[i] += prec 50 | ap /= 11 51 | else: 52 | raise ValueError( 53 | 'Unrecognized mode, only "area" and "11points" are supported') 54 | if no_scale: 55 | ap = ap[0] 56 | return ap 57 | 58 | 59 | def tpfp_imagenet(det_bboxes, 60 | gt_bboxes, 61 | gt_bboxes_ignore=None, 62 | default_iou_thr=0.5, 63 | area_ranges=None): 64 | """Check if detected bboxes are true positive or false positive. 65 | 66 | Args: 67 | det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). 68 | gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). 69 | gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, 70 | of shape (k, 4). Default: None 71 | default_iou_thr (float): IoU threshold to be considered as matched for 72 | medium and large bboxes (small ones have special rules). 73 | Default: 0.5. 74 | area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, 75 | in the format [(min1, max1), (min2, max2), ...]. Default: None. 76 | 77 | Returns: 78 | tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of 79 | each array is (num_scales, m). 80 | """ 81 | # an indicator of ignored gts 82 | gt_ignore_inds = np.concatenate( 83 | (np.zeros(gt_bboxes.shape[0], dtype=np.bool), 84 | np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) 85 | # stack gt_bboxes and gt_bboxes_ignore for convenience 86 | gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) 87 | 88 | num_dets = det_bboxes.shape[0] 89 | num_gts = gt_bboxes.shape[0] 90 | if area_ranges is None: 91 | area_ranges = [(None, None)] 92 | num_scales = len(area_ranges) 93 | # tp and fp are of shape (num_scales, num_gts), each row is tp or fp 94 | # of a certain scale. 95 | tp = np.zeros((num_scales, num_dets), dtype=np.float32) 96 | fp = np.zeros((num_scales, num_dets), dtype=np.float32) 97 | if gt_bboxes.shape[0] == 0: 98 | if area_ranges == [(None, None)]: 99 | fp[...] = 1 100 | else: 101 | det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * ( 102 | det_bboxes[:, 3] - det_bboxes[:, 1] + 1) 103 | for i, (min_area, max_area) in enumerate(area_ranges): 104 | fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 105 | return tp, fp 106 | ious = bbox_overlaps(det_bboxes, gt_bboxes - 1) 107 | gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1 108 | gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1 109 | iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)), 110 | default_iou_thr) 111 | # sort all detections by scores in descending order 112 | sort_inds = np.argsort(-det_bboxes[:, -1]) 113 | for k, (min_area, max_area) in enumerate(area_ranges): 114 | gt_covered = np.zeros(num_gts, dtype=bool) 115 | # if no area range is specified, gt_area_ignore is all False 116 | if min_area is None: 117 | gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) 118 | else: 119 | gt_areas = gt_w * gt_h 120 | gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) 121 | for i in sort_inds: 122 | max_iou = -1 123 | matched_gt = -1 124 | # find best overlapped available gt 125 | for j in range(num_gts): 126 | # different from PASCAL VOC: allow finding other gts if the 127 | # best overlaped ones are already matched by other det bboxes 128 | if gt_covered[j]: 129 | continue 130 | elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou: 131 | max_iou = ious[i, j] 132 | matched_gt = j 133 | # there are 4 cases for a det bbox: 134 | # 1. it matches a gt, tp = 1, fp = 0 135 | # 2. it matches an ignored gt, tp = 0, fp = 0 136 | # 3. it matches no gt and within area range, tp = 0, fp = 1 137 | # 4. it matches no gt but is beyond area range, tp = 0, fp = 0 138 | if matched_gt >= 0: 139 | gt_covered[matched_gt] = 1 140 | if not (gt_ignore_inds[matched_gt] 141 | or gt_area_ignore[matched_gt]): 142 | tp[k, i] = 1 143 | elif min_area is None: 144 | fp[k, i] = 1 145 | else: 146 | bbox = det_bboxes[i, :4] 147 | area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1) 148 | if area >= min_area and area < max_area: 149 | fp[k, i] = 1 150 | return tp, fp 151 | 152 | 153 | def tpfp_default(det_bboxes, 154 | gt_bboxes, 155 | gt_bboxes_ignore=None, 156 | iou_thr=0.5, 157 | area_ranges=None): 158 | """Check if detected bboxes are true positive or false positive. 159 | 160 | Args: 161 | det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5). 162 | gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4). 163 | gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, 164 | of shape (k, 4). Default: None 165 | iou_thr (float): IoU threshold to be considered as matched. 166 | Default: 0.5. 167 | area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, 168 | in the format [(min1, max1), (min2, max2), ...]. Default: None. 169 | 170 | Returns: 171 | tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of 172 | each array is (num_scales, m). 173 | """ 174 | # an indicator of ignored gts 175 | gt_ignore_inds = np.concatenate( 176 | (np.zeros(gt_bboxes.shape[0], dtype=np.bool), 177 | np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) 178 | # stack gt_bboxes and gt_bboxes_ignore for convenience 179 | gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) 180 | 181 | num_dets = det_bboxes.shape[0] 182 | num_gts = gt_bboxes.shape[0] 183 | if area_ranges is None: 184 | area_ranges = [(None, None)] 185 | num_scales = len(area_ranges) 186 | # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of 187 | # a certain scale 188 | tp = np.zeros((num_scales, num_dets), dtype=np.float32) 189 | fp = np.zeros((num_scales, num_dets), dtype=np.float32) 190 | 191 | # if there is no gt bboxes in this image, then all det bboxes 192 | # within area range are false positives 193 | if gt_bboxes.shape[0] == 0: 194 | if area_ranges == [(None, None)]: 195 | fp[...] = 1 196 | else: 197 | det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0] + 1) * ( 198 | det_bboxes[:, 3] - det_bboxes[:, 1] + 1) 199 | for i, (min_area, max_area) in enumerate(area_ranges): 200 | fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 201 | return tp, fp 202 | 203 | ious = bbox_overlaps(det_bboxes, gt_bboxes) 204 | # for each det, the max iou with all gts 205 | ious_max = ious.max(axis=1) 206 | # for each det, which gt overlaps most with it 207 | ious_argmax = ious.argmax(axis=1) 208 | # sort all dets in descending order by scores 209 | sort_inds = np.argsort(-det_bboxes[:, -1]) 210 | for k, (min_area, max_area) in enumerate(area_ranges): 211 | gt_covered = np.zeros(num_gts, dtype=bool) 212 | # if no area range is specified, gt_area_ignore is all False 213 | if min_area is None: 214 | gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) 215 | else: 216 | gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * ( 217 | gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1) 218 | gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) 219 | for i in sort_inds: 220 | if ious_max[i] >= iou_thr: 221 | matched_gt = ious_argmax[i] 222 | if not (gt_ignore_inds[matched_gt] 223 | or gt_area_ignore[matched_gt]): 224 | if not gt_covered[matched_gt]: 225 | gt_covered[matched_gt] = True 226 | tp[k, i] = 1 227 | else: 228 | fp[k, i] = 1 229 | # otherwise ignore this detected bbox, tp = 0, fp = 0 230 | elif min_area is None: 231 | fp[k, i] = 1 232 | else: 233 | bbox = det_bboxes[i, :4] 234 | area = (bbox[2] - bbox[0] + 1) * (bbox[3] - bbox[1] + 1) 235 | if area >= min_area and area < max_area: 236 | fp[k, i] = 1 237 | return tp, fp 238 | 239 | 240 | def get_cls_results(det_results, annotations, class_id): 241 | """Get det results and gt information of a certain class. 242 | 243 | Args: 244 | det_results (list[list]): Same as `eval_map()`. 245 | annotations (list[dict]): Same as `eval_map()`. 246 | 247 | Returns: 248 | tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes 249 | """ 250 | cls_dets = [img_res[class_id] for img_res in det_results] 251 | cls_gts = [] 252 | cls_gts_ignore = [] 253 | for ann in annotations: 254 | gt_inds = ann['labels'] == (class_id + 1) 255 | cls_gts.append(ann['bboxes'][gt_inds, :]) 256 | 257 | if ann.get('labels_ignore', None) is not None: 258 | ignore_inds = ann['labels_ignore'] == (class_id + 1) 259 | cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) 260 | else: 261 | cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32)) 262 | 263 | return cls_dets, cls_gts, cls_gts_ignore 264 | 265 | 266 | def eval_map(det_results, 267 | annotations, 268 | scale_ranges=None, 269 | iou_thr=0.5, 270 | dataset=None, 271 | logger=None, 272 | nproc=4): 273 | """Evaluate mAP of a dataset. 274 | 275 | Args: 276 | det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. 277 | The outer list indicates images, and the inner list indicates 278 | per-class detected bboxes. 279 | annotations (list[dict]): Ground truth annotations where each item of 280 | the list indicates an image. Keys of annotations are: 281 | - "bboxes": numpy array of shape (n, 4) 282 | - "labels": numpy array of shape (n, ) 283 | - "bboxes_ignore" (optional): numpy array of shape (k, 4) 284 | - "labels_ignore" (optional): numpy array of shape (k, ) 285 | scale_ranges (list[tuple] | None): Range of scales to be evaluated, 286 | in the format [(min1, max1), (min2, max2), ...]. A range of 287 | (32, 64) means the area range between (32**2, 64**2). 288 | Default: None. 289 | iou_thr (float): IoU threshold to be considered as matched. 290 | Default: 0.5. 291 | dataset (list[str] | str | None): Dataset name or dataset classes, 292 | there are minor differences in metrics for different datsets, e.g. 293 | "voc07", "imagenet_det", etc. Default: None. 294 | logger (logging.Logger | str | None): The way to print the mAP 295 | summary. See `mmdet.utils.print_log()` for details. Default: None. 296 | nproc (int): Processes used for computing TP and FP. 297 | Default: 4. 298 | 299 | Returns: 300 | tuple: (mAP, [dict, dict, ...]) 301 | """ 302 | assert len(det_results) == len(annotations) 303 | 304 | num_imgs = len(det_results) 305 | num_scales = len(scale_ranges) if scale_ranges is not None else 1 306 | num_classes = len(det_results[0]) # positive class num 307 | area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] 308 | if scale_ranges is not None else None) 309 | 310 | pool = Pool(nproc) 311 | eval_results = [] 312 | for i in range(num_classes): 313 | # get gt and det bboxes of this class 314 | cls_dets, cls_gts, cls_gts_ignore = get_cls_results( 315 | det_results, annotations, i) 316 | # choose proper function according to datasets to compute tp and fp 317 | if dataset in ['det', 'vid']: 318 | tpfp_func = tpfp_imagenet 319 | else: 320 | tpfp_func = tpfp_default 321 | # compute tp and fp for each image with multiple processes 322 | tpfp = pool.starmap( 323 | tpfp_func, 324 | zip(cls_dets, cls_gts, cls_gts_ignore, 325 | [iou_thr for _ in range(num_imgs)], 326 | [area_ranges for _ in range(num_imgs)])) 327 | tp, fp = tuple(zip(*tpfp)) 328 | # calculate gt number of each scale 329 | # ignored gts or gts beyond the specific scale are not counted 330 | num_gts = np.zeros(num_scales, dtype=int) 331 | for j, bbox in enumerate(cls_gts): 332 | if area_ranges is None: 333 | num_gts[0] += bbox.shape[0] 334 | else: 335 | gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * ( 336 | bbox[:, 3] - bbox[:, 1] + 1) 337 | for k, (min_area, max_area) in enumerate(area_ranges): 338 | num_gts[k] += np.sum((gt_areas >= min_area) 339 | & (gt_areas < max_area)) 340 | # sort all det bboxes by score, also sort tp and fp 341 | cls_dets = np.vstack(cls_dets) 342 | num_dets = cls_dets.shape[0] 343 | sort_inds = np.argsort(-cls_dets[:, -1]) 344 | tp = np.hstack(tp)[:, sort_inds] 345 | fp = np.hstack(fp)[:, sort_inds] 346 | # calculate recall and precision with tp and fp 347 | tp = np.cumsum(tp, axis=1) 348 | fp = np.cumsum(fp, axis=1) 349 | eps = np.finfo(np.float32).eps 350 | recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) 351 | precisions = tp / np.maximum((tp + fp), eps) 352 | # calculate AP 353 | if scale_ranges is None: 354 | recalls = recalls[0, :] 355 | precisions = precisions[0, :] 356 | num_gts = num_gts.item() 357 | mode = 'area' if dataset != 'voc07' else '11points' 358 | ap = average_precision(recalls, precisions, mode) 359 | eval_results.append({ 360 | 'num_gts': num_gts, 361 | 'num_dets': num_dets, 362 | 'recall': recalls, 363 | 'precision': precisions, 364 | 'ap': ap 365 | }) 366 | if scale_ranges is not None: 367 | # shape (num_classes, num_scales) 368 | all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) 369 | all_num_gts = np.vstack( 370 | [cls_result['num_gts'] for cls_result in eval_results]) 371 | mean_ap = [] 372 | for i in range(num_scales): 373 | if np.any(all_num_gts[:, i] > 0): 374 | mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) 375 | else: 376 | mean_ap.append(0.0) 377 | else: 378 | aps = [] 379 | for cls_result in eval_results: 380 | if cls_result['num_gts'] > 0: 381 | aps.append(cls_result['ap']) 382 | mean_ap = np.array(aps).mean().item() if aps else 0.0 383 | 384 | print_map_summary( 385 | mean_ap, eval_results, dataset, area_ranges, logger=logger) 386 | 387 | return mean_ap, eval_results 388 | 389 | 390 | def print_map_summary(mean_ap, 391 | results, 392 | dataset=None, 393 | scale_ranges=None, 394 | logger=None): 395 | """Print mAP and results of each class. 396 | 397 | A table will be printed to show the gts/dets/recall/AP of each class and 398 | the mAP. 399 | 400 | Args: 401 | mean_ap (float): Calculated from `eval_map()`. 402 | results (list[dict]): Calculated from `eval_map()`. 403 | dataset (list[str] | str | None): Dataset name or dataset classes. 404 | scale_ranges (list[tuple] | None): Range of scales to be evaluated. 405 | logger (logging.Logger | str | None): The way to print the mAP 406 | summary. See `mmdet.utils.print_log()` for details. Default: None. 407 | """ 408 | 409 | if logger == 'silent': 410 | return 411 | 412 | if isinstance(results[0]['ap'], np.ndarray): 413 | num_scales = len(results[0]['ap']) 414 | else: 415 | num_scales = 1 416 | 417 | if scale_ranges is not None: 418 | assert len(scale_ranges) == num_scales 419 | 420 | num_classes = len(results) 421 | 422 | recalls = np.zeros((num_scales, num_classes), dtype=np.float32) 423 | aps = np.zeros((num_scales, num_classes), dtype=np.float32) 424 | num_gts = np.zeros((num_scales, num_classes), dtype=int) 425 | for i, cls_result in enumerate(results): 426 | if cls_result['recall'].size > 0: 427 | recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] 428 | aps[:, i] = cls_result['ap'] 429 | num_gts[:, i] = cls_result['num_gts'] 430 | 431 | if dataset is None: 432 | label_names = [str(i) for i in range(1, num_classes + 1)] 433 | elif mmcv.is_str(dataset): 434 | label_names = get_classes(dataset) 435 | else: 436 | label_names = dataset 437 | 438 | if not isinstance(mean_ap, list): 439 | mean_ap = [mean_ap] 440 | 441 | header = ['class', 'gts', 'dets', 'recall', 'ap'] 442 | for i in range(num_scales): 443 | if scale_ranges is not None: 444 | print_log('Scale range {}'.format(scale_ranges[i]), logger=logger) 445 | table_data = [header] 446 | for j in range(num_classes): 447 | try: 448 | row_data = [ 449 | label_names[j], num_gts[i, j], results[j]['num_dets'], 450 | '{:.3f}'.format(recalls[i, j]), '{:.3f}'.format(aps[i, j]) 451 | ] 452 | table_data.append(row_data) 453 | except: 454 | pass 455 | table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])]) 456 | table = AsciiTable(table_data) 457 | table.inner_footing_row_border = True 458 | print_log('\n' + table.table, logger=logger) 459 | -------------------------------------------------------------------------------- /dla/src/projection_histogram.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | import os 18 | import cv2 19 | import time 20 | import numpy as np 21 | import argparse 22 | 23 | from scipy import ndimage 24 | from PIL import Image 25 | from sklearn.cluster import KMeans 26 | 27 | try: 28 | from dla.src.image_utils import divide_cells_or_characters, find_background_contours, cc_extraction 29 | from dla.src.xml_utils import save_ICDAR_xml, load_ICDAR_xml 30 | except: 31 | from image_utils import divide_cells_or_characters, find_background_contours, cc_extraction 32 | from xml_utils import save_ICDAR_xml, load_ICDAR_xml 33 | 34 | BIN_WIDTH = 10 35 | PEAK_THRESHOLD = 0.2 36 | 37 | 38 | def run_kmeans(points,mode): 39 | kmeans = KMeans(n_clusters=2, random_state=0,max_iter=1).fit(points) 40 | 41 | if mode=="min": 42 | return kmeans.labels_==np.argmin(kmeans.cluster_centers_) 43 | elif mode=="max": 44 | return kmeans.labels_==np.argmax(kmeans.cluster_centers_) 45 | 46 | def historgram_partitioning(text_regions,image_shape,axis,plot=False,median_filter=True,invert=False,clustering=True): 47 | #Implemented algorithm described here: http://www.dlib.org/dlib/november14/klampfl/11klampfl.html 48 | 49 | projection = [] 50 | 51 | for region in text_regions: 52 | if axis=="x": 53 | projection += [i for i in range(region[0],region[2]+1)] 54 | elif axis=="y": 55 | projection += [i for i in range(region[1],region[3]+1)] 56 | 57 | histogram,bins = np.histogram(projection,bins=np.arange(image_shape[1]//BIN_WIDTH) * BIN_WIDTH) 58 | 59 | if median_filter: 60 | histogram = ndimage.median_filter(histogram,size=5) 61 | 62 | maxima = [] 63 | minima = [] 64 | 65 | max_ = max(histogram) 66 | 67 | diff = np.diff(histogram) 68 | 69 | edges = np.where(diff!=0)[0] 70 | 71 | for i in range(len(edges)-1): 72 | if diff[edges[i]] * diff[edges[i+1]] >= 0: 73 | diff[edges[i+1]] += diff[edges[i]] 74 | 75 | elif abs(diff[edges[i]])>PEAK_THRESHOLD * max_ or abs(diff[edges[i+1]])>PEAK_THRESHOLD * max_: 76 | if diff[edges[i]]<0 and diff[edges[i+1]] > 0 : 77 | minima.append((edges[i] + edges[i+1])//2) 78 | 79 | if diff[edges[i]]>0 and diff[edges[i+1]] < 0 : 80 | maxima.append((edges[i] + edges[i+1])//2) 81 | 82 | 83 | minima = np.array(minima) 84 | if np.unique(minima).size >3 and clustering: 85 | minima = minima[run_kmeans(histogram[minima].reshape(-1,1),mode="min")] 86 | 87 | maxima = np.array(maxima) 88 | if np.unique(maxima).size >3 and clustering: 89 | maxima = maxima[run_kmeans(histogram[maxima].reshape(-1,1),mode="max")] 90 | 91 | if invert: 92 | return (np.array(maxima) * BIN_WIDTH).tolist() 93 | 94 | return (np.array(minima)*BIN_WIDTH).tolist() 95 | 96 | 97 | def table_analysis(image,cc_cols=False,cc_rows=False): 98 | cells_n_characters = find_background_contours(image) 99 | _, characters = divide_cells_or_characters(cells_n_characters,image.shape) 100 | 101 | columns = historgram_partitioning(characters,image.shape,axis="x") 102 | rows = historgram_partitioning(characters,image.shape,axis="y") 103 | 104 | cc = cc_extraction(image) 105 | v_boundaries = [[c[0],c[1],c[0]+1,c[3]] for c in cc] + [[c[2],c[1],c[2]+1,c[3]] for c in cc] 106 | h_boundaries = [[c[0],c[1],c[2],c[1]+1] for c in cc] + [[c[0],c[3],c[2],c[3]+1] for c in cc] 107 | 108 | v_edges = historgram_partitioning(v_boundaries,image.shape,axis="x",median_filter=False,invert=True,clustering=False) 109 | h_edges = historgram_partitioning(h_boundaries,image.shape,axis="y",median_filter=False,invert=True,clustering=False) 110 | 111 | return columns + (v_edges if cc_cols else []), rows + (h_edges if cc_rows else []) 112 | 113 | 114 | if __name__ == "__main__": 115 | 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('image_path') 118 | parser.add_argument('table_path') 119 | parser.add_argument('save_to') 120 | parser.add_argument('--cc_rows',type=bool,dest="use_cc_rows",action="store_true") 121 | parser.add_argument('--cc_cols',type=bool,dest="use_cc_cols",action="store_true") 122 | args = parser.parse_args() 123 | 124 | image_path = args.image_path 125 | table_path = args.table_path 126 | save_to = args.save_to 127 | 128 | use_cc_cols = args.use_cc_cols 129 | use_cc_rows = args.use_cc_rows 130 | 131 | for image_name in os.listdir(image_path): 132 | if image_name.endswith(".xml"): 133 | continue 134 | if image_name.split(".")[-2] + ".xml" in os.listdir(table_path): 135 | tables = load_ICDAR_xml(table_path + image_name.split(".")[-2] + ".xml") 136 | else: 137 | continue 138 | 139 | image = cv2.imread(image_path + image_name) 140 | 141 | cols = [] 142 | rows = [] 143 | tabs = [] 144 | 145 | for table in tables: 146 | table = tuple(max(pos,0) for pos in table["region"]) 147 | tabs.append(table) 148 | 149 | image_cut = image[(int)(table[1]):(int)(table[3]),(int)(table[0]):(int)(table[2])] 150 | output = table_analysis(cv2.cvtColor(image_cut,cv2.COLOR_BGR2GRAY),cc_rows=use_cc_rows,cc_cols=use_cc_cols) 151 | 152 | cols.append([col + table[0] for col in output[0]] + [table[0],table[2]]) 153 | rows.append([row + table[1] for row in output[1]] + [table[1],table[3]]) 154 | 155 | save_ICDAR_xml(tabs,cols,rows,save_to + image_name.split(".")[-2] + ".xml") -------------------------------------------------------------------------------- /dla/src/table_structure_analysis.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | import numpy as np 18 | import warnings 19 | from sklearn.cluster import DBSCAN 20 | 21 | area = lambda box: (box[2]-box[0]) * (box[3] - box[1]) if box[2]>=box[0] and box[3]>=box[1] else 0 22 | 23 | def run_dbs_1D(cells:list, eps:int,include_outliers=True,min_samples=2) -> list: 24 | ''' 25 | Runs DBScan in 1d and returns the average values for each label. 26 | If outliers are detected (label = -1), each of them is appended to the average values. 27 | ''' 28 | 29 | centers = np.array([cells]).reshape(-1,1) 30 | labels = DBSCAN(eps=eps, min_samples=min_samples).fit_predict(centers) 31 | 32 | examples_by_label = {label:[] for label in labels} 33 | mean_by_label = dict() 34 | 35 | for no, center in enumerate(centers): 36 | examples_by_label[labels[no]].append(center) 37 | 38 | for label in examples_by_label: 39 | if label!= -1 : 40 | mean_by_label[label] = sum(examples_by_label[label])/len(examples_by_label[label]) 41 | 42 | 43 | return list(mean_by_label.values()) + (examples_by_label[-1] if -1 in examples_by_label.keys() and include_outliers else []) 44 | 45 | 46 | def reconstruct_table(cells:list,table:list,eps:int) -> (list,list): 47 | ''' 48 | Reconstructs the cells, given the table region using DBScan with hyperparmeter eps. 49 | 50 | ''' 51 | table_width = table[2] - table[0] 52 | table_height = table[3] - table[1] 53 | 54 | #Normalise cells 55 | cells = [[(cell[0]-table[0])/table_width,(cell[1]-table[1])/table_height,(cell[2]-table[0])/table_width,(cell[3]-table[1])/table_height] for cell in cells] 56 | 57 | cells_x = [0,1] 58 | cells_y = [0,1] 59 | for cell in cells: 60 | cells_x += [cell[0], cell[2]] 61 | cells_y += [cell[1], cell[3]] 62 | 63 | eps_x, eps_y = check_hyperparams(cells,eps) 64 | 65 | rows = run_dbs_1D(cells_y,eps_y) 66 | cols = run_dbs_1D(cells_x,eps_x) 67 | 68 | rows = [(int)(row * table_height) + table[1] for row in rows] 69 | cols = [(int)(col * table_width) + table[0] for col in cols] 70 | 71 | return rows,cols 72 | 73 | def reconstruct_table_coarse_and_fine(coarse_cells:list,fine_cells:list,table:list,eps:int) -> (list,list): 74 | ''' 75 | Reconstructs the cells, given the table region using DBScan with hyperparmeter eps. 76 | 77 | ''' 78 | table_width = table[2] - table[0] 79 | table_height = table[3] - table[1] 80 | 81 | rows = [] 82 | cols = [] 83 | 84 | if fine_cells!=[]: 85 | #Normalise cells 86 | fine_cells = [[(cell[0]-table[0])/table_width,(cell[1]-table[1])/table_height,(cell[2]-table[0])/table_width,(cell[3]-table[1])/table_height] for cell in fine_cells] 87 | 88 | cells_x = [0,1] 89 | cells_y = [0,1] 90 | for cell in fine_cells: 91 | cells_x += [cell[0], cell[2]] 92 | cells_y += [cell[1], cell[3]] 93 | 94 | fine_eps_x, fine_eps_y = check_hyperparams(fine_cells,eps) 95 | 96 | rows += run_dbs_1D(cells_y,fine_eps_y) 97 | cols += run_dbs_1D(cells_x,fine_eps_x) 98 | 99 | if coarse_cells!=[]: 100 | coarse_cells = [[(cell[0]-table[0])/table_width,(cell[1]-table[1])/table_height,(cell[2]-table[0])/table_width,(cell[3]-table[1])/table_height] for cell in coarse_cells] 101 | 102 | cells_x = [0,1] 103 | cells_y = [0,1] 104 | for cell in coarse_cells: 105 | cells_x += [cell[0], cell[2]] 106 | cells_y += [cell[1], cell[3]] 107 | 108 | eps_x, eps_y = check_hyperparams(coarse_cells,eps) 109 | 110 | rows += run_dbs_1D(cells_y,eps_y) 111 | cols += run_dbs_1D(cells_x,eps_x) 112 | 113 | if fine_cells!=[]: 114 | rows = run_dbs_1D(rows,fine_eps_y) 115 | cols = run_dbs_1D(cols,fine_eps_x) 116 | 117 | elif coarse_cells!=[]: 118 | rows = run_dbs_1D(rows,eps_y) 119 | cols = run_dbs_1D(cols,eps_x) 120 | 121 | rows = [(int)(row * table_height) + table[1] for row in rows] 122 | cols = [(int)(col * table_width) + table[0] for col in cols] 123 | 124 | return rows,cols 125 | 126 | def check_hyperparams(cells:list,eps:int) -> (int,int): 127 | ''' 128 | Check whether the eps paramter is smaller than avarega width and height of cell. 129 | If one of those conditions is violated, prints a warning. 130 | Returns adjusted hyperparameters for x and y. 131 | ''' 132 | diff_x, diff_y = [], [] 133 | for cell in cells: 134 | diff_x.append(cell[2] - cell[0]) 135 | diff_y.append(cell[3] - cell[1]) 136 | 137 | avg_diff_x = sum(diff_x)/len(diff_x) 138 | avg_diff_y = sum(diff_y)/len(diff_y) 139 | 140 | if avg_diff_x/2 int: 155 | ''' 156 | Checks how much of the first box lies inside the second one. 157 | ''' 158 | 159 | area1 = area(box1) 160 | 161 | intersection_box = [max(box1[0],box2[0]), 162 | max(box1[1],box2[1]), 163 | min(box1[2],box2[2]), 164 | min(box1[3],box2[3])] 165 | 166 | intersection_area = area(intersection_box) 167 | 168 | return intersection_area/(area1) 169 | -------------------------------------------------------------------------------- /dla/src/visualise_annotations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | 4 | try: 5 | from dla.src.image_utils import put_box 6 | from dla.src.xml_utils import load_VOC_xml 7 | except ModuleNotFoundError: 8 | from image_utils import put_box 9 | from xml_utils import load_VOC_xml 10 | 11 | if __name__ == "__main__": 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--image",type=str,default="") 15 | parser.add_argument("--annotation",type=str,default="") 16 | 17 | args = parser.parse_args() 18 | 19 | image_file = args.image 20 | annotation_file = args.annotation 21 | 22 | image = cv2.imread(image_file) 23 | #image = cv2.imread("segmentation_.jpg") 24 | 25 | objects = load_VOC_xml(annotation_file) 26 | 27 | 28 | for object_ in objects: 29 | name,box = object_["name"], object_["bbox"] 30 | #if name=="header": 31 | # put_box(image,box,colour=(0,0,0),thickness=10) 32 | #if name=="heading": 33 | # put_box(image,box,colour=(0,0,0),thickness=10) 34 | #if name=="full_table": 35 | # put_box(image,box,colour=(0,0,0),thickness=10) 36 | if name=="table_body": 37 | put_box(image,box,colour=(25,25,25)) 38 | if name=="cell": 39 | put_box(image,box,colour=(25,25,25)) 40 | 41 | cv2.imwrite("vis_90.jpg",image) 42 | 43 | -------------------------------------------------------------------------------- /dla/src/xml_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ###################################################################### 5 | # 6 | # (c) Copyright University of Southampton, 2020 7 | # 8 | # Copyright in this software belongs to University of Southampton, 9 | # Highfield, University Road, Southampton SO17 1BJ 10 | # 11 | # Created By : Juliusz Ziomek 12 | # Created Date : 2020/09/09 13 | # Project : GloSAT 14 | # 15 | ###################################################################### 16 | 17 | import xml.etree.ElementTree as ET 18 | import time 19 | 20 | def save_ICDAR_xml(tables:list,cols:list,rows:list,filename:str): 21 | ''' 22 | Saves the output in ICDAR-style xml. 23 | - tables_with_cells - dictonaries in form {k:v} where k is a tuple representing table bounding box and v is a list of boxes representing table cells 24 | - filename - string containing filename under which the xml should be saved 25 | ''' 26 | 27 | root = ET.Element('document') 28 | 29 | for no,table in enumerate(tables): 30 | 31 | tab = xml_add_table(root,table) 32 | 33 | cells, row_by_cell, col_by_cell = cells_from_lines(rows[no],cols[no]) 34 | 35 | for no,cell in enumerate(cells): 36 | xml_add_cell(tab,cell,row_by_cell[no],col_by_cell[no],rowSpan=1,colSpan=1) 37 | 38 | tree = ET.ElementTree(root) 39 | tree.write(filename) 40 | 41 | 42 | def box_to_item(box): 43 | x0, y0, x1, y1 = box 44 | x0, y0, x1, y1 = (int)(x0), (int)(y0), (int)(x1), (int)(y1) 45 | 46 | result = (str)(x0) + "," + (str)(y0) 47 | result += " " + (str)(x0) + "," + (str)(y1) 48 | result += " " + (str)(x1) + "," + (str)(y0) 49 | result += " " + (str)(x1) + "," + (str)(y1) 50 | 51 | return result 52 | 53 | def item_to_box(item): 54 | 55 | for info in item: 56 | if info.tag.endswith("Coords"): 57 | coords = info 58 | try: 59 | points = coords.attrib["points"].split(" ") 60 | except: 61 | points = coords.attrib["point"].split(" ") 62 | 63 | points = [((float)(p.split(",")[0]),(float)(p.split(",")[1])) for p in points] 64 | 65 | x = [point[0] for point in points] 66 | y = [point[1] for point in points] 67 | 68 | return min(x),min(y),max(x),max(y) 69 | 70 | def cells_from_lines(rows,cols): 71 | 72 | junctions = [] 73 | 74 | for row in sorted(rows): 75 | junctions_in_row = [] 76 | for col in sorted(cols): 77 | junctions_in_row.append([col,row]) 78 | 79 | junctions.append(junctions_in_row) 80 | 81 | cells = [] 82 | row_by_cell = [] 83 | col_by_cell = [] 84 | 85 | for i in range(len(junctions)-1): 86 | for j in range(len(junctions[i])-1): 87 | cells += [junctions[i][j] + junctions[i+1][j+1]] 88 | row_by_cell.append(i) 89 | col_by_cell.append(j) 90 | 91 | return cells, row_by_cell, col_by_cell 92 | 93 | def cell_to_lines(cells): 94 | rows_by_pixels = dict() 95 | cols_by_pixels = dict() 96 | 97 | for cell in cells: 98 | if cell.attrib["start-row"] in rows_by_pixels: 99 | rows_by_pixels[cell.attrib["start-row"]].append(item_to_box(cell)[1]) 100 | else: 101 | rows_by_pixels[cell.attrib["start-row"]] = [item_to_box(cell)[1]] 102 | 103 | if cell.attrib["start-col"] in cols_by_pixels: 104 | cols_by_pixels[cell.attrib["start-col"]].append(item_to_box(cell)[0]) 105 | else: 106 | cols_by_pixels[cell.attrib["start-col"]] = [item_to_box(cell)[0]] 107 | 108 | return [sum(rows_by_pixels[row])/len(rows_by_pixels[row])for row in rows_by_pixels],[sum(cols_by_pixels[col])/len(cols_by_pixels[col])for col in cols_by_pixels] 109 | 110 | 111 | def xml_add_table(root,table): 112 | tab = ET.SubElement(root,"table") 113 | tab.attrib["id"] = "Table_" + (str)(time.time()*1000000) 114 | 115 | coords = ET.SubElement(tab,"Coords") 116 | coords.attrib["points"] = box_to_item(table) 117 | 118 | return tab 119 | 120 | def xml_add_cell(table,cell,row,col,rowSpan=1,colSpan=1): 121 | c = ET.SubElement(table,"cell") 122 | c_coords = ET.SubElement(c,"Coords") 123 | c_coords.attrib["points"] = box_to_item(cell) 124 | c.attrib["id"] = "TableCell_" + (str)(time.time()*1000000) 125 | c.attrib["start-row"] = (str)(row) 126 | c.attrib["start-col"] = (str)(col) 127 | c.attrib["end-row"] = (str)(row + rowSpan) 128 | c.attrib["end-col"] = (str)(col + colSpan) 129 | 130 | def get_ICDAR_summary(filename): 131 | root = ET.parse(filename).getroot() 132 | 133 | doc_type = root.attrib["type"] 134 | 135 | tables = [] 136 | 137 | for element in root: 138 | if element.tag == "table": 139 | table_type = element.attrib["type"] 140 | table_region = item_to_box(element) 141 | cells = [] 142 | header_cells = 0 143 | for child in element: 144 | if child.tag.endswith("cell"): 145 | if "header" in child.attrib: 146 | header_cells += 1 147 | cells.append(item_to_box(child)) 148 | 149 | tables.append({"cells":cells,"region":table_region,"type":table_type.replace("_","-"),"header_no":header_cells}) 150 | 151 | return tables, doc_type 152 | 153 | def load_ICDAR_xml(filename): 154 | root = ET.parse(filename).getroot() 155 | 156 | tables = [] 157 | 158 | for element in root: 159 | if element.tag == "table": 160 | table_region = item_to_box(element) 161 | cells = [] 162 | for child in element: 163 | if child.tag.endswith("cell"): 164 | cells.append(item_to_box(child)) 165 | 166 | tables.append({"cells":cells,"region":table_region}) 167 | 168 | return tables 169 | 170 | def load_ICDAR_xml_lines(filename): 171 | root = ET.parse(filename).getroot() 172 | 173 | tables = [] 174 | 175 | for element in root: 176 | if element.tag == "table": 177 | table_region = item_to_box(element) 178 | cells = [] 179 | for child in element: 180 | if child.tag.endswith("cell"): 181 | cells.append(child) 182 | 183 | rows,cols = cell_to_lines(cells) 184 | tables.append({"rows":rows,"cols":cols,"region":table_region}) 185 | 186 | return tables 187 | 188 | def load_VOC_xml(filename): 189 | root = ET.parse(filename).getroot() 190 | 191 | objects = [] 192 | 193 | for element in root: 194 | if element.tag == "object": 195 | for child in element: 196 | if child.tag == "name": 197 | name = child.text 198 | 199 | if child.tag == "bndbox": 200 | for dim in child: 201 | if dim.tag == "xmin": 202 | xmin = (int)(dim.text) 203 | 204 | if dim.tag == "xmax": 205 | xmax = (int)(dim.text) 206 | 207 | if dim.tag == "ymin": 208 | ymin = (int)(dim.text) 209 | 210 | if dim.tag == "ymax": 211 | ymax = (int)(dim.text) 212 | 213 | objects.append({"name":name,"bbox":[xmin,ymin,xmax,ymax]}) 214 | 215 | return objects 216 | 217 | def add_VOC_object(bbox,name_text,root): 218 | obj = ET.SubElement(root,"object") 219 | name = ET.SubElement(obj,"name") 220 | pose = ET.SubElement(obj,"pose") 221 | truncated = ET.SubElement(obj,"truncated") 222 | difficult = ET.SubElement(obj,"difficult") 223 | 224 | name.text = name_text 225 | pose.text = "Unspecified" 226 | truncated.text = "0" 227 | difficult.text = "0" 228 | bndbox = ET.SubElement(obj,"bndbox") 229 | 230 | xmin = ET.SubElement(bndbox,"xmin") 231 | ymin = ET.SubElement(bndbox,"ymin") 232 | xmax = ET.SubElement(bndbox,"xmax") 233 | ymax = ET.SubElement(bndbox,"ymax") 234 | 235 | xmin.text = (str)((int)(bbox[0])) 236 | ymin.text = (str)((int)(bbox[1])) 237 | xmax.text = (str)((int)(bbox[2])) 238 | ymax.text = (str)((int)(bbox[3])) 239 | 240 | 241 | def add_VOC_intro(root,width_,height_,filename): 242 | folder = ET.SubElement(root,"folder") 243 | folder.text = "JPEGImages" 244 | filename = ET.SubElement(root,"filename") 245 | filename.text = filename 246 | path = ET.SubElement(root,"path") 247 | path.text = "VOC/" 248 | source = ET.SubElement(root,"source") 249 | database = ET.SubElement(source,"database") 250 | database.text = "GloSAT" 251 | size = ET.SubElement(root,"size") 252 | width = ET.SubElement(size,"width") 253 | height = ET.SubElement(size,"height") 254 | depth = ET.SubElement(size,"depth") 255 | 256 | width.text = (str)(width_) 257 | height.text = (str)(height_) 258 | depth.text = "3" 259 | 260 | segmented = ET.SubElement(root,"segmented") 261 | segmented.text = "0" 262 | 263 | 264 | def save_VOC_xml(headings:list,headers:list,bodies:list,full_tables:list,cols:list,rows:list,filename:str,width:int,height:int): 265 | root = ET.Element('annotation') 266 | add_VOC_intro(root,width,height,filename) 267 | 268 | for heading in headings: 269 | add_VOC_object(heading,"heading",root) 270 | 271 | for header in headers: 272 | add_VOC_object(header,"header",root) 273 | 274 | for no,body in enumerate(full_tables): 275 | add_VOC_object(body,"full_table",root) 276 | 277 | for no,body in enumerate(bodies): 278 | add_VOC_object(body,"table_body",root) 279 | 280 | cells, _, _ = cells_from_lines(rows[no],cols[no]) 281 | 282 | for cell in cells: 283 | add_VOC_object(cell,"cell",root) 284 | 285 | tree = ET.ElementTree(root) 286 | tree.write(filename) 287 | 288 | def save_VOC_xml_from_cells(headings:list,headers:list,bodies:list,full_tables:list,cells:list,filename:str,width:int,height:int): 289 | root = ET.Element('annotation') 290 | add_VOC_intro(root,width,height,filename) 291 | 292 | for heading in headings: 293 | add_VOC_object(heading,"heading",root) 294 | 295 | for header in headers: 296 | add_VOC_object(header,"header",root) 297 | 298 | for no,body in enumerate(full_tables): 299 | add_VOC_object(body,"full_table",root) 300 | 301 | for no,body in enumerate(bodies): 302 | add_VOC_object(body,"table_body",root) 303 | 304 | 305 | for cell in cells: 306 | add_VOC_object(cell,"cell",root) 307 | 308 | tree = ET.ElementTree(root) 309 | tree.write(filename) -------------------------------------------------------------------------------- /examples/table_detection_105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_detection_105.jpg -------------------------------------------------------------------------------- /examples/table_detection_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_detection_3.jpg -------------------------------------------------------------------------------- /examples/table_detection_300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_detection_300.jpg -------------------------------------------------------------------------------- /examples/table_detection_enhanced_105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_detection_enhanced_105.jpg -------------------------------------------------------------------------------- /examples/table_detection_enhanced_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_detection_enhanced_3.jpg -------------------------------------------------------------------------------- /examples/table_detection_enhanced_300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_detection_enhanced_300.jpg -------------------------------------------------------------------------------- /examples/table_struct_recog_coarse_105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_struct_recog_coarse_105.jpg -------------------------------------------------------------------------------- /examples/table_struct_recog_coarse_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_struct_recog_coarse_3.jpg -------------------------------------------------------------------------------- /examples/table_struct_recog_coarse_300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_struct_recog_coarse_300.jpg -------------------------------------------------------------------------------- /examples/table_struct_recog_fine_105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_struct_recog_fine_105.jpg -------------------------------------------------------------------------------- /examples/table_struct_recog_fine_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_struct_recog_fine_3.jpg -------------------------------------------------------------------------------- /examples/table_struct_recog_fine_300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stuartemiddleton/glosat_table_dataset/fed2f0c6c75a567476e734cb878348cf6fdbae62/examples/table_struct_recog_fine_300.jpg --------------------------------------------------------------------------------