├── adet ├── utils │ ├── __init__.py │ ├── comm.py │ ├── structure_utils.py │ ├── misc.py │ ├── curve_utils.py │ ├── polygon_utils.py │ ├── hooks.py │ └── visualizer.py ├── modeling │ ├── model │ │ ├── __init__.py │ │ └── utils.py │ ├── utils │ │ ├── __init__.py │ │ ├── build_semi.py │ │ ├── dist_utils.py │ │ └── boundary_utils.py │ ├── __init__.py │ └── semi_text_spotter.py ├── __init__.py ├── evaluation │ └── __init__.py ├── data │ ├── samplers │ │ └── __init__.py │ ├── __init__.py │ ├── geo_utils.py │ ├── detection_utils.py │ └── builtin.py ├── checkpoint │ ├── __init__.py │ └── adet_checkpoint.py ├── config │ ├── __init__.py │ ├── config.py │ ├── defaults.py │ └── semi_defaults.py └── layers │ ├── __init__.py │ ├── csrc │ ├── cuda_version.cu │ ├── DeformAttn │ │ ├── ms_deform_attn_cuda.h │ │ ├── ms_deform_attn_cpu.h │ │ ├── ms_deform_attn_cpu.cpp │ │ ├── ms_deform_attn.h │ │ └── ms_deform_attn_cuda.cu │ └── vision.cpp │ └── pos_encoding.py ├── figs └── framework.jpg ├── requirements.txt ├── configs └── R_50 │ ├── pretrain │ └── 150k.yaml │ ├── Base_det.yaml │ ├── TotalText │ └── SemiETS │ │ ├── SemiETS_1s.yaml │ │ ├── SemiETS_2s.yaml │ │ ├── SemiETS_5s.yaml │ │ ├── SemiETS_10s.yaml │ │ └── SemiETS_0.5s.yaml │ ├── CTW1500 │ └── SemiETS │ │ ├── SemiETS_1s.yaml │ │ ├── SemiETS_2s.yaml │ │ ├── SemiETS_5s.yaml │ │ ├── SemiETS_10s.yaml │ │ └── SemiETS_0.5s.yaml │ └── IC15 │ └── SemiETS │ ├── SemiETS_2s.yaml │ ├── SemiETS_1s.yaml │ ├── SemiETS_0.5s.yaml │ ├── SemiETS_5s.yaml │ └── SemiETS_10s.yaml ├── tools ├── convert.py └── train_net.py ├── LICENSE ├── setup.py ├── .gitignore ├── README.md └── datasets └── ic15 ├── train_37voc_0.5_labeled.json └── train_37voc_1_labeled.json /adet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /adet/modeling/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /adet/__init__.py: -------------------------------------------------------------------------------- 1 | from adet import modeling 2 | 3 | __version__ = "0.1.1" 4 | -------------------------------------------------------------------------------- /adet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_evaluation_all import TextEvaluator 2 | -------------------------------------------------------------------------------- /figs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrLuo/SemiETS/HEAD/figs/framework.jpg -------------------------------------------------------------------------------- /adet/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .semi_samplers import MultiSourceSampler, GroupMultiSourceSampler 2 | -------------------------------------------------------------------------------- /adet/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .adet_checkpoint import AdetCheckpointer 2 | 3 | __all__ = ["AdetCheckpointer"] 4 | -------------------------------------------------------------------------------- /adet/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import get_cfg, get_cfg_semi 2 | 3 | __all__ = [ 4 | "get_cfg", "get_cfg_semi" 5 | ] 6 | -------------------------------------------------------------------------------- /adet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ms_deform_attn import MSDeformAttn 2 | 3 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /adet/layers/csrc/cuda_version.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace adet { 4 | int get_cudart_version() { 5 | return CUDART_VERSION; 6 | } 7 | } // namespace adet 8 | -------------------------------------------------------------------------------- /adet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import builtin # ensure the builtin datasets are registered 2 | from .dataset_mapper import DatasetMapperWithBasis 3 | 4 | 5 | __all__ = ["DatasetMapperWithBasis"] 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools==59.5.0 2 | editdistance==0.6.2 3 | matplotlib==3.3.3 4 | numba==0.51.2 5 | numpy==1.23.5 6 | opencv-python==4.5.5.62 7 | pillow==9.0.1 8 | polygon3==3.0.9.1 9 | rapidfuzz==2.13.7 10 | scipy==1.5.2 11 | scikit-image==0.15.0 12 | scikit-learn==0.23.2 13 | shapely==2.0.0 14 | timm==0.5.4 15 | tqdm==4.53.0 -------------------------------------------------------------------------------- /adet/modeling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # from .bbox_utils import Transform2D, filter_invalid, filter_invalid_class_wise, filter_ignore, filter_ignore_class_wise, filter_invalid_soft_label, filter_invalid_with_index 2 | from .dist_utils import concat_all_gather, concat_all_gather_equal_size 3 | from .build_semi import META_ARCH_REGISTRY, build_semi_wrapper 4 | -------------------------------------------------------------------------------- /adet/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .text_spotter import TransformerPureDetector, TransformerPureDetectorV2 3 | from .SemiETS import SemiETSTextSpotter 4 | 5 | 6 | _EXCLUDE = {"torch", "ShapeSpec"} 7 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] 8 | -------------------------------------------------------------------------------- /adet/config/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode 2 | 3 | 4 | def get_cfg() -> CfgNode: 5 | """ 6 | Get a copy of the default config. 7 | 8 | Returns: 9 | a detectron2 CfgNode instance. 10 | """ 11 | from .defaults import _C 12 | 13 | return _C.clone() 14 | 15 | def get_cfg_semi() -> CfgNode: 16 | from .semi_defaults import _C 17 | return _C.clone() -------------------------------------------------------------------------------- /configs/R_50/pretrain/150k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | 7 | DATASETS: 8 | TRAIN: ("syntext1","syntext2",) 9 | TEST: ("totaltext_test",) 10 | 11 | SOLVER: 12 | IMS_PER_BATCH: 8 13 | BASE_LR: 1e-4 14 | LR_BACKBONE: 1e-5 15 | WARMUP_ITERS: 0 16 | STEPS: (300000,) 17 | MAX_ITER: 350000 18 | CHECKPOINT_PERIOD: 1000 19 | 20 | TEST: 21 | EVAL_PERIOD: 10000 22 | 23 | OUTPUT_DIR: "output/R50/150k/pretrain" -------------------------------------------------------------------------------- /adet/modeling/utils/build_semi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | 4 | from detectron2.modeling import META_ARCH_REGISTRY 5 | from detectron2.utils.logger import _log_api_usage 6 | # from detectron2.utils.registry import Registry 7 | # 8 | # META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip 9 | # META_ARCH_REGISTRY.__doc__ = """ 10 | # Registry for meta-architectures, i.e. the whole model. 11 | # 12 | # The registered object will be called with `obj(cfg)` 13 | # and expected to return a `nn.Module` object. 14 | # """ 15 | 16 | 17 | def build_semi_wrapper(cfg): 18 | """ 19 | Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. 20 | Note that it does not load any weights from ``cfg``. 21 | """ 22 | meta_arch = cfg.SSL.SEMI_WRAPPER 23 | model = META_ARCH_REGISTRY.get(meta_arch)(cfg) 24 | model.to(torch.device(cfg.MODEL.DEVICE)) 25 | _log_api_usage("modeling.semi_supervised_method." + meta_arch) 26 | return model -------------------------------------------------------------------------------- /tools/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | def convert_ckpt_semi(source_path, dest_path): 5 | 6 | source_weights = torch.load(source_path, map_location="cpu")['model'] 7 | converted_weights = {} 8 | keys = list(source_weights.keys()) 9 | 10 | for key in keys: 11 | key_s = 'student.' + key 12 | key_t = 'teacher.' + key 13 | converted_weights[key_s] = source_weights[key] 14 | converted_weights[key_t] = source_weights[key] 15 | 16 | torch.save(converted_weights, dest_path) 17 | 18 | def parse_args(): 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser( 22 | description="Convert pre-trained checkpoint to initialize Teacher-Student", 23 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 24 | ) 25 | 26 | parser.add_argument("--source_ckpt", type=str, default="model_final.pth") 27 | parser.add_argument("--dest_ckpt", type=str, default="model_ts_final.pth") 28 | args = parser.parse_args() 29 | 30 | return args 31 | 32 | if __name__ == "__main__": 33 | args = parse_args() 34 | convert_ckpt_semi(args.source_ckpt, args.dest_ckpt) -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | AdelaiDet for non-commercial purposes 2 | (For commercial use, contact chhshen@gmail.com for obtaining a commerical license.) 3 | 4 | Copyright (c) 2019 the authors 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /adet/modeling/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class MLP(nn.Module): 8 | """ Very simple multi-layer perceptron (also called FFN)""" 9 | 10 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 11 | super().__init__() 12 | self.num_layers = num_layers 13 | h = [hidden_dim] * (num_layers - 1) 14 | self.layers = nn.ModuleList(nn.Linear(n, k) 15 | for n, k in zip([input_dim] + h, h + [output_dim]) 16 | ) 17 | 18 | def forward(self, x): 19 | for i, layer in enumerate(self.layers): 20 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 21 | return x 22 | 23 | 24 | def gen_point_pos_embed(pts_tensor, d_model, temp): 25 | # pts_tensor: bs, nq, n_pts, 2 26 | scale = 2 * math.pi 27 | dim = d_model // 2 28 | dim_t = torch.arange(dim, dtype=torch.float32, device=pts_tensor.device) 29 | dim_t = temp ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / dim) 30 | x_embed = pts_tensor[:, :, :, 0] * scale 31 | y_embed = pts_tensor[:, :, :, 1] * scale 32 | pos_x = x_embed[:, :, :, None] / dim_t 33 | pos_y = y_embed[:, :, :, None] / dim_t 34 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 35 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 36 | pos = torch.cat((pos_x, pos_y), dim=-1) 37 | return pos -------------------------------------------------------------------------------- /adet/layers/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #include "DeformAttn/ms_deform_attn.h" 3 | 4 | namespace adet { 5 | 6 | #ifdef WITH_CUDA 7 | extern int get_cudart_version(); 8 | #endif 9 | 10 | std::string get_cuda_version() { 11 | #ifdef WITH_CUDA 12 | std::ostringstream oss; 13 | 14 | // copied from 15 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 16 | auto printCudaStyleVersion = [&](int v) { 17 | oss << (v / 1000) << "." << (v / 10 % 100); 18 | if (v % 10 != 0) { 19 | oss << "." << (v % 10); 20 | } 21 | }; 22 | printCudaStyleVersion(get_cudart_version()); 23 | return oss.str(); 24 | #else 25 | return std::string("not available"); 26 | #endif 27 | } 28 | 29 | // similar to 30 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp 31 | std::string get_compiler_version() { 32 | std::ostringstream ss; 33 | #if defined(__GNUC__) 34 | #ifndef __clang__ 35 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } 36 | #endif 37 | #endif 38 | 39 | #if defined(__clang_major__) 40 | { 41 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." 42 | << __clang_patchlevel__; 43 | } 44 | #endif 45 | 46 | #if defined(_MSC_VER) 47 | { ss << "MSVC " << _MSC_FULL_VER; } 48 | #endif 49 | return ss.str(); 50 | } 51 | 52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 53 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 54 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 55 | } 56 | 57 | } // namespace adet 58 | -------------------------------------------------------------------------------- /configs/R_50/Base_det.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "TransformerPureDetector" 3 | MASK_ON: False 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | BACKBONE: 7 | NAME: "build_resnet_backbone" 8 | RESNETS: 9 | DEPTH: 50 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res3", "res4", "res5"] 12 | TRANSFORMER: 13 | ENABLED: True 14 | NUM_FEATURE_LEVELS: 4 15 | TEMPERATURE: 10000 16 | ENC_LAYERS: 6 17 | DEC_LAYERS: 6 18 | EMB_LAYERS: 3 19 | DIM_FEEDFORWARD: 1024 20 | HIDDEN_DIM: 256 21 | DROPOUT: 0.0 22 | NHEADS: 8 23 | NUM_QUERIES: 100 24 | ENC_N_POINTS: 4 25 | DEC_N_POINTS: 4 26 | NUM_POINTS: 25 27 | INFERENCE_TH_TEST: 0.4 28 | LOSS: 29 | BEZIER_SAMPLE_POINTS: 25 30 | BEZIER_CLASS_WEIGHT: 1.0 31 | BEZIER_COORD_WEIGHT: 1.0 32 | POINT_CLASS_WEIGHT: 1.0 33 | POINT_COORD_WEIGHT: 1.0 34 | POINT_TEXT_WEIGHT: 0.5 35 | BOUNDARY_WEIGHT: 0.5 36 | 37 | 38 | 39 | SOLVER: 40 | WEIGHT_DECAY: 1e-4 41 | OPTIMIZER: "ADAMW" 42 | LR_BACKBONE_NAMES: ['backbone.0'] 43 | LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets'] 44 | LR_LINEAR_PROJ_MULT: 1. 45 | CLIP_GRADIENTS: 46 | ENABLED: True 47 | CLIP_TYPE: "full_model" 48 | CLIP_VALUE: 0.1 49 | NORM_TYPE: 2.0 50 | 51 | INPUT: 52 | HFLIP_TRAIN: False 53 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896) 54 | MAX_SIZE_TRAIN: 1600 55 | MIN_SIZE_TEST: 1000 56 | MAX_SIZE_TEST: 1892 57 | CROP: 58 | ENABLED: True 59 | CROP_INSTANCE: False 60 | SIZE: [0.1, 0.1] 61 | FORMAT: "RGB" 62 | 63 | DATALOADER: 64 | NUM_WORKERS: 8 65 | 66 | VERSION: 2 67 | SEED: 42 -------------------------------------------------------------------------------- /configs/R_50/TotalText/SemiETS/SemiETS_1s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | LOSS: 8 | USE_DYNAMIC_K: False 9 | O2M_MATCH_NUM: 13 10 | LEVEN_ALPHA : 20 11 | ADP_POINT_COORD_WEIGHT: 1.0 12 | ADP_POINT_TEXT_WEIGHT: 0.5 13 | ADP_BOUNDARY_WEIGHT: 0.5 14 | DET_ADAPTIVE_TYPE: 'edit_distance' 15 | 16 | DATASETS: 17 | TRAIN: ("totaltext_train_1_label", "totaltext_train_1_unlabel") 18 | TEST: ("totaltext_test",) 19 | 20 | SOLVER: 21 | IMS_PER_BATCH: 12 #debug 22 | SOURCE_RATIO: (1,2) 23 | BASE_LR: 1e-5 24 | LR_BACKBONE: 1e-6 25 | WARMUP_ITERS: 0 26 | STEPS: (100000,) # no step 27 | MAX_ITER: 20000 28 | CHECKPOINT_PERIOD: 1000 29 | FIND_UNUSED_PARAMETERS: True 30 | 31 | SSL: 32 | MODE: "mean-teacher" 33 | SEMI_WRAPPER: "SemiETSTextSpotter" 34 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 35 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 36 | USE_SPOTTING_NMS: True 37 | USE_SEPERATE_MATCHER: False 38 | USE_COMBINED_THR: False 39 | O2M_TEXT_O2O: False 40 | USE_O2M_ENC : False 41 | MIN_PSEDO_BOX_SIZE: 0 42 | UNSUP_WEIGHT: 2.0 43 | CONSISTENCY_WEIGHT: 1.0 44 | AUG_QUERY: False 45 | INFERENCE_ON: "teacher" 46 | STEP_HOOK: True 47 | WARM_UP: 1000 #debug 48 | STAGE_WARM_UP: 10000 49 | EXTRA_STUDENT_INFO: True 50 | USE_CONSISTENCY: False 51 | EMA: 52 | WARM_UP: 0 53 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 54 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 55 | #ctc hard mining + o2o rcs weighting 56 | 57 | TEST: 58 | EVAL_PERIOD: 1000 59 | 60 | OUTPUT_DIR: "output/R50/150k_tt/Finetune/semi_1s/SemiETS_1s" -------------------------------------------------------------------------------- /configs/R_50/TotalText/SemiETS/SemiETS_2s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | LOSS: 8 | USE_DYNAMIC_K: False 9 | O2M_MATCH_NUM: 5 10 | LEVEN_ALPHA : 20 11 | ADP_POINT_COORD_WEIGHT: 1.0 12 | ADP_POINT_TEXT_WEIGHT: 0.5 13 | ADP_BOUNDARY_WEIGHT: 0.5 14 | DET_ADAPTIVE_TYPE: 'edit_distance' 15 | 16 | DATASETS: 17 | TRAIN: ("totaltext_train_2_label", "totaltext_train_2_unlabel") 18 | TEST: ("totaltext_test",) 19 | 20 | SOLVER: 21 | IMS_PER_BATCH: 12 #debug 22 | SOURCE_RATIO: (1,2) 23 | BASE_LR: 1e-5 24 | LR_BACKBONE: 1e-6 25 | WARMUP_ITERS: 0 26 | STEPS: (100000,) # no step 27 | MAX_ITER: 20000 28 | CHECKPOINT_PERIOD: 1000 29 | FIND_UNUSED_PARAMETERS: True 30 | 31 | SSL: 32 | MODE: "mean-teacher" 33 | SEMI_WRAPPER: "SemiETSTextSpotter" 34 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 35 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 36 | USE_SPOTTING_NMS: True 37 | USE_SEPERATE_MATCHER: False 38 | USE_COMBINED_THR: False 39 | O2M_TEXT_O2O: False 40 | USE_O2M_ENC : False 41 | MIN_PSEDO_BOX_SIZE: 0 42 | UNSUP_WEIGHT: 2.0 43 | CONSISTENCY_WEIGHT: 1.0 44 | AUG_QUERY: False 45 | INFERENCE_ON: "teacher" 46 | STEP_HOOK: True 47 | WARM_UP: 1000 #debug 48 | STAGE_WARM_UP: 10000 49 | EXTRA_STUDENT_INFO: True 50 | USE_CONSISTENCY: False 51 | EMA: 52 | WARM_UP: 0 53 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 54 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 55 | #ctc hard mining + o2o rcs weighting 56 | 57 | TEST: 58 | EVAL_PERIOD: 1000 59 | 60 | OUTPUT_DIR: "output/R50/150k_tt/Finetune/semi_2s/SemiETS_full_2s" -------------------------------------------------------------------------------- /configs/R_50/TotalText/SemiETS/SemiETS_5s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | LOSS: 8 | USE_DYNAMIC_K: False 9 | O2M_MATCH_NUM: 5 10 | LEVEN_ALPHA : 20 11 | COST_ALPHA : 20 #rec 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("totaltext_train_5_label", "totaltext_train_5_unlabel") 19 | TEST: ("totaltext_test",) 20 | 21 | SOLVER: 22 | IMS_PER_BATCH: 12 #debug 23 | SOURCE_RATIO: (1,2) 24 | BASE_LR: 1e-5 25 | LR_BACKBONE: 1e-6 26 | WARMUP_ITERS: 0 27 | STEPS: (100000,) # no step 28 | MAX_ITER: 20000 29 | CHECKPOINT_PERIOD: 1000 30 | FIND_UNUSED_PARAMETERS: True 31 | 32 | SSL: 33 | MODE: "mean-teacher" 34 | SEMI_WRAPPER: "SemiETSTextSpotter" 35 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 36 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 37 | USE_SPOTTING_NMS: True 38 | USE_SEPERATE_MATCHER: False 39 | USE_COMBINED_THR: False 40 | O2M_TEXT_O2O: False 41 | USE_O2M_ENC : False 42 | MIN_PSEDO_BOX_SIZE: 0 43 | UNSUP_WEIGHT: 2.0 44 | CONSISTENCY_WEIGHT: 1.0 45 | AUG_QUERY: False 46 | INFERENCE_ON: "teacher" 47 | STEP_HOOK: True 48 | WARM_UP: 1000 #debug 49 | STAGE_WARM_UP: 10000 50 | EXTRA_STUDENT_INFO: True 51 | USE_CONSISTENCY: False 52 | EMA: 53 | WARM_UP: 0 54 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 55 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 56 | #ctc hard mining + o2o rcs weighting 57 | 58 | TEST: 59 | EVAL_PERIOD: 1000 60 | 61 | OUTPUT_DIR: "output/R50/150k_tt/Finetune/semi_5s/SemiETS_5s" -------------------------------------------------------------------------------- /configs/R_50/TotalText/SemiETS/SemiETS_10s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | LOSS: 8 | USE_DYNAMIC_K: False 9 | O2M_MATCH_NUM: 5 10 | LEVEN_ALPHA : 20 11 | COST_ALPHA : 20 #rec 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("totaltext_train_10_label", "totaltext_train_10_unlabel") 19 | TEST: ("totaltext_test",) 20 | 21 | SOLVER: 22 | IMS_PER_BATCH: 12 #debug 23 | SOURCE_RATIO: (1,2) 24 | BASE_LR: 1e-5 25 | LR_BACKBONE: 1e-6 26 | WARMUP_ITERS: 0 27 | STEPS: (100000,) # no step 28 | MAX_ITER: 20000 29 | CHECKPOINT_PERIOD: 1000 30 | FIND_UNUSED_PARAMETERS: True 31 | 32 | SSL: 33 | MODE: "mean-teacher" 34 | SEMI_WRAPPER: "SemiETSTextSpotter" 35 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 36 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 37 | USE_SPOTTING_NMS: True 38 | USE_SEPERATE_MATCHER: False 39 | USE_COMBINED_THR: False 40 | O2M_TEXT_O2O: False 41 | USE_O2M_ENC : False 42 | MIN_PSEDO_BOX_SIZE: 0 43 | UNSUP_WEIGHT: 2.0 44 | CONSISTENCY_WEIGHT: 1.0 45 | AUG_QUERY: False 46 | INFERENCE_ON: "teacher" 47 | STEP_HOOK: True 48 | WARM_UP: 1000 #debug 49 | STAGE_WARM_UP: 10000 50 | EXTRA_STUDENT_INFO: True 51 | USE_CONSISTENCY: False 52 | EMA: 53 | WARM_UP: 0 54 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 55 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 56 | #ctc hard mining + o2o rcs weighting 57 | 58 | TEST: 59 | EVAL_PERIOD: 1000 60 | 61 | OUTPUT_DIR: "output/R50/150k_tt/Finetune/semi_10s/SemiETS_10s" -------------------------------------------------------------------------------- /configs/R_50/TotalText/SemiETS/SemiETS_0.5s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | LOSS: 8 | USE_DYNAMIC_K: False 9 | O2M_MATCH_NUM: 13 10 | LEVEN_ALPHA : 20 11 | ADP_POINT_COORD_WEIGHT: 1.0 12 | ADP_POINT_TEXT_WEIGHT: 0.5 13 | ADP_BOUNDARY_WEIGHT: 0.5 14 | DET_ADAPTIVE_TYPE: 'edit_distance' 15 | PRECISE_TEACHER: True 16 | 17 | DATASETS: 18 | TRAIN: ("totaltext_train_0.5_label", "totaltext_train_0.5_unlabel") 19 | TEST: ("totaltext_test",) 20 | 21 | SOLVER: 22 | IMS_PER_BATCH: 12 #debug 23 | SOURCE_RATIO: (1,2) 24 | BASE_LR: 1e-5 25 | LR_BACKBONE: 1e-6 26 | WARMUP_ITERS: 0 27 | STEPS: (100000,) # no step 28 | MAX_ITER: 20000 29 | CHECKPOINT_PERIOD: 1000 30 | FIND_UNUSED_PARAMETERS: True 31 | 32 | SSL: 33 | MODE: "mean-teacher" 34 | SEMI_WRAPPER: "SemiETSTextSpotter" 35 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 36 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 37 | USE_SPOTTING_NMS: True 38 | USE_SEPERATE_MATCHER: False 39 | USE_COMBINED_THR: False 40 | O2M_TEXT_O2O: False 41 | USE_O2M_ENC : False 42 | MIN_PSEDO_BOX_SIZE: 0 43 | UNSUP_WEIGHT: 2.0 44 | CONSISTENCY_WEIGHT: 1.0 45 | AUG_QUERY: False 46 | INFERENCE_ON: "teacher" 47 | STEP_HOOK: True 48 | WARM_UP: 1000 #debug 49 | STAGE_WARM_UP: 10000 50 | EXTRA_STUDENT_INFO: True 51 | USE_CONSISTENCY: False 52 | EMA: 53 | WARM_UP: 0 54 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 55 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 56 | #ctc hard mining + o2o rcs weighting 57 | 58 | TEST: 59 | EVAL_PERIOD: 1000 60 | 61 | OUTPUT_DIR: "output/R50/150k_tt/Finetune/semi_0.5s/SemiETS_0.5s" -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /configs/R_50/CTW1500/SemiETS/SemiETS_1s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 4 | WEIGHTS: "output/R50/150k/pretrain/model_ts_96_final.pth" 5 | TRANSFORMER: 6 | VOC_SIZE: 96 7 | NUM_POINTS: 50 8 | LOSS: 9 | BEZIER_SAMPLE_POINTS: 50 10 | BEZIER_CLASS_WEIGHT: 1.0 11 | BEZIER_COORD_WEIGHT: 0.5 12 | POINT_CLASS_WEIGHT: 1.0 13 | POINT_COORD_WEIGHT: 0.5 14 | POINT_TEXT_WEIGHT: 1.0 #0.5 15 | BOUNDARY_WEIGHT: 0.25 16 | ######################### 17 | USE_DYNAMIC_K: False 18 | O2M_MATCH_NUM: 5 19 | LEVEN_ALPHA : 20 #det 20 | DET_ADAPTIVE_TYPE: 'edit_distance' 21 | 22 | DATASETS: 23 | TRAIN: ("ctw1500_train_1_label", "ctw1500_train_1_unlabel") 24 | TEST: ("ctw1500_test",) 25 | 26 | 27 | INPUT: 28 | ROTATE: False 29 | MIN_SIZE_TEST: 1000 30 | MAX_SIZE_TEST: 1200 31 | 32 | 33 | SOLVER: 34 | IMS_PER_BATCH: 12 35 | SOURCE_RATIO: (1,2) 36 | BASE_LR: 5e-5 37 | LR_BACKBONE: 5e-6 38 | WARMUP_ITERS: 0 39 | STEPS: (16000,) #8000 40 | MAX_ITER: 24000 #12000 41 | CHECKPOINT_PERIOD: 1000 42 | FIND_UNUSED_PARAMETERS: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 4000 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | 74 | 75 | OUTPUT_DIR: "output/R50/ctw1500/Finetune/semi_1s/SemiETS_1s" -------------------------------------------------------------------------------- /configs/R_50/CTW1500/SemiETS/SemiETS_2s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 4 | WEIGHTS: "output/R50/150k/pretrain/model_ts_96_final.pth" 5 | TRANSFORMER: 6 | VOC_SIZE: 96 7 | NUM_POINTS: 50 8 | LOSS: 9 | BEZIER_SAMPLE_POINTS: 50 10 | BEZIER_CLASS_WEIGHT: 1.0 11 | BEZIER_COORD_WEIGHT: 0.5 12 | POINT_CLASS_WEIGHT: 1.0 13 | POINT_COORD_WEIGHT: 0.5 14 | POINT_TEXT_WEIGHT: 1.0 #0.5 15 | BOUNDARY_WEIGHT: 0.25 16 | ######################### 17 | USE_DYNAMIC_K: False 18 | O2M_MATCH_NUM: 5 19 | LEVEN_ALPHA : 20 #det 20 | DET_ADAPTIVE_TYPE: 'edit_distance' 21 | 22 | DATASETS: 23 | TRAIN: ("ctw1500_train_2_label", "ctw1500_train_2_unlabel") 24 | TEST: ("ctw1500_test",) 25 | 26 | 27 | INPUT: 28 | ROTATE: False 29 | MIN_SIZE_TEST: 1000 30 | MAX_SIZE_TEST: 1200 31 | 32 | 33 | SOLVER: 34 | IMS_PER_BATCH: 12 35 | SOURCE_RATIO: (1,2) 36 | BASE_LR: 5e-5 37 | LR_BACKBONE: 5e-6 38 | WARMUP_ITERS: 0 39 | STEPS: (16000,) #8000 40 | MAX_ITER: 24000 #12000 41 | CHECKPOINT_PERIOD: 1000 42 | FIND_UNUSED_PARAMETERS: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 4000 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | 74 | 75 | OUTPUT_DIR: "output/R50/ctw1500/Finetune/semi_2s/SemiETS_2s" -------------------------------------------------------------------------------- /configs/R_50/CTW1500/SemiETS/SemiETS_5s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 4 | WEIGHTS: "output/R50/150k/pretrain/model_ts_96_final.pth" 5 | TRANSFORMER: 6 | VOC_SIZE: 96 7 | NUM_POINTS: 50 8 | LOSS: 9 | BEZIER_SAMPLE_POINTS: 50 10 | BEZIER_CLASS_WEIGHT: 1.0 11 | BEZIER_COORD_WEIGHT: 0.5 12 | POINT_CLASS_WEIGHT: 1.0 13 | POINT_COORD_WEIGHT: 0.5 14 | POINT_TEXT_WEIGHT: 1.0 #0.5 15 | BOUNDARY_WEIGHT: 0.25 16 | ######################### 17 | USE_DYNAMIC_K: False 18 | O2M_MATCH_NUM: 5 19 | LEVEN_ALPHA : 20 #det 20 | DET_ADAPTIVE_TYPE: 'edit_distance' 21 | 22 | DATASETS: 23 | TRAIN: ("ctw1500_train_5_label", "ctw1500_train_5_unlabel") 24 | TEST: ("ctw1500_test",) 25 | 26 | 27 | INPUT: 28 | ROTATE: False 29 | MIN_SIZE_TEST: 1000 30 | MAX_SIZE_TEST: 1200 31 | 32 | 33 | SOLVER: 34 | IMS_PER_BATCH: 12 35 | SOURCE_RATIO: (1,2) 36 | BASE_LR: 5e-5 37 | LR_BACKBONE: 5e-6 38 | WARMUP_ITERS: 0 39 | STEPS: (16000,) #8000 40 | MAX_ITER: 24000 #12000 41 | CHECKPOINT_PERIOD: 1000 42 | FIND_UNUSED_PARAMETERS: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 4000 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | 74 | 75 | OUTPUT_DIR: "output/R50/ctw1500/Finetune/semi_5s/SemiETS_5s" -------------------------------------------------------------------------------- /configs/R_50/CTW1500/SemiETS/SemiETS_10s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 4 | WEIGHTS: "output/R50/150k/pretrain/model_ts_96_final.pth" 5 | TRANSFORMER: 6 | VOC_SIZE: 96 7 | NUM_POINTS: 50 8 | LOSS: 9 | BEZIER_SAMPLE_POINTS: 50 10 | BEZIER_CLASS_WEIGHT: 1.0 11 | BEZIER_COORD_WEIGHT: 0.5 12 | POINT_CLASS_WEIGHT: 1.0 13 | POINT_COORD_WEIGHT: 0.5 14 | POINT_TEXT_WEIGHT: 1.0 #0.5 15 | BOUNDARY_WEIGHT: 0.25 16 | ######################### 17 | USE_DYNAMIC_K: False 18 | O2M_MATCH_NUM: 5 19 | LEVEN_ALPHA : 20 #det 20 | DET_ADAPTIVE_TYPE: 'edit_distance' 21 | 22 | DATASETS: 23 | TRAIN: ("ctw1500_train_10_label", "ctw1500_train_10_unlabel") 24 | TEST: ("ctw1500_test",) 25 | 26 | 27 | INPUT: 28 | ROTATE: False 29 | MIN_SIZE_TEST: 1000 30 | MAX_SIZE_TEST: 1200 31 | 32 | 33 | SOLVER: 34 | IMS_PER_BATCH: 12 35 | SOURCE_RATIO: (1,2) 36 | BASE_LR: 5e-5 37 | LR_BACKBONE: 5e-6 38 | WARMUP_ITERS: 0 39 | STEPS: (16000,) #8000 40 | MAX_ITER: 24000 #12000 41 | CHECKPOINT_PERIOD: 1000 42 | FIND_UNUSED_PARAMETERS: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 4000 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | 74 | 75 | OUTPUT_DIR: "output/R50/ctw1500/Finetune/semi_10s/SemiETS_10s" -------------------------------------------------------------------------------- /configs/R_50/CTW1500/SemiETS/SemiETS_0.5s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 4 | WEIGHTS: "output/R50/150k/pretrain/model_ts_96_final.pth" 5 | TRANSFORMER: 6 | VOC_SIZE: 96 7 | NUM_POINTS: 50 8 | LOSS: 9 | BEZIER_SAMPLE_POINTS: 50 10 | BEZIER_CLASS_WEIGHT: 1.0 11 | BEZIER_COORD_WEIGHT: 0.5 12 | POINT_CLASS_WEIGHT: 1.0 13 | POINT_COORD_WEIGHT: 0.5 14 | POINT_TEXT_WEIGHT: 1.0 #0.5 15 | BOUNDARY_WEIGHT: 0.25 16 | ######################### 17 | USE_DYNAMIC_K: False 18 | O2M_MATCH_NUM: 5 19 | LEVEN_ALPHA : 20 #det 20 | DET_ADAPTIVE_TYPE: 'edit_distance' 21 | 22 | DATASETS: 23 | TRAIN: ("ctw1500_train_0.5_label", "ctw1500_train_0.5_unlabel") 24 | TEST: ("ctw1500_test",) 25 | 26 | 27 | INPUT: 28 | ROTATE: False 29 | MIN_SIZE_TEST: 1000 30 | MAX_SIZE_TEST: 1200 31 | 32 | 33 | SOLVER: 34 | IMS_PER_BATCH: 12 35 | SOURCE_RATIO: (1,2) 36 | BASE_LR: 5e-5 37 | LR_BACKBONE: 5e-6 38 | WARMUP_ITERS: 0 39 | STEPS: (16000,) #8000 40 | MAX_ITER: 24000 #12000 41 | CHECKPOINT_PERIOD: 1000 42 | FIND_UNUSED_PARAMETERS: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.4 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 4000 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | 74 | 75 | OUTPUT_DIR: "output/R50/ctw1500/Finetune/semi_0.5s/SemiETS_0.5s" -------------------------------------------------------------------------------- /configs/R_50/IC15/SemiETS/SemiETS_2s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | INFERENCE_TH_TEST: 0.3 8 | LOSS: 9 | USE_DYNAMIC_K: False 10 | O2M_MATCH_NUM: 5 11 | LEVEN_ALPHA : 20 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("ic15_train_2_label", "ic15_train_2_unlabel") 19 | TEST: ("ic15_test",) 20 | 21 | INPUT: 22 | MIN_SIZE_TRAIN: (800,900,1000,1100,1200,1300,1400) 23 | MAX_SIZE_TRAIN: 3000 24 | MIN_SIZE_TEST: 1440 25 | MAX_SIZE_TEST: 4000 26 | CROP: 27 | ENABLED: False 28 | ROTATE: False 29 | 30 | SOLVER: 31 | IMS_PER_BATCH: 8 #debug 32 | SOURCE_RATIO: (1,1) 33 | BASE_LR: 1e-5 34 | LR_BACKBONE: 1e-6 35 | WARMUP_ITERS: 0 36 | STEPS: (100000,) # no step 37 | MAX_ITER: 8000 38 | CHECKPOINT_PERIOD: 1000 39 | FIND_UNUSED_PARAMETERS: True 40 | AMP: 41 | ENABLED: True 42 | 43 | SSL: 44 | MODE: "mean-teacher" 45 | SEMI_WRAPPER: "SemiETSTextSpotter" 46 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.3 47 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 48 | USE_SPOTTING_NMS: True 49 | USE_SEPERATE_MATCHER: False 50 | USE_COMBINED_THR: False 51 | O2M_TEXT_O2O: False 52 | USE_O2M_ENC : False 53 | MIN_PSEDO_BOX_SIZE: 0 54 | UNSUP_WEIGHT: 2.0 55 | CONSISTENCY_WEIGHT: 1.0 56 | AUG_QUERY: False 57 | INFERENCE_ON: "teacher" 58 | STEP_HOOK: True 59 | WARM_UP: 1000 60 | STAGE_WARM_UP: 2500 61 | EXTRA_STUDENT_INFO: True 62 | USE_CONSISTENCY: False 63 | EMA: 64 | WARM_UP: 0 65 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 66 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 67 | #ctc hard mining + o2o rcs weighting 68 | 69 | TEST: 70 | EVAL_PERIOD: 1000 71 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 4 - training setting 72 | LEXICON_TYPE: 4 73 | 74 | OUTPUT_DIR: "output/R50/150k_tt/IC15/Finetune/semi_2s/SemiETS_2s" -------------------------------------------------------------------------------- /configs/R_50/IC15/SemiETS/SemiETS_1s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | INFERENCE_TH_TEST: 0.3 8 | LOSS: 9 | USE_DYNAMIC_K: False 10 | O2M_MATCH_NUM: 5 11 | LEVEN_ALPHA : 20 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("ic15_train_1_label", "ic15_train_1_unlabel") 19 | TEST: ("ic15_test",) 20 | 21 | INPUT: 22 | MIN_SIZE_TRAIN: (800,900,1000,1100,1200,1300,1400) 23 | MAX_SIZE_TRAIN: 3000 24 | MIN_SIZE_TEST: 1440 25 | MAX_SIZE_TEST: 4000 26 | CROP: 27 | ENABLED: False 28 | ROTATE: False 29 | 30 | SOLVER: 31 | IMS_PER_BATCH: 8 #debug 32 | SOURCE_RATIO: (1,1) 33 | BASE_LR: 1e-5 34 | LR_BACKBONE: 1e-6 35 | WARMUP_ITERS: 0 36 | STEPS: (100000,) # no step 37 | MAX_ITER: 8000 38 | CHECKPOINT_PERIOD: 1000 39 | FIND_UNUSED_PARAMETERS: True 40 | AMP: 41 | ENABLED: True 42 | 43 | SSL: 44 | MODE: "mean-teacher" 45 | SEMI_WRAPPER: "SemiETSTextSpotter" 46 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.3 47 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 48 | USE_SPOTTING_NMS: True 49 | USE_SEPERATE_MATCHER: False 50 | USE_COMBINED_THR: False 51 | O2M_TEXT_O2O: False 52 | USE_O2M_ENC : False 53 | MIN_PSEDO_BOX_SIZE: 0 54 | UNSUP_WEIGHT: 2.0 55 | CONSISTENCY_WEIGHT: 1.0 56 | AUG_QUERY: False 57 | INFERENCE_ON: "teacher" 58 | STEP_HOOK: True 59 | WARM_UP: 1000 #debug 60 | STAGE_WARM_UP: 2500 61 | EXTRA_STUDENT_INFO: True 62 | USE_CONSISTENCY: False 63 | EMA: 64 | WARM_UP: 0 65 | DECODER_LOSS: [ "labels", "texts_psa", "ctrl_points", "bd_points" ] 66 | O2O_DECODER_LOSS: [ "labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc" ] 67 | #ctc hard mining + o2o rcs weighting 68 | 69 | TEST: 70 | EVAL_PERIOD: 1000 71 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 4 - training setting 72 | LEXICON_TYPE: 4 73 | 74 | OUTPUT_DIR: "output/R50/150k_tt/IC15/Finetune/semi_1s/SemiETS_1s" -------------------------------------------------------------------------------- /configs/R_50/IC15/SemiETS/SemiETS_0.5s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | INFERENCE_TH_TEST: 0.3 8 | LOSS: 9 | USE_DYNAMIC_K: False 10 | O2M_MATCH_NUM: 5 11 | LEVEN_ALPHA : 20 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("ic15_train_0.5_label", "ic15_train_0.5_unlabel") 19 | TEST: ("ic15_test",) 20 | 21 | INPUT: 22 | MIN_SIZE_TRAIN: (800,900,1000,1100,1200,1300,1400) 23 | MAX_SIZE_TRAIN: 3000 24 | MIN_SIZE_TEST: 1440 25 | MAX_SIZE_TEST: 4000 26 | CROP: 27 | ENABLED: False 28 | ROTATE: False 29 | 30 | SOLVER: 31 | IMS_PER_BATCH: 8 #debug 32 | SOURCE_RATIO: (1,1) 33 | BASE_LR: 1e-5 34 | LR_BACKBONE: 1e-6 35 | WARMUP_ITERS: 0 36 | STEPS: (100000,) # no step 37 | MAX_ITER: 8000 38 | CHECKPOINT_PERIOD: 1000 39 | FIND_UNUSED_PARAMETERS: True 40 | AMP: 41 | ENABLED: True 42 | 43 | SSL: 44 | MODE: "mean-teacher" 45 | SEMI_WRAPPER: "SemiETSTextSpotter" 46 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.3 47 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 48 | USE_SPOTTING_NMS: True 49 | USE_SEPERATE_MATCHER: False 50 | USE_COMBINED_THR: False 51 | O2M_TEXT_O2O: False 52 | USE_O2M_ENC : False 53 | MIN_PSEDO_BOX_SIZE: 0 54 | UNSUP_WEIGHT: 2.0 55 | CONSISTENCY_WEIGHT: 1.0 56 | AUG_QUERY: False 57 | INFERENCE_ON: "teacher" 58 | STEP_HOOK: True 59 | WARM_UP: 1000 #debug 60 | STAGE_WARM_UP: 2500 61 | EXTRA_STUDENT_INFO: True 62 | USE_CONSISTENCY: False 63 | EMA: 64 | WARM_UP: 0 65 | DECODER_LOSS: [ "labels", "texts_psa", "ctrl_points", "bd_points" ] 66 | O2O_DECODER_LOSS: [ "labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc" ] 67 | #ctc hard mining + o2o rcs weighting 68 | 69 | TEST: 70 | EVAL_PERIOD: 1000 71 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 4 - training setting 72 | LEXICON_TYPE: 4 73 | 74 | OUTPUT_DIR: "output/R50/150k_tt/IC15/Finetune/semi_0.5s/SemiETS_0.5s" -------------------------------------------------------------------------------- /configs/R_50/IC15/SemiETS/SemiETS_5s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | INFERENCE_TH_TEST: 0.3 8 | LOSS: 9 | USE_DYNAMIC_K: False 10 | O2M_MATCH_NUM: 5 11 | LEVEN_ALPHA : 20 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("ic15_train_5_label", "ic15_train_5_unlabel") 19 | TEST: ("ic15_test",) 20 | 21 | INPUT: 22 | MIN_SIZE_TRAIN: (800,900,1000,1100,1200,1300,1400) 23 | MAX_SIZE_TRAIN: 3000 24 | MIN_SIZE_TEST: 1440 25 | MAX_SIZE_TEST: 4000 26 | CROP: 27 | ENABLED: False 28 | ROTATE: False 29 | 30 | 31 | SOLVER: 32 | IMS_PER_BATCH: 8 #debug 33 | SOURCE_RATIO: (1,1) 34 | BASE_LR: 1e-5 35 | LR_BACKBONE: 1e-6 36 | WARMUP_ITERS: 0 37 | STEPS: (100000,) # no step 38 | MAX_ITER: 8000 39 | CHECKPOINT_PERIOD: 1000 40 | FIND_UNUSED_PARAMETERS: True 41 | AMP: 42 | ENABLED: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.3 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 2500 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 4 - training setting 74 | LEXICON_TYPE: 4 75 | 76 | OUTPUT_DIR: "output/R50/150k_tt/IC15/Finetune/semi_5s/SemiETS_5s" -------------------------------------------------------------------------------- /configs/R_50/IC15/SemiETS/SemiETS_10s.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../Base_det.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "TransformerPureDetectorV2" #no o2m in sup:student.use_o2m =False 5 | WEIGHTS: "output/R50/150k/pretrain/model_ts_final.pth" 6 | TRANSFORMER: 7 | INFERENCE_TH_TEST: 0.3 8 | LOSS: 9 | USE_DYNAMIC_K: False 10 | O2M_MATCH_NUM: 5 11 | LEVEN_ALPHA : 20 12 | ADP_POINT_COORD_WEIGHT: 1.0 13 | ADP_POINT_TEXT_WEIGHT: 0.5 14 | ADP_BOUNDARY_WEIGHT: 0.5 15 | DET_ADAPTIVE_TYPE: 'edit_distance' 16 | 17 | DATASETS: 18 | TRAIN: ("ic15_train_10_label", "ic15_train_10_unlabel") 19 | TEST: ("ic15_test",) 20 | 21 | INPUT: 22 | MIN_SIZE_TRAIN: (800,900,1000,1100,1200,1300,1400) 23 | MAX_SIZE_TRAIN: 3000 24 | MIN_SIZE_TEST: 1440 25 | MAX_SIZE_TEST: 4000 26 | CROP: 27 | ENABLED: False 28 | ROTATE: False 29 | 30 | 31 | SOLVER: 32 | IMS_PER_BATCH: 8 #debug 33 | SOURCE_RATIO: (1,1) 34 | BASE_LR: 1e-5 35 | LR_BACKBONE: 1e-6 36 | WARMUP_ITERS: 0 37 | STEPS: (100000,) # no step 38 | MAX_ITER: 8000 39 | CHECKPOINT_PERIOD: 1000 40 | FIND_UNUSED_PARAMETERS: True 41 | AMP: 42 | ENABLED: True 43 | 44 | 45 | SSL: 46 | MODE: "mean-teacher" 47 | SEMI_WRAPPER: "SemiETSTextSpotter" 48 | PSEUDO_LABEL_INITIAL_SCORE_THR: 0.3 49 | PSEUDO_LABEL_FINAL_SCORE_THR: 0.7 50 | USE_SPOTTING_NMS: True 51 | USE_SEPERATE_MATCHER: False 52 | USE_COMBINED_THR: False 53 | O2M_TEXT_O2O: False 54 | USE_O2M_ENC : False 55 | MIN_PSEDO_BOX_SIZE: 0 56 | UNSUP_WEIGHT: 2.0 57 | CONSISTENCY_WEIGHT: 1.0 58 | AUG_QUERY: False 59 | INFERENCE_ON: "teacher" 60 | STEP_HOOK: True 61 | WARM_UP: 1000 #debug 62 | STAGE_WARM_UP: 2500 63 | EXTRA_STUDENT_INFO: True 64 | USE_CONSISTENCY: False 65 | EMA: 66 | WARM_UP: 0 67 | DECODER_LOSS : ["labels", "texts_psa", "ctrl_points", "bd_points"] 68 | O2O_DECODER_LOSS : ["labels", "texts_adaptive_sci", "ctrl_points_adaptive_crc", "bd_points_adaptive_crc"] 69 | #ctc hard mining + o2o rcs weighting 70 | 71 | TEST: 72 | EVAL_PERIOD: 1000 73 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 4 - training setting 74 | LEXICON_TYPE: 4 75 | 76 | OUTPUT_DIR: "output/R50/150k_tt/IC15/Finetune/semi_10s/SemiETS_10s" -------------------------------------------------------------------------------- /adet/modeling/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from detectron2.utils import comm 4 | 5 | 6 | @torch.no_grad() 7 | def concat_all_gather_with_various_shape(tensor): 8 | """ 9 | Performs all_gather operation on the provided tensors. 10 | *** Warning ***: torch.distributed.all_gather has no gradient. 11 | """ 12 | # current_rank = torch.distributed.get_rank() 13 | # print('current_rank: ', current_rank) 14 | # if len(tensor.size()) == 1: 15 | # tensor = tensor.view(-1, 1) 16 | 17 | # tensor_size = torch.tensor(tensor.size()).to(tensor.device) 18 | # device = tensor.device 19 | # dtype = tensor.dtype 20 | 21 | # size_gather = [torch.zeros_like(tensor_size) for _ in range(torch.distributed.get_world_size())] 22 | # torch.distributed.all_gather(size_gather, tensor_size, async_op=False) 23 | # tensors_gather = [torch.zeros(torch.Size(_size), dtype=dtype).to(device) for _size in size_gather] 24 | # torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 25 | 26 | tensors_gather = comm.all_gather(tensor) 27 | tensors_gather = [ t.detach().cpu() for t in tensors_gather] 28 | 29 | # tensors_gather.pop(current_rank) 30 | # print('>>>>> Right local?',(tensor == tensors_gather[current_rank]).all()) 31 | tensors_gather = torch.cat(tensors_gather) 32 | 33 | # output = torch.cat(tensors_gather, dim=0) 34 | return tensors_gather#, current_rank 35 | 36 | 37 | @torch.no_grad() 38 | def concat_all_gather(tensor): 39 | # gather all tensor shape 40 | shape_tensor = torch.tensor(tensor.shape, device='cuda') 41 | shape_list = [shape_tensor.clone() for _ in range(comm.get_world_size())] 42 | comm.all_gather(shape_list, shape_tensor) 43 | 44 | # padding tensor to the max length 45 | if shape_list[0].numel() > 1: 46 | max_shape = torch.tensor([_[0] for _ in shape_list]).max() 47 | padding_tensor = torch.zeros((max_shape, shape_tensor[1]), device='cuda').type_as(tensor) 48 | else: 49 | max_shape = torch.tensor(shape_list).max() 50 | padding_tensor = torch.zeros(max_shape, device='cuda').type_as(tensor) 51 | 52 | padding_tensor[:shape_tensor[0]] = tensor 53 | 54 | tensor_list = [torch.zeros_like(padding_tensor) for _ in range(comm.get_world_size())] 55 | comm.all_gather(tensor_list, padding_tensor) 56 | 57 | sub_tensor_list = [] 58 | for sub_tensor, sub_shape in zip(tensor_list, shape_list): 59 | sub_tensor_list.append(sub_tensor[:sub_shape[0]]) 60 | output = torch.cat(sub_tensor_list, dim=0) 61 | 62 | return output 63 | 64 | 65 | @torch.no_grad() 66 | def concat_all_gather_equal_size(tensor, dim=0): 67 | """Performs all_gather operation on the provided tensors. 68 | 69 | *** Warning ***: torch.distributed.all_gather has no gradient. 70 | """ 71 | tensors_gather = [ 72 | torch.ones_like(tensor) 73 | for _ in range(torch.distributed.get_world_size()) 74 | ] 75 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 76 | 77 | output = torch.cat(tensors_gather, dim=dim) 78 | return output 79 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import glob 5 | import os 6 | from setuptools import find_packages, setup 7 | import torch 8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 9 | 10 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 11 | assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3" 12 | 13 | 14 | def get_extensions(): 15 | this_dir = os.path.dirname(os.path.abspath(__file__)) 16 | extensions_dir = os.path.join(this_dir, "adet", "layers", "csrc") 17 | 18 | main_source = os.path.join(extensions_dir, "vision.cpp") 19 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 20 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 21 | os.path.join(extensions_dir, "*.cu") 22 | ) 23 | 24 | sources = [main_source] + sources 25 | 26 | extension = CppExtension 27 | 28 | extra_compile_args = {"cxx": []} 29 | define_macros = [] 30 | 31 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 32 | extension = CUDAExtension 33 | sources += source_cuda 34 | define_macros += [("WITH_CUDA", None)] 35 | extra_compile_args["nvcc"] = [ 36 | "-DCUDA_HAS_FP16=1", 37 | "-D__CUDA_NO_HALF_OPERATORS__", 38 | "-D__CUDA_NO_HALF_CONVERSIONS__", 39 | "-D__CUDA_NO_HALF2_OPERATORS__", 40 | ] 41 | 42 | if torch_ver < [1, 7]: 43 | # supported by https://github.com/pytorch/pytorch/pull/43931 44 | CC = os.environ.get("CC", None) 45 | if CC is not None: 46 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 47 | 48 | sources = [os.path.join(extensions_dir, s) for s in sources] 49 | 50 | include_dirs = [extensions_dir] 51 | 52 | ext_modules = [ 53 | extension( 54 | "adet._C", 55 | sources, 56 | include_dirs=include_dirs, 57 | define_macros=define_macros, 58 | extra_compile_args=extra_compile_args, 59 | ) 60 | ] 61 | 62 | return ext_modules 63 | 64 | 65 | setup( 66 | name="AdelaiDet", 67 | version="0.2.0", 68 | author="Adelaide Intelligent Machines", 69 | url="https://github.com/stanstarks/AdelaiDet", 70 | description="AdelaiDet is AIM's research " 71 | "platform for instance-level detection tasks based on Detectron2.", 72 | packages=find_packages(exclude=("configs", "tests")), 73 | python_requires=">=3.6", 74 | install_requires=[ 75 | "termcolor>=1.1", 76 | "Pillow>=6.0", 77 | "yacs>=0.1.6", 78 | "tabulate", 79 | "cloudpickle", 80 | "matplotlib", 81 | "tqdm>4.29.0", 82 | "tensorboard", 83 | "rapidfuzz", 84 | "Polygon3", 85 | "shapely", 86 | "scikit-image", 87 | "editdistance", 88 | "opencv-python", 89 | "numba", 90 | ], 91 | extras_require={"all": ["psutil"]}, 92 | ext_modules=get_extensions(), 93 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 94 | ) 95 | -------------------------------------------------------------------------------- /adet/utils/comm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.distributed as dist 4 | 5 | from detectron2.utils.comm import get_world_size 6 | 7 | 8 | def reduce_sum(tensor): 9 | world_size = get_world_size() 10 | if world_size < 2: 11 | return tensor 12 | tensor = tensor.clone() 13 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 14 | return tensor 15 | 16 | 17 | def reduce_mean(tensor): 18 | num_gpus = get_world_size() 19 | total = reduce_sum(tensor) 20 | return total.float() / num_gpus 21 | 22 | 23 | def aligned_bilinear(tensor, factor): 24 | assert tensor.dim() == 4 25 | assert factor >= 1 26 | assert int(factor) == factor 27 | 28 | if factor == 1: 29 | return tensor 30 | 31 | h, w = tensor.size()[2:] 32 | tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate") 33 | oh = factor * h + 1 34 | ow = factor * w + 1 35 | tensor = F.interpolate( 36 | tensor, size=(oh, ow), 37 | mode='bilinear', 38 | align_corners=True 39 | ) 40 | tensor = F.pad( 41 | tensor, pad=(factor // 2, 0, factor // 2, 0), 42 | mode="replicate" 43 | ) 44 | 45 | return tensor[:, :, :oh - 1, :ow - 1] 46 | 47 | 48 | def compute_locations(h, w, stride, device): 49 | shifts_x = torch.arange( 50 | 0, w * stride, step=stride, 51 | dtype=torch.float32, device=device 52 | ) 53 | shifts_y = torch.arange( 54 | 0, h * stride, step=stride, 55 | dtype=torch.float32, device=device 56 | ) 57 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 58 | shift_x = shift_x.reshape(-1) 59 | shift_y = shift_y.reshape(-1) 60 | locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 61 | return locations 62 | 63 | 64 | def compute_ious(pred, target): 65 | """ 66 | Args: 67 | pred: Nx4 predicted bounding boxes 68 | target: Nx4 target bounding boxes 69 | Both are in the form of FCOS prediction (l, t, r, b) 70 | """ 71 | pred_left = pred[:, 0] 72 | pred_top = pred[:, 1] 73 | pred_right = pred[:, 2] 74 | pred_bottom = pred[:, 3] 75 | 76 | target_left = target[:, 0] 77 | target_top = target[:, 1] 78 | target_right = target[:, 2] 79 | target_bottom = target[:, 3] 80 | 81 | target_aera = (target_left + target_right) * \ 82 | (target_top + target_bottom) 83 | pred_aera = (pred_left + pred_right) * \ 84 | (pred_top + pred_bottom) 85 | 86 | w_intersect = torch.min(pred_left, target_left) + \ 87 | torch.min(pred_right, target_right) 88 | h_intersect = torch.min(pred_bottom, target_bottom) + \ 89 | torch.min(pred_top, target_top) 90 | 91 | g_w_intersect = torch.max(pred_left, target_left) + \ 92 | torch.max(pred_right, target_right) 93 | g_h_intersect = torch.max(pred_bottom, target_bottom) + \ 94 | torch.max(pred_top, target_top) 95 | ac_uion = g_w_intersect * g_h_intersect 96 | 97 | area_intersect = w_intersect * h_intersect 98 | area_union = target_aera + pred_aera - area_intersect 99 | 100 | ious = (area_intersect + 1.0) / (area_union + 1.0) 101 | gious = ious - (ac_uion - area_union) / ac_uion 102 | 103 | return ious, gious 104 | -------------------------------------------------------------------------------- /adet/layers/pos_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | class PositionalEncoding1D(nn.Module): 6 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 7 | """ 8 | :param channels: The last dimension of the tensor you want to apply pos emb to. 9 | """ 10 | super().__init__() 11 | self.channels = num_pos_feats 12 | dim_t = torch.arange(0, self.channels, 2).float() 13 | if scale is not None and normalize is False: 14 | raise ValueError("normalize should be True if scale is passed") 15 | if scale is None: 16 | scale = 2 * np.pi 17 | self.scale = scale 18 | self.normalize = normalize 19 | inv_freq = 1. / (temperature ** (dim_t / self.channels)) 20 | self.register_buffer('inv_freq', inv_freq) 21 | 22 | def forward(self, tensor): 23 | """ 24 | :param tensor: A 2d tensor of size (len, c) 25 | :return: Positional Encoding Matrix of size (len, c) 26 | """ 27 | if tensor.ndim != 2: 28 | raise RuntimeError("The input tensor has to be 2D!") 29 | x, orig_ch = tensor.shape 30 | pos_x = torch.arange( 31 | 1, x + 1, device=tensor.device).type(self.inv_freq.type()) 32 | 33 | if self.normalize: 34 | eps = 1e-6 35 | pos_x = pos_x / (pos_x[-1:] + eps) * self.scale 36 | 37 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 38 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) 39 | emb = torch.zeros((x, self.channels), 40 | device=tensor.device).type(tensor.type()) 41 | emb[:, :self.channels] = emb_x 42 | 43 | return emb[:, :orig_ch] 44 | 45 | 46 | class PositionalEncoding2D(nn.Module): 47 | """ 48 | This is a more standard version of the position embedding, very similar to the one 49 | used by the Attention is all you need paper, generalized to work on images. 50 | """ 51 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 52 | super().__init__() 53 | self.num_pos_feats = num_pos_feats 54 | self.temperature = temperature 55 | self.normalize = normalize 56 | if scale is not None and normalize is False: 57 | raise ValueError("normalize should be True if scale is passed") 58 | if scale is None: 59 | scale = 2 * np.pi 60 | self.scale = scale 61 | 62 | def forward(self, tensors): 63 | x = tensors.tensors 64 | mask = tensors.mask 65 | assert mask is not None 66 | not_mask = ~mask 67 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 68 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 69 | if self.normalize: 70 | eps = 1e-6 71 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 72 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 73 | 74 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 75 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) 76 | 77 | pos_x = x_embed[:, :, :, None] / dim_t 78 | pos_y = y_embed[:, :, :, None] / dim_t 79 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 80 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 81 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 82 | return pos 83 | -------------------------------------------------------------------------------- /adet/data/geo_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Record the geometric transformation information used in the augmentation in a transformation matrix. 3 | """ 4 | import numpy as np 5 | 6 | 7 | class GeometricTransformationBase(object): 8 | @classmethod 9 | def inverse(cls, results): 10 | # compute the inverse 11 | return results["transform_matrix"].I # 3x3 12 | 13 | @classmethod 14 | def apply(self, results, operator, **kwargs): 15 | trans_matrix = getattr(self, f"_get_{operator}_matrix")(**kwargs) 16 | if "transform_matrix" not in results: 17 | results["transform_matrix"] = trans_matrix 18 | else: 19 | base_transformation = results["transform_matrix"] 20 | results["transform_matrix"] = np.dot(trans_matrix, base_transformation) 21 | 22 | @classmethod 23 | def apply_cv2_matrix(self, results, cv2_matrix): 24 | if cv2_matrix.shape[0] == 2: 25 | mat = np.concatenate( 26 | [cv2_matrix, np.array([0, 0, 1]).reshape((1, 3))], axis=0 27 | ) 28 | else: 29 | mat = cv2_matrix 30 | base_transformation = results["transform_matrix"] 31 | results["transform_matrix"] = np.dot(mat, base_transformation) 32 | return results 33 | 34 | @classmethod 35 | def _get_rotate_matrix(cls, degree=None, cv2_rotation_matrix=None, inverse=False): 36 | # TODO: this is rotated by zero point 37 | if degree is None and cv2_rotation_matrix is None: 38 | raise ValueError( 39 | "At least one of degree or rotation matrix should be provided" 40 | ) 41 | if degree: 42 | if inverse: 43 | degree = -degree 44 | rad = degree * np.pi / 180 45 | sin_a = np.sin(rad) 46 | cos_a = np.cos(rad) 47 | return np.array([[cos_a, sin_a, 0], [-sin_a, cos_a, 0], [0, 0, 1]]) # 2x3 48 | else: 49 | mat = np.concatenate( 50 | [cv2_rotation_matrix, np.array([0, 0, 1]).reshape((1, 3))], axis=0 51 | ) 52 | if inverse: 53 | mat = mat * np.array([[1, -1, -1], [-1, 1, -1], [1, 1, 1]]) 54 | return mat 55 | 56 | @classmethod 57 | def _get_shift_matrix(cls, dx=0, dy=0, inverse=False): 58 | if inverse: 59 | dx = -dx 60 | dy = -dy 61 | return np.array([[1, 0, dx], [0, 1, dy], [0, 0, 1]]) 62 | 63 | @classmethod 64 | def _get_shear_matrix( 65 | cls, degree=None, magnitude=None, direction="horizontal", inverse=False 66 | ): 67 | if magnitude is None: 68 | assert degree is not None 69 | rad = degree * np.pi / 180 70 | magnitude = np.tan(rad) 71 | 72 | if inverse: 73 | magnitude = -magnitude 74 | if direction == "horizontal": 75 | shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0], [0, 0, 1]]) 76 | else: 77 | shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0], [0, 0, 1]]) 78 | return shear_matrix 79 | 80 | @classmethod 81 | def _get_flip_matrix(cls, shape, direction="horizontal", inverse=False): 82 | h, w = shape 83 | if direction == "horizontal": 84 | flip_matrix = np.float32([[-1, 0, w], [0, 1, 0], [0, 0, 1]]) 85 | else: 86 | flip_matrix = np.float32([[1, 0, 0], [0, h - 1, 0], [0, 0, 1]]) 87 | return flip_matrix 88 | 89 | @classmethod 90 | def _get_scale_matrix(cls, sx, sy, inverse=False): 91 | if inverse: 92 | sx = 1 / sx 93 | sy = 1 / sy 94 | return np.float32([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) 95 | -------------------------------------------------------------------------------- /adet/modeling/semi_text_spotter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import time 3 | import logging 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import pickle 8 | import copy 9 | import torch.nn.functional as F 10 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 11 | from detectron2.modeling import build_backbone, build_model 12 | from detectron2.structures import ImageList, Instances 13 | from adet.layers.pos_encoding import PositionalEncoding2D 14 | # from adet.modeling.model.losses import SetCriterion, SetAdaptiveO2MCriterionFull 15 | from adet.modeling.model.matcher import build_matcher, build_matcher_o2m, CtrlPointCost, build_matcher_o2m_full 16 | from adet.modeling.utils.dist_utils import concat_all_gather_with_various_shape 17 | 18 | try: 19 | from ctcdecode import CTCBeamDecoder 20 | except ImportError: 21 | CTCBeamDecoder = None 22 | 23 | 24 | @META_ARCH_REGISTRY.register() 25 | class MultiStreamSpotter(nn.Module): 26 | def __init__(self, cfg): 27 | super(MultiStreamSpotter, self).__init__() 28 | 29 | self.device = torch.device(cfg.MODEL.DEVICE) 30 | 31 | self.submodules = ['student', 'teacher'] 32 | 33 | # create student model 34 | self.student = self.build_model(cfg) 35 | 36 | # create teacher model 37 | self.teacher = self.build_model(cfg) 38 | 39 | # inference model 40 | self.inference_on = 'teacher' 41 | 42 | # warm up using only labeled data 43 | self.label_warm_up = 0 44 | 45 | 46 | def build_model(self, cfg): 47 | model = build_model(cfg) 48 | return model 49 | 50 | # select the model to forward 51 | def model(self, **kwargs): 52 | if "submodule" in kwargs: 53 | assert ( 54 | kwargs["submodule"] in self.submodules 55 | ), "Detector does not contain submodule {}".format(kwargs["submodule"]) 56 | model = getattr(self, kwargs["submodule"]) 57 | else: 58 | model = getattr(self, self.inference_on) 59 | return model 60 | 61 | # freeze the sub-model during training 62 | def freeze(self, model_ref: str): 63 | assert model_ref in self.submodules 64 | model = getattr(self, model_ref) 65 | model.eval() 66 | for param in model.parameters(): 67 | param.requires_grad = False 68 | 69 | 70 | def split_ssl_batch(self, batched_inputs): 71 | """ 72 | Split batched inputs into labeled and unlabeled samples. 73 | Args: 74 | batched_inputs (list[dict]): same as in :meth:`forward` 75 | Returns: 76 | labeled_batched_inputs (list[dict]): same as in :meth:`forward` 77 | unlabeled_batched_inputs (list[dict]): same as in :meth:`forward` 78 | """ 79 | results = {'sup': [], 'unsup_teacher': [], 'unsup_student': []} 80 | for d in batched_inputs: 81 | semi_flag = d['semi'] 82 | if semi_flag == 'sup': 83 | results['sup'].append(d) 84 | elif semi_flag == 'unsup': 85 | results['unsup_teacher'].append(d['weak']) 86 | results['unsup_student'].append(d['strong']) 87 | 88 | return results 89 | 90 | def inference(self, batched_inputs): 91 | """ 92 | Run inference on the given inputs. 93 | Args: 94 | batched_inputs (list[dict]): same as in :meth:`forward` 95 | Returns: 96 | same as in :meth:`forward`. 97 | """ 98 | if self.inference_on == 'student': 99 | return self.student.forward(batched_inputs) 100 | elif self.inference_on == 'teacher': 101 | return self.teacher.forward(batched_inputs) 102 | else: 103 | raise NotImplementedError 104 | 105 | 106 | def reverse_sigmoid(y): 107 | return -torch.log(1 / y - 1) 108 | -------------------------------------------------------------------------------- /adet/modeling/utils/boundary_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | def extract_curved_roi_features(features, center_points, boundary_points, aggregation='concat'): 8 | assert aggregation in ['sum', 'concat'] 9 | 10 | # poi_feats = [[] for _ in range(len(features))] 11 | bs = features[0].shape[0] 12 | poi_feats = [[] for _ in range(bs)] 13 | 14 | 15 | 16 | for lvl, feats in enumerate(features): 17 | 18 | for batch_id, (feat, ref_pts) in enumerate(zip(feats, center_points)): 19 | sampled_features = F.grid_sample(feat[None,:,:,:], ref_pts[None,:,:,:], align_corners=True) 20 | # level_roi_feats.append(sampled_features) 21 | 22 | poi_feats[batch_id].append(sampled_features) 23 | 24 | for batch_id, feats in enumerate(poi_feats): 25 | if aggregation == 'sum': 26 | poi_feats[batch_id] = torch.stack(feats, dim=0).sum(dim=0).squeeze(0).permute(1,0,2) # [num_proposal, C, points] 27 | elif aggregation == 'concat': 28 | poi_feats[batch_id] = torch.cat(feats, dim=1).squeeze(0).permute(1,0,2) 29 | else: 30 | raise NotImplementedError 31 | 32 | return poi_feats 33 | 34 | 35 | class CurvedRoIExtractor(nn.Module): 36 | """ 37 | 38 | """ 39 | def __init__(self, out_channels, out_height=None, sample_center=True, aggregation='sum', mode='align'): 40 | super(CurvedRoIExtractor, self).__init__() 41 | self.out_channels = out_channels 42 | self.mode = mode 43 | self.sample_center = sample_center 44 | if out_height is None: 45 | self.out_height = 3 46 | else: 47 | self.out_height = out_height 48 | 49 | if sample_center: 50 | assert self.out_height % 2 == 1 51 | 52 | assert aggregation in ['sum', 'concat'] 53 | self.aggregation = aggregation 54 | 55 | 56 | def forward(self, features, center_points, boundary_points): 57 | 58 | bs = features[0].shape[0] 59 | 60 | if self.sample_center: 61 | assert center_points is not None 62 | 63 | if center_points is not None: 64 | assert len(boundary_points) == len(center_points) 65 | 66 | 67 | roi_feats_list = [[] for _ in range(bs)] 68 | t = torch.linspace(0, 1, self.out_height).to(features[0].device) 69 | t = t.reshape(self.out_height, 1, 1) 70 | t = t[None, :, :, :] 71 | 72 | for lvl, feats in enumerate(features): 73 | 74 | for batch_id in range(len(boundary_points)): 75 | 76 | # merge sampled coords 77 | up_points = boundary_points[batch_id][:, :, :2] 78 | down_points = boundary_points[batch_id][:, :, 2:] 79 | 80 | # for instance_id in range(boundary_points.size(0)): 81 | 82 | upts = up_points[:, None,:,:].repeat(1, self.out_height, 1, 1) 83 | dpts = down_points[:, None,:,:].repeat(1, self.out_height, 1, 1) 84 | sample_points = upts + (dpts - upts) * t.repeat(upts.size(0), 1, 1, 1) 85 | 86 | if self.sample_center: 87 | sample_points = sample_points.transpose(0, 1) 88 | sample_points[self.out_height // 2] = center_points[batch_id] 89 | sample_points = sample_points.transpose(0, 1) 90 | 91 | encoded_feats = feats[batch_id] 92 | 93 | sampled_feats = F.grid_sample(encoded_feats[None, :, :, :].repeat(sample_points.size(0), 1, 1, 1), 94 | sample_points, align_corners=True) 95 | roi_feats_list[batch_id].append(sampled_feats) 96 | 97 | for batch_id, feats in enumerate(roi_feats_list): 98 | if self.aggregation == 'sum': 99 | roi_feats_list[batch_id] = torch.stack(feats, dim=0).sum(dim=0) # [num_proposal, channel , height, points] 100 | elif self.aggregation == 'concat': 101 | roi_feats_list[batch_id] = torch.cat(feats, dim=1) 102 | else: 103 | raise NotImplementedError 104 | 105 | return roi_feats_list 106 | 107 | 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /adet/utils/structure_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import Counter, Mapping, Sequence 3 | from numbers import Number 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | import torch 8 | # from mmdet.core.mask.structures import BitmapMasks 9 | from torch.nn import functional as F 10 | 11 | _step_counter = Counter() 12 | 13 | 14 | def list_concat(data_list: List[list]): 15 | if isinstance(data_list[0], torch.Tensor): 16 | return torch.cat(data_list) 17 | else: 18 | endpoint = [d for d in data_list[0]] 19 | 20 | for i in range(1, len(data_list)): 21 | endpoint.extend(data_list[i]) 22 | return endpoint 23 | 24 | 25 | def sequence_concat(a, b): 26 | if isinstance(a, Sequence) and isinstance(b, Sequence): 27 | return a + b 28 | else: 29 | return None 30 | 31 | 32 | def dict_concat(dicts: List[Dict[str, list]]): 33 | return {k: list_concat([d[k] for d in dicts]) for k in dicts[0].keys()} 34 | 35 | 36 | def dict_fuse(obj_list, reference_obj): 37 | if isinstance(reference_obj, torch.Tensor): 38 | return torch.stack(obj_list) 39 | return obj_list 40 | 41 | 42 | def dict_select(dict1: Dict[str, list], key: str, value: str): 43 | flag = [v == value for v in dict1[key]] 44 | return { 45 | k: dict_fuse([vv for vv, ff in zip(v, flag) if ff], v) for k, v in dict1.items() 46 | } 47 | 48 | 49 | def dict_split(dict1, key): 50 | group_names = list(set(dict1[key])) 51 | dict_groups = {k: dict_select(dict1, key, k) for k in group_names} 52 | 53 | return dict_groups 54 | 55 | 56 | def dict_sum(a, b): 57 | if isinstance(a, dict): 58 | assert isinstance(b, dict) 59 | return {k: dict_sum(v, b[k]) for k, v in a.items()} 60 | elif isinstance(a, list): 61 | assert len(a) == len(b) 62 | return [dict_sum(aa, bb) for aa, bb in zip(a, b)] 63 | else: 64 | return a + b 65 | 66 | 67 | def zero_like(tensor_pack, prefix=""): 68 | if isinstance(tensor_pack, Sequence): 69 | return [zero_like(t) for t in tensor_pack] 70 | elif isinstance(tensor_pack, Mapping): 71 | return {prefix + k: zero_like(v) for k, v in tensor_pack.items()} 72 | elif isinstance(tensor_pack, torch.Tensor): 73 | return tensor_pack.new_zeros(tensor_pack.shape) 74 | elif isinstance(tensor_pack, np.ndarray): 75 | return np.zeros_like(tensor_pack) 76 | else: 77 | warnings.warn("Unexpected data type {}".format(type(tensor_pack))) 78 | return 0 79 | 80 | 81 | def pad_stack(tensors, shape, pad_value=255): 82 | tensors = torch.stack( 83 | [ 84 | F.pad( 85 | tensor, 86 | pad=[0, shape[1] - tensor.shape[1], 0, shape[0] - tensor.shape[0]], 87 | value=pad_value, 88 | ) 89 | for tensor in tensors 90 | ] 91 | ) 92 | return tensors 93 | 94 | 95 | def result2bbox(result): 96 | num_class = len(result) 97 | 98 | bbox = np.concatenate(result) 99 | if bbox.shape[0] == 0: 100 | label = np.zeros(0, dtype=np.uint8) 101 | else: 102 | label = np.concatenate( 103 | [[i] * len(result[i]) for i in range(num_class) if len(result[i]) > 0] 104 | ).reshape((-1,)) 105 | return bbox, label 106 | 107 | 108 | # def result2mask(result): 109 | # num_class = len(result) 110 | # mask = [np.stack(result[i]) for i in range(num_class) if len(result[i]) > 0] 111 | # if len(mask) > 0: 112 | # mask = np.concatenate(mask) 113 | # else: 114 | # mask = np.zeros((0, 1, 1)) 115 | # return BitmapMasks(mask, mask.shape[1], mask.shape[2]), None 116 | 117 | 118 | def sequence_mul(obj, multiplier): 119 | if isinstance(obj, Sequence): 120 | return [o * multiplier for o in obj] 121 | else: 122 | return obj * multiplier 123 | 124 | 125 | def is_match(word, word_list): 126 | for keyword in word_list: 127 | if keyword in word: 128 | return True 129 | return False 130 | 131 | 132 | def weighted_loss(loss: dict, weight, ignore_keys=[], warmup=0): 133 | _step_counter["weight"] += 1 134 | lambda_weight = ( 135 | lambda x: x * (_step_counter["weight"] - 1) / warmup 136 | if _step_counter["weight"] <= warmup 137 | else x 138 | ) 139 | if isinstance(weight, Mapping): 140 | for k, v in weight.items(): 141 | for name, loss_item in loss.items(): 142 | if (k in name) and ("loss" in name): 143 | loss[name] = sequence_mul(loss[name], lambda_weight(v)) 144 | elif isinstance(weight, Number): 145 | for name, loss_item in loss.items(): 146 | if "loss" in name: 147 | if not is_match(name, ignore_keys): 148 | loss[name] = sequence_mul(loss[name], lambda_weight(weight)) 149 | else: 150 | loss[name] = sequence_mul(loss[name], 0.0) 151 | else: 152 | raise NotImplementedError() 153 | return loss 154 | -------------------------------------------------------------------------------- /adet/config/defaults.py: -------------------------------------------------------------------------------- 1 | from detectron2.config.defaults import _C 2 | from detectron2.config import CfgNode as CN 3 | 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # Additional Configs 7 | # ---------------------------------------------------------------------------- # 8 | _C.MODEL.MOBILENET = False 9 | _C.MODEL.BACKBONE.ANTI_ALIAS = False 10 | _C.MODEL.RESNETS.DEFORM_INTERVAL = 1 11 | _C.INPUT.HFLIP_TRAIN = False 12 | _C.INPUT.CROP.CROP_INSTANCE = True 13 | _C.INPUT.ROTATE = True 14 | 15 | _C.MODEL.BASIS_MODULE = CN() 16 | _C.MODEL.BASIS_MODULE.NAME = "ProtoNet" 17 | _C.MODEL.BASIS_MODULE.NUM_BASES = 4 18 | _C.MODEL.BASIS_MODULE.LOSS_ON = False 19 | _C.MODEL.BASIS_MODULE.ANN_SET = "coco" 20 | _C.MODEL.BASIS_MODULE.CONVS_DIM = 128 21 | _C.MODEL.BASIS_MODULE.IN_FEATURES = ["p3", "p4", "p5"] 22 | _C.MODEL.BASIS_MODULE.NORM = "SyncBN" 23 | _C.MODEL.BASIS_MODULE.NUM_CONVS = 3 24 | _C.MODEL.BASIS_MODULE.COMMON_STRIDE = 8 25 | _C.MODEL.BASIS_MODULE.NUM_CLASSES = 80 26 | _C.MODEL.BASIS_MODULE.LOSS_WEIGHT = 0.3 27 | 28 | _C.MODEL.TOP_MODULE = CN() 29 | _C.MODEL.TOP_MODULE.NAME = "conv" 30 | _C.MODEL.TOP_MODULE.DIM = 16 31 | 32 | 33 | # ---------------------------------------------------------------------------- # 34 | # BAText Options 35 | # ---------------------------------------------------------------------------- # 36 | _C.MODEL.BATEXT = CN() 37 | _C.MODEL.BATEXT.VOC_SIZE = 96 38 | _C.MODEL.BATEXT.NUM_CHARS = 25 39 | _C.MODEL.BATEXT.POOLER_RESOLUTION = (8, 32) 40 | _C.MODEL.BATEXT.IN_FEATURES = ["p2", "p3", "p4"] 41 | _C.MODEL.BATEXT.POOLER_SCALES = (0.25, 0.125, 0.0625) 42 | _C.MODEL.BATEXT.SAMPLING_RATIO = 1 43 | _C.MODEL.BATEXT.CONV_DIM = 256 44 | _C.MODEL.BATEXT.NUM_CONV = 2 45 | _C.MODEL.BATEXT.RECOGNITION_LOSS = "ctc" 46 | _C.MODEL.BATEXT.RECOGNIZER = "attn" 47 | _C.MODEL.BATEXT.CANONICAL_SIZE = 96 # largest min_size for level 3 (stride=8) 48 | _C.MODEL.BATEXT.USE_COORDCONV = False 49 | _C.MODEL.BATEXT.USE_AET = False 50 | _C.MODEL.BATEXT.CUSTOM_DICT = "" # Path to the class file. 51 | 52 | 53 | # ---------------------------------------------------------------------------- # 54 | # SwinTransformer Options 55 | # ---------------------------------------------------------------------------- # 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.TYPE = 'tiny' 58 | _C.MODEL.SWIN.DROP_PATH_RATE = 0.2 59 | 60 | # ---------------------------------------------------------------------------- # 61 | # ViTAE-v2 Options 62 | # ---------------------------------------------------------------------------- # 63 | _C.MODEL.ViTAEv2 = CN() 64 | _C.MODEL.ViTAEv2.TYPE = 'vitaev2_s' 65 | _C.MODEL.ViTAEv2.DROP_PATH_RATE = 0.2 66 | 67 | # ---------------------------------------------------------------------------- # 68 | # (Deformable) Transformer Options 69 | # ---------------------------------------------------------------------------- # 70 | _C.MODEL.TRANSFORMER = CN() 71 | _C.MODEL.TRANSFORMER.ENABLED = False 72 | _C.MODEL.TRANSFORMER.INFERENCE_TH_TEST = 0.4 73 | _C.MODEL.TRANSFORMER.AUX_LOSS = True 74 | _C.MODEL.TRANSFORMER.ENC_LAYERS = 6 75 | _C.MODEL.TRANSFORMER.DEC_LAYERS = 6 76 | _C.MODEL.TRANSFORMER.EMB_LAYERS = 3 77 | _C.MODEL.TRANSFORMER.DIM_FEEDFORWARD = 1024 78 | _C.MODEL.TRANSFORMER.HIDDEN_DIM = 256 79 | _C.MODEL.TRANSFORMER.DROPOUT = 0.0 80 | _C.MODEL.TRANSFORMER.NHEADS = 8 81 | _C.MODEL.TRANSFORMER.NUM_QUERIES = 100 82 | _C.MODEL.TRANSFORMER.ENC_N_POINTS = 4 83 | _C.MODEL.TRANSFORMER.DEC_N_POINTS = 4 84 | _C.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE = 6.283185307179586 # 2 PI 85 | _C.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS = 4 86 | _C.MODEL.TRANSFORMER.VOC_SIZE = 37 # a-z + 0-9 + unknown 87 | _C.MODEL.TRANSFORMER.CUSTOM_DICT = "" # Path to the character class file. 88 | _C.MODEL.TRANSFORMER.NUM_POINTS = 25 # the number of point queries for each instance 89 | _C.MODEL.TRANSFORMER.TEMPERATURE = 10000 90 | _C.MODEL.TRANSFORMER.BOUNDARY_HEAD = True # True: with boundary predictions, False: only with center lines 91 | _C.MODEL.TRANSFORMER.SFEM = False 92 | _C.MODEL.TRANSFORMER.SFEM_DECODER_SA_TYPE = 'sa' 93 | _C.MODEL.TRANSFORMER.SFEM_DECODER_MODULE_SEQ = ['sa', 'ca', 'ffn'] 94 | 95 | 96 | _C.MODEL.TRANSFORMER.LOSS = CN() 97 | _C.MODEL.TRANSFORMER.LOSS.AUX_LOSS = True 98 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_ALPHA = 0.25 99 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_GAMMA = 2.0 100 | # bezier proposal loss 101 | _C.MODEL.TRANSFORMER.LOSS.BEZIER_CLASS_WEIGHT = 1.0 102 | _C.MODEL.TRANSFORMER.LOSS.BEZIER_COORD_WEIGHT = 1.0 103 | _C.MODEL.TRANSFORMER.LOSS.BEZIER_SAMPLE_POINTS = 25 104 | # supervise the sampled on-curve points but not 4 Bezier control points 105 | 106 | 107 | 108 | # target loss 109 | _C.MODEL.TRANSFORMER.LOSS.POINT_CLASS_WEIGHT = 1.0 110 | _C.MODEL.TRANSFORMER.LOSS.POINT_COORD_WEIGHT = 1.0 111 | _C.MODEL.TRANSFORMER.LOSS.POINT_TEXT_WEIGHT = 0.5 112 | _C.MODEL.TRANSFORMER.LOSS.BOUNDARY_WEIGHT = 0.5 113 | 114 | # instance confidence loss 115 | _C.MODEL.TRANSFORMER.LOSS.INSTANCE_CLASS_WEIGHT = 1.0 116 | _C.MODEL.TRANSFORMER.LOSS.INSTANCE_REC_WEIGHT = 1.0 117 | #Hard Sample Mining loss 118 | _C.MODEL.TRANSFORMER.LOSS.USE_DYNAMIC_K=False 119 | _C.MODEL.TRANSFORMER.LOSS.LEVEN_ALPHA = 20 120 | _C.MODEL.TRANSFORMER.LOSS.COST_ALPHA = 20 121 | 122 | _C.SOLVER.OPTIMIZER = "ADAMW" 123 | _C.SOLVER.LR_BACKBONE = 1e-5 124 | _C.SOLVER.LR_BACKBONE_NAMES = [] 125 | _C.SOLVER.LR_LINEAR_PROJ_NAMES = [] 126 | _C.SOLVER.LR_LINEAR_PROJ_MULT = 0.1 127 | _C.SOLVER.SOURCE_RATIO = [1,1] 128 | 129 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 130 | # 1 - Full lexicon (for totaltext) 131 | _C.TEST.LEXICON_TYPE = 1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SemiETS 2 | 3 | 👋 Welcome to the official code of [SemiETS: Integrating Spatial and Content Consistencies for Semi-Supervised End-to-end Text Spotting](https://arxiv.org/abs/2504.09966) (CVPR 2025) 4 | 5 | This work explored semi-supervised text spotting (SSTS) to reduce the expensive annotation costs for text spotting. We observe two challenges in SSTS: 1) inconsistent pseudo labels between detection and recognition tasks, and 2) sub-optimal supervisions caused by inconsistency between teacher/student. Addressing them, we proposed SemiETS. It gradually generates reliable hierarchical pseudo labels for each task, thereby reducing noisy labels. Meanwhile, it extracts important information in text locations and transcriptions from bidirectional flows to improve consistency. 6 | 7 |
8 | 9 |
10 | 11 | 12 | ## 📖 Usage 13 | 14 | ### 🛠️ Dependencies and Installation 15 | 16 | * **Environment** 17 | 18 | ``` 19 | Python 3.8 + Pytorch 1.9.0 + CUDA 11.1 + Detectron2 (v0.6) + ctcdecode 20 | ``` 21 | 22 | 1. **Install SemiETS** 23 | 24 | ``` 25 | # 1. Clone depository 26 | git clone git@github.com:DrLuo/SemiETS.git 27 | cd SemiETS 28 | 29 | # 2. Create conda environment 30 | conda create -n semiets python=3.8 -y 31 | conda activate semiets 32 | 33 | # 3. Install PyTorch and other dependencies using conda 34 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 35 | pip install -r requirements.txt 36 | python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html 37 | python setup.py build develop 38 | ``` 39 | 40 | 2. **Install ctcdecode** from [source](https://github.com/parlance/ctcdecode) 41 | 42 | ``` 43 | git clone --recursive https://github.com/parlance/ctcdecode.git 44 | cd ctcdecode 45 | pip install . 46 | ``` 47 | 48 | 49 | ### 🧱 Preparation 50 | 51 | 1. **Download datasets** from [here](https://github.com/ViTAE-Transformer/DeepSolo/blob/main/DeepSolo/README.md#preparation). Data splits are in ```SemiETS/datasets```. 52 | 53 |
54 | Dataset Orgnization 55 | 56 | *Some image files need to be renamed.* Organize them as follows (lexicon files are not listed here): 57 | 58 | ``` 59 | |- ./datasets 60 | |- syntext1 61 | | |- train_images 62 | | └ annotations 63 | | |- train_37voc.json 64 | | └ train_96voc.json 65 | |- syntext2 66 | | |- train_images 67 | | └ annotations 68 | | |- train_37voc.json 69 | | └ train_96voc.json 70 | |- totaltext 71 | | |- train_images 72 | | |- test_images 73 | | |- train_37voc.json 74 | | |- train_96voc.json 75 | | |- train_37voc_0.5_labeled.json 76 | | |- train_37voc_0.5_unlabeled.json 77 | | |- train_37voc_1_labeled.json 78 | | |- train_37voc_1_unlabeled.json 79 | | |- train_37voc_2_labeled.json 80 | | |- train_37voc_2_unlabeled.json 81 | | |- train_37voc_5_labeled.json 82 | | |- train_37voc_5_unlabeled.json 83 | | |- train_37voc_10_labeled.json 84 | | |- train_37voc_10_unlabeled.json 85 | | └ test.json 86 | |- ic15 87 | | |- train_images 88 | | |- test_images 89 | | |- train_37voc.json 90 | | |- train_96voc.json 91 | | |- train_37voc_0.5_labeled.json 92 | | |- train_37voc_0.5_unlabeled.json 93 | | |- train_37voc_1_labeled.json 94 | | |- train_37voc_1_unlabeled.json 95 | | |- train_37voc_2_labeled.json 96 | | |- train_37voc_2_unlabeled.json 97 | | |- train_37voc_5_labeled.json 98 | | |- train_37voc_5_unlabeled.json 99 | | |- train_37voc_10_labeled.json 100 | | |- train_37voc_10_unlabeled.json 101 | | └ test.json 102 | |- ctw1500 103 | | |- train_images 104 | | |- test_images 105 | | |- train_96voc.json 106 | | |- train_96voc_0.5_labeled.json 107 | | |- train_96voc_0.5_unlabeled.json 108 | | |- train_96voc_1_labeled.json 109 | | |- train_96voc_1_unlabeled.json 110 | | |- train_96voc_2_labeled.json 111 | | |- train_96voc_2_unlabeled.json 112 | | |- train_96voc_5_labeled.json 113 | | |- train_96voc_5_unlabeled.json 114 | | |- train_96voc_10_labeled.json 115 | | |- train_96voc_10_unlabeled.json 116 | | └ test.json 117 | |- evaluation 118 | | |- gt_*.zip 119 | ``` 120 |
121 | 122 | 123 | 124 | 2. **Download pretrained weights** to for initialization from [Google Drive](https://drive.google.com/drive/folders/1ix416PtjenJxvDm_2KlS6z1vo6z5gI1K?usp=drive_link) 125 | 126 | The checkpoints were pretrained using only Synth150K. 127 | Place them under the folder ```./output/R50/150k_tt/pretrain/```. 128 | 129 | 130 | 131 | 132 | ### 🚀 Training 133 | 134 | ``` 135 | python tools/train_semi.py --config-file ${CONFIG_FILE} --num-gpus 4 --dist-url 'auto' 136 | ``` 137 | 138 | For example: 139 | ``` 140 | python tools/train_semi.py --config-file configs/R_50/TotalText/SemiETS/SemiETS_2s.yaml --num-gpus 4 --dist-url 'auto' 141 | ``` 142 | 143 | The configuration files are named following the format: ```SemiETS_{DATA_PROPORTION}s.yaml``` 144 | 145 | 146 | 147 | ### 📈 Evaluation 148 | 149 | ``` 150 | python tools/train_semi.py --config-file ${CONFIG_FILE} --eval-only MODEL.WEIGHTS ${MODEL_PATH} 151 | ``` 152 | 153 | 154 | ## 🔗 Citation 155 | If you find [SemiETS](https://arxiv.org/abs/2504.09966) useful for your research and applications, please cite using this BibTeX: 156 | 157 | ``` 158 | @article{luo2025semiets, 159 | title={SemiETS: Integrating Spatial and Content Consistencies for Semi-Supervised End-to-end Text Spotting}, 160 | author={Luo, Dongliang and Zhu, Hanshen and Zhang, Ziyang and Liang, Dingkang and Xie, Xudong and Liu, Yuliang and Bai, Xiang}, 161 | journal={CVPR}, 162 | year={2025} 163 | } 164 | ``` 165 | 166 | ## Acknowledgement 167 | This project is based on [DeepSolo](https://github.com/ViTAE-Transformer/DeepSolo) and [Adelaidet](https://github.com/aim-uofa/AdelaiDet). We appreciate their wonderful codebase. For academic use, this project is licensed under the 2-clause BSD License. 168 | 169 | -------------------------------------------------------------------------------- /adet/utils/misc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import torch 3 | from torch.functional import Tensor 4 | from torchvision.ops.boxes import box_area 5 | import torch.distributed as dist 6 | 7 | 8 | def is_dist_avail_and_initialized(): 9 | if not dist.is_available(): 10 | return False 11 | if not dist.is_initialized(): 12 | return False 13 | return True 14 | 15 | 16 | @torch.no_grad() 17 | def accuracy(output, target, topk=(1,)): 18 | """Computes the precision@k for the specified values of k""" 19 | if target.numel() == 0: 20 | return [torch.zeros([], device=output.device)] 21 | if target.ndim == 2: 22 | assert output.ndim == 3 23 | output = output.mean(1) 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, -1) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].view(-1).float().sum(0) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | 38 | def box_cxcywh_to_xyxy(x): 39 | x_c, y_c, w, h = x.unbind(-1) 40 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 41 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 42 | return torch.stack(b, dim=-1) 43 | 44 | 45 | def box_xyxy_to_cxcywh(x): 46 | x0, y0, x1, y1 = x.unbind(-1) 47 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 48 | (x1 - x0), (y1 - y0)] 49 | return torch.stack(b, dim=-1) 50 | 51 | 52 | # modified from torchvision to also return the union 53 | def box_iou(boxes1, boxes2): 54 | area1 = box_area(boxes1) 55 | area2 = box_area(boxes2) 56 | 57 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 58 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 59 | 60 | wh = (rb - lt).clamp(min=0) # [N,M,2] 61 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 62 | 63 | union = area1[:, None] + area2 - inter 64 | 65 | iou = inter / union 66 | return iou, union 67 | 68 | 69 | def generalized_box_iou(boxes1, boxes2): 70 | """ 71 | Generalized IoU from https://giou.stanford.edu/ 72 | The boxes should be in [x0, y0, x1, y1] format 73 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 74 | and M = len(boxes2) 75 | """ 76 | # degenerate boxes gives inf / nan results 77 | # so do an early check 78 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 79 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 80 | iou, union = box_iou(boxes1, boxes2) 81 | 82 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 83 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 84 | 85 | wh = (rb - lt).clamp(min=0) # [N,M,2] 86 | area = wh[:, :, 0] * wh[:, :, 1] 87 | 88 | return iou - (area - union) / area 89 | 90 | 91 | def masks_to_boxes(masks): 92 | """Compute the bounding boxes around the provided masks 93 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 94 | Returns a [N, 4] tensors, with the boxes in xyxy format 95 | """ 96 | if masks.numel() == 0: 97 | return torch.zeros((0, 4), device=masks.device) 98 | 99 | h, w = masks.shape[-2:] 100 | 101 | y = torch.arange(0, h, dtype=torch.float) 102 | x = torch.arange(0, w, dtype=torch.float) 103 | y, x = torch.meshgrid(y, x) 104 | 105 | x_mask = (masks * x.unsqueeze(0)) 106 | x_max = x_mask.flatten(1).max(-1)[0] 107 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 108 | 109 | y_mask = (masks * y.unsqueeze(0)) 110 | y_max = y_mask.flatten(1).max(-1)[0] 111 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 112 | 113 | return torch.stack([x_min, y_min, x_max, y_max], 1) 114 | 115 | def inverse_sigmoid(x, eps=1e-5): 116 | x = x.clamp(min=0, max=1) 117 | x1 = x.clamp(min=eps) 118 | x2 = (1 - x).clamp(min=eps) 119 | return torch.log(x1/x2) 120 | 121 | def sigmoid_offset(x, offset=True): 122 | # modified sigmoid for range [-0.5, 1.5] 123 | if offset: 124 | return x.sigmoid() * 2 - 0.5 125 | else: 126 | return x.sigmoid() 127 | 128 | def inverse_sigmoid_offset(x, eps=1e-5, offset=True): 129 | if offset: 130 | x = (x + 0.5) / 2.0 131 | return inverse_sigmoid(x, eps) 132 | 133 | def _max_by_axis(the_list): 134 | # type: (List[List[int]]) -> List[int] 135 | maxes = the_list[0] 136 | for sublist in the_list[1:]: 137 | for index, item in enumerate(sublist): 138 | maxes[index] = max(maxes[index], item) 139 | return maxes 140 | 141 | 142 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 143 | # make this more general 144 | if tensor_list[0].ndim == 3: 145 | # make it support different-sized images 146 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 147 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 148 | batch_shape = [len(tensor_list)] + max_size 149 | b, c, h, w = batch_shape 150 | dtype = tensor_list[0].dtype 151 | device = tensor_list[0].device 152 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 153 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 154 | for img, pad_img, m in zip(tensor_list, tensor, mask): 155 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 156 | m[: img.shape[1], :img.shape[2]] = False 157 | else: 158 | raise ValueError('not supported') 159 | return NestedTensor(tensor, mask) 160 | 161 | 162 | class NestedTensor(object): 163 | def __init__(self, tensors, mask: Optional[Tensor]): 164 | self.tensors = tensors 165 | self.mask = mask 166 | 167 | def to(self, device): 168 | # type: (Device) -> NestedTensor # noqa 169 | cast_tensor = self.tensors.to(device) 170 | mask = self.mask 171 | if mask is not None: 172 | assert mask is not None 173 | cast_mask = mask.to(device) 174 | else: 175 | cast_mask = None 176 | return NestedTensor(cast_tensor, cast_mask) 177 | 178 | def decompose(self): 179 | return self.tensors, self.mask 180 | 181 | def __repr__(self): 182 | return str(self.tensors) 183 | -------------------------------------------------------------------------------- /adet/checkpoint/adet_checkpoint.py: -------------------------------------------------------------------------------- 1 | import pickle, os 2 | from fvcore.common.file_io import PathManager 3 | from detectron2.checkpoint import DetectionCheckpointer 4 | from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts 5 | from fvcore.common.checkpoint import _strip_prefix_if_present, _IncompatibleKeys 6 | 7 | 8 | class AdetCheckpointer(DetectionCheckpointer): 9 | """ 10 | Same as :class:`DetectronCheckpointer`, but is able to convert models 11 | in AdelaiDet, such as LPF backbone. 12 | """ 13 | def _load_file(self, filename): 14 | if filename.endswith(".pkl"): 15 | with PathManager.open(filename, "rb") as f: 16 | data = pickle.load(f, encoding="latin1") 17 | if "model" in data and "__author__" in data: 18 | # file is in Detectron2 model zoo format 19 | self.logger.info("Reading a file from '{}'".format(data["__author__"])) 20 | return data 21 | else: 22 | # assume file is from Caffe2 / Detectron1 model zoo 23 | if "blobs" in data: 24 | # Detection models have "blobs", but ImageNet models don't 25 | data = data["blobs"] 26 | data = {k: v for k, v in data.items() if not k.endswith("_momentum")} 27 | if "weight_order" in data: 28 | del data["weight_order"] 29 | return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} 30 | 31 | loaded = super()._load_file(filename) # load native pth checkpoint 32 | if "model" not in loaded: 33 | loaded = {"model": loaded} 34 | 35 | basename = os.path.basename(filename).lower() 36 | if "lpf" in basename or "dla" in basename: 37 | loaded["matching_heuristics"] = True 38 | return loaded 39 | 40 | 41 | 42 | class AdetTSCheckpointer(AdetCheckpointer): 43 | """ 44 | Same as :class:`AdetCheckpointer`, 45 | but is able to load the whole model or only teacher and student modal 46 | """ 47 | def _load_model(self, checkpoint): 48 | if checkpoint.get("student", None) == True: 49 | # pretrained model weight: only update student model 50 | if checkpoint.get("matching_heuristics", False): 51 | self._convert_ndarray_to_tensor(checkpoint["model"]) 52 | # convert weights by name-matching heuristics 53 | checkpoint["model"] = align_and_update_state_dicts( 54 | self.model.student.state_dict(), 55 | checkpoint["model"], 56 | c2_conversion=checkpoint.get("__author__", None) == "Caffe2", 57 | ) 58 | 59 | # for non-caffe2 models, use standard ways to load it 60 | incompatible = self._load_student_model(checkpoint) 61 | 62 | model_buffers = dict(self.model.student.named_buffers(recurse=False)) 63 | for k in ["pixel_mean", "pixel_std"]: 64 | # Ignore missing key message about pixel_mean/std. 65 | # Though they may be missing in old checkpoints, they will be correctly 66 | # initialized from config anyway. 67 | if k in model_buffers: 68 | try: 69 | incompatible.missing_keys.remove(k) 70 | except ValueError: 71 | pass 72 | for k in incompatible.unexpected_keys[:]: 73 | # Ignore unexpected keys about cell anchors. They exist in old checkpoints 74 | # but now they are non-persistent buffers and will not be in semi_5s checkpoints. 75 | if "anchor_generator.cell_anchors" in k: 76 | incompatible.unexpected_keys.remove(k) 77 | return incompatible 78 | 79 | 80 | 81 | else: # whole model 82 | if checkpoint.get("matching_heuristics", False): 83 | self._convert_ndarray_to_tensor(checkpoint["model"]) 84 | # convert weights by name-matching heuristics 85 | checkpoint["model"] = align_and_update_state_dicts( 86 | self.model.state_dict(), 87 | checkpoint["model"], 88 | c2_conversion=checkpoint.get("__author__", None) == "Caffe2", 89 | ) 90 | # for non-caffe2 models, use standard ways to load it 91 | incompatible = super()._load_model(checkpoint) 92 | 93 | model_buffers = dict(self.model.named_buffers(recurse=False)) 94 | for k in ["pixel_mean", "pixel_std"]: 95 | # Ignore missing key message about pixel_mean/std. 96 | # Though they may be missing in old checkpoints, they will be correctly 97 | # initialized from config anyway. 98 | if k in model_buffers: 99 | try: 100 | incompatible.missing_keys.remove(k) 101 | except ValueError: 102 | pass 103 | for k in incompatible.unexpected_keys[:]: 104 | # Ignore unexpected keys about cell anchors. They exist in old checkpoints 105 | # but now they are non-persistent buffers and will not be in semi_5s checkpoints. 106 | if "anchor_generator.cell_anchors" in k: 107 | incompatible.unexpected_keys.remove(k) 108 | return incompatible 109 | 110 | def _load_student_model(self, checkpoint) -> _IncompatibleKeys: # pyre-ignore 111 | checkpoint_state_dict = checkpoint.pop("model") 112 | self._convert_ndarray_to_tensor(checkpoint_state_dict) 113 | 114 | # if the state_dict comes from a model that was wrapped in a 115 | # DataParallel or DistributedDataParallel during serialization, 116 | # remove the "module" prefix before performing the matching. 117 | _strip_prefix_if_present(checkpoint_state_dict, "module.") 118 | 119 | # work around https://github.com/pytorch/pytorch/issues/24139 120 | model_state_dict = self.model.student.state_dict() 121 | incorrect_shapes = [] 122 | for k in list(checkpoint_state_dict.keys()): 123 | if k in model_state_dict: 124 | shape_model = tuple(model_state_dict[k].shape) 125 | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) 126 | if shape_model != shape_checkpoint: 127 | incorrect_shapes.append((k, shape_checkpoint, shape_model)) 128 | checkpoint_state_dict.pop(k) 129 | # pyre-ignore 130 | incompatible = self.model.modelStudent.load_state_dict( 131 | checkpoint_state_dict, strict=False 132 | ) 133 | return _IncompatibleKeys( 134 | missing_keys=incompatible.missing_keys, 135 | unexpected_keys=incompatible.unexpected_keys, 136 | incorrect_shapes=incorrect_shapes, 137 | ) -------------------------------------------------------------------------------- /datasets/ic15/train_37voc_0.5_labeled.json: -------------------------------------------------------------------------------- 1 | {"images": [{"coco_url": "", "date_captured": "", "file_name": "img_988.jpg", "flickr_url": "", "id": 988, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_196.jpg", "flickr_url": "", "id": 196, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_833.jpg", "flickr_url": "", "id": 833, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_790.jpg", "flickr_url": "", "id": 790, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_836.jpg", "flickr_url": "", "id": 836, "license": 0, "width": 1280, "height": 720}], "annotations": [{"area": 6230.0, "bbox": [815.0, 71.0, 70.0, 89.0], "category_id": 1, "id": 379, "image_id": 988, "iscrowd": 0, "bezier_pts": [815, 122, 836, 105, 858, 88, 880, 71, 884, 115, 863, 129, 843, 144, 823, 159], "rec": [31, 24, 4, 0, 17, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 10465.0, "bbox": [874.0, 2.0, 91.0, 115.0], "category_id": 1, "id": 380, "image_id": 988, "iscrowd": 0, "bezier_pts": [874, 62, 901, 42, 929, 22, 957, 2, 964, 57, 937, 76, 910, 96, 883, 116], "rec": [17, 0, 12, 4, 13, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 12852.0, "bbox": [1095.0, 84.0, 108.0, 119.0], "category_id": 1, "id": 381, "image_id": 988, "iscrowd": 0, "bezier_pts": [1095, 120, 1128, 108, 1161, 96, 1194, 84, 1202, 169, 1167, 180, 1132, 191, 1097, 202], "rec": [19, 14, 13, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1240.0, "bbox": [614.0, 98.0, 62.0, 20.0], "category_id": 1, "id": 4267, "image_id": 196, "iscrowd": 0, "bezier_pts": [615, 98, 635, 99, 655, 100, 675, 102, 674, 117, 654, 116, 634, 115, 614, 114], "rec": [3, 14, 14, 17, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 2222.0, "bbox": [101.0, 267.0, 101.0, 22.0], "category_id": 1, "id": 4268, "image_id": 196, "iscrowd": 0, "bezier_pts": [101, 267, 134, 267, 167, 267, 200, 267, 201, 288, 168, 288, 135, 288, 102, 288], "rec": [17, 4, 18, 4, 17, 21, 4, 3, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1120.0, "bbox": [462.0, 93.0, 56.0, 20.0], "category_id": 1, "id": 4269, "image_id": 196, "iscrowd": 0, "bezier_pts": [462, 93, 480, 93, 498, 94, 516, 95, 517, 112, 499, 111, 481, 110, 463, 110], "rec": [10, 4, 4, 15, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1064.0, "bbox": [522.0, 96.0, 56.0, 19.0], "category_id": 1, "id": 4270, "image_id": 196, "iscrowd": 0, "bezier_pts": [523, 96, 541, 96, 559, 96, 577, 97, 576, 114, 558, 113, 540, 113, 522, 113], "rec": [2, 11, 4, 0, 17, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1638.0, "bbox": [198.0, 268.0, 78.0, 21.0], "category_id": 1, "id": 4271, "image_id": 196, "iscrowd": 0, "bezier_pts": [198, 270, 223, 269, 248, 268, 274, 268, 275, 287, 250, 287, 225, 287, 200, 288], "rec": [18, 4, 0, 19, 8, 13, 6, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1340.0, "bbox": [897.0, 104.0, 67.0, 20.0], "category_id": 1, "id": 4272, "image_id": 196, "iscrowd": 0, "bezier_pts": [901, 104, 921, 104, 942, 105, 963, 106, 960, 123, 939, 122, 918, 121, 897, 121], "rec": [9, 0, 13, 6, 0, 13, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1026.0, "bbox": [958.0, 108.0, 57.0, 18.0], "category_id": 1, "id": 4273, "image_id": 196, "iscrowd": 0, "bezier_pts": [961, 108, 978, 109, 996, 110, 1014, 111, 1012, 125, 994, 124, 976, 123, 958, 123], "rec": [1, 4, 17, 3, 8, 17, 8, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 736.0, "bbox": [1010.0, 110.0, 46.0, 16.0], "category_id": 1, "id": 4274, "image_id": 196, "iscrowd": 0, "bezier_pts": [1011, 110, 1025, 110, 1040, 111, 1055, 112, 1054, 125, 1039, 124, 1024, 123, 1010, 123], "rec": [3, 4, 10, 0, 19, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 731.0, "bbox": [1056.0, 112.0, 43.0, 17.0], "category_id": 1, "id": 4275, "image_id": 196, "iscrowd": 0, "bezier_pts": [1057, 112, 1070, 112, 1084, 113, 1098, 114, 1097, 128, 1083, 127, 1069, 126, 1056, 126], "rec": [15, 8, 13, 19, 20, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1504.0, "bbox": [590.0, 103.0, 47.0, 32.0], "category_id": 1, "id": 800, "image_id": 833, "iscrowd": 0, "bezier_pts": [592, 103, 606, 105, 621, 107, 636, 110, 633, 134, 618, 131, 604, 129, 590, 127], "rec": [4, 23, 8, 19, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1431.0, "bbox": [537.0, 220.0, 53.0, 27.0], "category_id": 1, "id": 801, "image_id": 833, "iscrowd": 0, "bezier_pts": [539, 220, 555, 221, 572, 223, 589, 225, 587, 246, 570, 244, 553, 242, 537, 241], "rec": [0, 11, 8, 2, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1740.0, "bbox": [355.0, 323.0, 87.0, 20.0], "category_id": 1, "id": 802, "image_id": 833, "iscrowd": 0, "bezier_pts": [356, 324, 384, 323, 412, 323, 441, 323, 440, 341, 411, 341, 383, 341, 355, 342], "rec": [1, 11, 4, 13, 7, 4, 8, 12, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1440.0, "bbox": [350.0, 342.0, 72.0, 20.0], "category_id": 1, "id": 803, "image_id": 833, "iscrowd": 0, "bezier_pts": [351, 345, 374, 344, 397, 343, 421, 342, 420, 357, 396, 358, 373, 359, 350, 361], "rec": [0, 21, 4, 13, 20, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 78500.0, "bbox": [281.0, 127.0, 628.0, 125.0], "category_id": 1, "id": 1560, "image_id": 790, "iscrowd": 0, "bezier_pts": [281, 175, 488, 159, 695, 143, 902, 127, 908, 203, 701, 219, 494, 235, 287, 251], "rec": [22, 0, 17, 4, 7, 14, 20, 18, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 3159.0, "bbox": [270.0, 444.0, 81.0, 39.0], "category_id": 1, "id": 1561, "image_id": 790, "iscrowd": 0, "bezier_pts": [273, 444, 298, 445, 324, 446, 350, 448, 347, 482, 321, 480, 295, 478, 270, 477], "rec": [18, 0, 11, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 902.0, "bbox": [138.0, 367.0, 41.0, 22.0], "category_id": 1, "id": 1562, "image_id": 790, "iscrowd": 0, "bezier_pts": [140, 368, 152, 367, 165, 367, 178, 367, 176, 387, 163, 387, 150, 387, 138, 388], "rec": [18, 0, 11, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 4588.0, "bbox": [491.0, 197.0, 148.0, 31.0], "category_id": 1, "id": 4322, "image_id": 836, "iscrowd": 0, "bezier_pts": [491, 197, 539, 197, 588, 197, 637, 198, 638, 227, 589, 226, 540, 226, 492, 226], "rec": [18, 8, 13, 6, 0, 15, 14, 17, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 3570.0, "bbox": [503.0, 226.0, 119.0, 30.0], "category_id": 1, "id": 4323, "image_id": 836, "iscrowd": 0, "bezier_pts": [507, 226, 545, 226, 583, 227, 621, 228, 618, 255, 579, 254, 541, 253, 503, 253], "rec": [0, 8, 17, 11, 8, 13, 4, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}], "categories": [{"id": 1, "name": "text", "supercategory": "beverage", "keypoints": ["mean", "xmin", "x2", "x3", "xmax", "ymin", "y2", "y3", "ymax", "cross"]}]} -------------------------------------------------------------------------------- /adet/data/detection_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | from detectron2.structures import Instances 6 | from detectron2.data import transforms as T 7 | from detectron2.data.detection_utils import \ 8 | annotations_to_instances as d2_anno_to_inst 9 | from detectron2.data.detection_utils import \ 10 | transform_instance_annotations as d2_transform_inst_anno 11 | from .augmentation import Pad, ResizeShortestEdgeWithRecord 12 | from .augmentation import OneOf, RandomSharpness, RandomEqualize 13 | import random 14 | 15 | 16 | def transform_instance_annotations( 17 | annotation, transforms, image_size, *, keypoint_hflip_indices=None 18 | ): 19 | 20 | annotation = d2_transform_inst_anno( 21 | annotation, 22 | transforms, 23 | image_size, 24 | keypoint_hflip_indices=keypoint_hflip_indices, 25 | ) 26 | 27 | if "beziers" in annotation: 28 | beziers = transform_ctrl_pnts_annotations(annotation["beziers"], transforms) 29 | annotation["beziers"] = beziers 30 | 31 | if "polyline" in annotation: 32 | polys = transform_ctrl_pnts_annotations(annotation["polyline"], transforms) 33 | annotation["polyline"] = polys 34 | 35 | if "boundary" in annotation: 36 | boundary = transform_ctrl_pnts_annotations(annotation["boundary"], transforms) 37 | annotation["boundary"] = boundary 38 | 39 | return annotation 40 | 41 | 42 | def transform_ctrl_pnts_annotations(pnts, transforms): 43 | """ 44 | Transform keypoint annotations of an image. 45 | 46 | Args: 47 | beziers (list[float]): Nx16 float in Detectron2 Dataset format. 48 | transforms (TransformList): 49 | """ 50 | # (N*2,) -> (N, 2) 51 | pnts = np.asarray(pnts, dtype="float64").reshape(-1, 2) 52 | pnts = transforms.apply_coords(pnts).reshape(-1) 53 | 54 | # This assumes that HorizFlipTransform is the only one that does flip 55 | do_hflip = ( 56 | sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 57 | ) 58 | if do_hflip: 59 | raise ValueError("Flipping text data is not supported (also disencouraged).") 60 | 61 | return pnts 62 | 63 | 64 | def annotations_to_instances(annos, image_size, mask_format="polygon"): 65 | """for line only annotations""" 66 | # instance = Instances(image_size) 67 | # 68 | # classes = [int(obj["category_id"]) for obj in annos] 69 | # classes = torch.tensor(classes, dtype=torch.int64) 70 | # instance.gt_classes = classes 71 | 72 | instance = d2_anno_to_inst(annos, image_size, mask_format) 73 | 74 | if not annos: 75 | return instance 76 | 77 | # add attributes 78 | if "beziers" in annos[0]: 79 | beziers = [obj.get("beziers", []) for obj in annos] 80 | instance.beziers = torch.as_tensor(np.array(beziers), dtype=torch.float32) 81 | 82 | if "polyline" in annos[0]: 83 | polys = [obj.get("polyline", []) for obj in annos] 84 | instance.polyline = torch.as_tensor(np.array(polys), dtype=torch.float32) 85 | 86 | if "boundary" in annos[0]: 87 | boundary = [obj.get("boundary", []) for obj in annos] 88 | instance.boundary = torch.as_tensor(np.array(boundary), dtype=torch.float32) 89 | 90 | if "text" in annos[0]: 91 | texts = [obj.get("text", []) for obj in annos] 92 | instance.texts = torch.as_tensor(np.array(texts), dtype=torch.int32) 93 | 94 | return instance 95 | 96 | 97 | def build_augmentation(cfg, is_train): 98 | """ 99 | With option to don't use hflip 100 | 101 | Returns: 102 | list[Augmentation] 103 | """ 104 | if is_train: 105 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 106 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 107 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 108 | else: 109 | min_size = cfg.INPUT.MIN_SIZE_TEST 110 | max_size = cfg.INPUT.MAX_SIZE_TEST 111 | sample_style = "choice" 112 | if sample_style == "range": 113 | assert ( 114 | len(min_size) == 2 115 | ), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) 116 | 117 | logger = logging.getLogger(__name__) 118 | 119 | augmentation = [] 120 | augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 121 | 122 | if is_train: 123 | augmentation.append(T.RandomContrast(0.3, 1.7)) 124 | augmentation.append(T.RandomBrightness(0.3, 1.7)) 125 | augmentation.append(T.RandomLighting(random.random() + 0.5)) 126 | augmentation.append(T.RandomSaturation(0.3, 1.7)) 127 | logger.info("Augmentations used in training: " + str(augmentation)) 128 | if cfg.MODEL.BACKBONE.NAME == "build_vitaev2_backbone": 129 | augmentation.append(Pad(divisible_size=32)) 130 | return augmentation 131 | 132 | 133 | def build_augmentation_strong(cfg, is_train): 134 | if is_train: 135 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 136 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 137 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 138 | else: 139 | min_size = cfg.INPUT.MIN_SIZE_TEST 140 | max_size = cfg.INPUT.MAX_SIZE_TEST 141 | sample_style = "choice" 142 | if sample_style == "range": 143 | assert ( 144 | len(min_size) == 2 145 | ), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) 146 | 147 | logger = logging.getLogger(__name__) 148 | 149 | augmentation = [] 150 | augmentation.append(ResizeShortestEdgeWithRecord(min_size, max_size, sample_style)) 151 | 152 | if is_train: 153 | augmentation.append(T.RandomContrast(0.3, 1.7)) 154 | augmentation.append(T.RandomBrightness(0.3, 1.7)) 155 | augmentation.append(T.RandomLighting(random.random() + 0.5)) 156 | augmentation.append(T.RandomSaturation(0.3, 1.7)) 157 | augmentation.append(OneOf([RandomSharpness(0.5, 1.5), RandomEqualize(p=0.3)])) 158 | # TODO: RandomSharpness, Random 159 | logger.info("Stong augmentations used in training for student: " + str(augmentation)) 160 | if cfg.MODEL.BACKBONE.NAME == "build_vitaev2_backbone": 161 | augmentation.append(Pad(divisible_size=32)) 162 | return augmentation 163 | 164 | 165 | def build_augmentation_weak(cfg, is_train): 166 | if is_train: 167 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 168 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 169 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 170 | else: 171 | min_size = cfg.INPUT.MIN_SIZE_TEST 172 | max_size = cfg.INPUT.MAX_SIZE_TEST 173 | sample_style = "choice" 174 | if sample_style == "range": 175 | assert ( 176 | len(min_size) == 2 177 | ), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) 178 | 179 | logger = logging.getLogger(__name__) 180 | 181 | augmentation = [] 182 | augmentation.append(ResizeShortestEdgeWithRecord(min_size, max_size, sample_style)) 183 | 184 | if is_train: 185 | logger.info("Weak augmentations used in training for teacher: " + str(augmentation)) 186 | if cfg.MODEL.BACKBONE.NAME == "build_vitaev2_backbone": 187 | augmentation.append(Pad(divisible_size=32)) 188 | return augmentation 189 | 190 | 191 | build_transform_gen = build_augmentation 192 | """ 193 | Alias for backward-compatibility. 194 | """ 195 | -------------------------------------------------------------------------------- /adet/layers/csrc/DeformAttn/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /adet/utils/curve_utils.py: -------------------------------------------------------------------------------- 1 | # borrow from https://github.com/voldemortX/pytorch-auto-drive/blob/master/utils/curve_utils.py 2 | 3 | import torch 4 | import numpy as np 5 | from scipy.interpolate import splprep, splev 6 | from scipy.special import comb as n_over_k 7 | 8 | 9 | def upcast(t): 10 | # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type 11 | # https://github.com/pytorch/vision/pull/3383 12 | if t.is_floating_point(): 13 | return t if t.dtype in (torch.float32, torch.float64) else t.float() 14 | else: 15 | return t if t.dtype in (torch.int32, torch.int64) else t.int() 16 | 17 | 18 | class BezierCurve(object): 19 | # Define Bezier curves for curve fitting 20 | def __init__(self, order, num_sample_points=50): 21 | self.num_point = order + 1 22 | self.control_points = [] 23 | self.bezier_coeff = self.get_bezier_coefficient() 24 | self.num_sample_points = num_sample_points 25 | self.c_matrix = self.get_bernstein_matrix() 26 | 27 | def get_bezier_coefficient(self): 28 | Mtk = lambda n, t, k: t ** k * (1 - t) ** (n - k) * n_over_k(n, k) 29 | BezierCoeff = lambda ts: [[Mtk(self.num_point - 1, t, k) for k in range(self.num_point)] for t in ts] 30 | 31 | return BezierCoeff 32 | 33 | def interpolate_lane(self, x, y, n=50): 34 | # Spline interpolation of a lane. Used on the predictions 35 | assert len(x) == len(y) 36 | 37 | tck, _ = splprep([x, y], s=0, t=n, k=min(3, len(x) - 1)) 38 | 39 | u = np.linspace(0., 1., n) 40 | return np.array(splev(u, tck)).T 41 | 42 | def get_control_points(self, x, y, interpolate=False): 43 | if interpolate: 44 | points = self.interpolate_lane(x, y) 45 | x = np.array([x for x, _ in points]) 46 | y = np.array([y for _, y in points]) 47 | 48 | middle_points = self.get_middle_control_points(x, y) 49 | for idx in range(0, len(middle_points) - 1, 2): 50 | self.control_points.append([middle_points[idx], middle_points[idx + 1]]) 51 | 52 | def get_bernstein_matrix(self): 53 | tokens = np.linspace(0, 1, self.num_sample_points) 54 | c_matrix = self.bezier_coeff(tokens) 55 | return np.array(c_matrix) 56 | 57 | def save_control_points(self): 58 | return self.control_points 59 | 60 | def assign_control_points(self, control_points): 61 | self.control_points = control_points 62 | 63 | def quick_sample_point(self, image_size=None): 64 | control_points_matrix = np.array(self.control_points) 65 | sample_points = self.c_matrix.dot(control_points_matrix) 66 | if image_size is not None: 67 | sample_points[:, 0] = sample_points[:, 0] * image_size[-1] 68 | sample_points[:, -1] = sample_points[:, -1] * image_size[0] 69 | return sample_points 70 | 71 | def get_sample_point(self, n=50, image_size=None): 72 | ''' 73 | :param n: the number of sampled points 74 | :return: a list of sampled points 75 | ''' 76 | t = np.linspace(0, 1, n) 77 | coeff_matrix = np.array(self.bezier_coeff(t)) 78 | control_points_matrix = np.array(self.control_points) 79 | sample_points = coeff_matrix.dot(control_points_matrix) 80 | if image_size is not None: 81 | sample_points[:, 0] = sample_points[:, 0] * image_size[-1] 82 | sample_points[:, -1] = sample_points[:, -1] * image_size[0] 83 | 84 | return sample_points 85 | 86 | def get_middle_control_points(self, x, y): 87 | dy = y[1:] - y[:-1] 88 | dx = x[1:] - x[:-1] 89 | dt = (dx ** 2 + dy ** 2) ** 0.5 90 | t = dt / dt.sum() 91 | t = np.hstack(([0], t)) 92 | t = t.cumsum() 93 | data = np.column_stack((x, y)) 94 | Pseudoinverse = np.linalg.pinv(self.bezier_coeff(t)) # (9,4) -> (4,9) 95 | control_points = Pseudoinverse.dot(data) # (4,9)*(9,2) -> (4,2) 96 | medi_ctp = control_points[:, :].flatten().tolist() 97 | 98 | return medi_ctp 99 | 100 | 101 | class BezierSampler(torch.nn.Module): 102 | # Fast Batch Bezier sampler 103 | def __init__(self, num_sample_points): 104 | super().__init__() 105 | self.num_control_points = 4 106 | self.num_sample_points = num_sample_points 107 | self.control_points = [] 108 | self.bezier_coeff = self.get_bezier_coefficient() 109 | self.bernstein_matrix = self.get_bernstein_matrix() 110 | 111 | def get_bezier_coefficient(self): 112 | Mtk = lambda n, t, k: t ** k * (1 - t) ** (n - k) * n_over_k(n, k) 113 | BezierCoeff = lambda ts: [[Mtk(3, t, k) for k in range(4)] for t in ts] 114 | return BezierCoeff 115 | 116 | def get_bernstein_matrix(self): 117 | t = torch.linspace(0, 1, self.num_sample_points) 118 | c_matrix = torch.tensor(self.bezier_coeff(t)) 119 | return c_matrix # (num_sample_points, 4) 120 | 121 | def get_sample_points(self, control_points_matrix): 122 | if control_points_matrix.numel() == 0: 123 | return control_points_matrix # Looks better than a torch.Tensor 124 | if self.bernstein_matrix.device != control_points_matrix.device: 125 | self.bernstein_matrix = self.bernstein_matrix.to(control_points_matrix.device) 126 | 127 | return upcast(self.bernstein_matrix).matmul(upcast(control_points_matrix)) 128 | 129 | 130 | @torch.no_grad() 131 | def get_valid_points(points): 132 | # ... x 2 133 | if points.numel() == 0: 134 | return torch.tensor([1], dtype=torch.bool, device=points.device) 135 | return (points[..., 0] > 0) * (points[..., 0] < 1) * (points[..., 1] > 0) * (points[..., 1] < 1) 136 | 137 | 138 | @torch.no_grad() 139 | def cubic_bezier_curve_segment(control_points, sample_points): 140 | # Cut a batch of cubic bezier curves to its in-image segments (assume at least 2 valid sample points per curve). 141 | # Based on De Casteljau's algorithm, formula for cubic bezier curve is derived by: 142 | # https://stackoverflow.com/a/11704152/15449902 143 | # control_points: B x 4 x 2 144 | # sample_points: B x N x 2 145 | if control_points.numel() == 0 or sample_points.numel() == 0: 146 | return control_points 147 | B, N = sample_points.shape[:-1] 148 | valid_points = get_valid_points(sample_points) # B x N, bool 149 | t = torch.linspace(0.0, 1.0, steps=N, dtype=sample_points.dtype, device=sample_points.device) 150 | 151 | # First & Last valid index (B) 152 | # Get unique values for deterministic behaviour on cuda: 153 | # https://pytorch.org/docs/1.6.0/generated/torch.max.html?highlight=max#torch.max 154 | t0 = t[(valid_points + torch.arange(N, device=valid_points.device).flip([0]) * valid_points).max(dim=-1).indices] 155 | t1 = t[(valid_points + torch.arange(N, device=valid_points.device) * valid_points).max(dim=-1).indices] 156 | 157 | # Generate transform matrix (old control points -> semi_5s control points = linear transform) 158 | u0 = 1 - t0 # B 159 | u1 = 1 - t1 # B 160 | transform_matrix_c = [torch.stack([u0 ** (3 - i) * u1 ** i for i in range(4)], dim=-1), 161 | torch.stack([3 * t0 * u0 ** 2, 162 | 2 * t0 * u0 * u1 + u0 ** 2 * t1, 163 | t0 * u1 ** 2 + 2 * u0 * u1 * t1, 164 | 3 * t1 * u1 ** 2], dim=-1), 165 | torch.stack([3 * t0 ** 2 * u0, 166 | t0 ** 2 * u1 + 2 * t0 * t1 * u0, 167 | 2 * t0 * t1 * u1 + t1 ** 2 * u0, 168 | 3 * t1 ** 2 * u1], dim=-1), 169 | torch.stack([t0 ** (3 - i) * t1 ** i for i in range(4)], dim=-1)] 170 | transform_matrix = torch.stack(transform_matrix_c, dim=-2).transpose(-2, -1) # B x 4 x 4, f**k this! 171 | transform_matrix = transform_matrix.unsqueeze(1).expand(B, 2, 4, 4) 172 | 173 | # Matrix multiplication 174 | res = transform_matrix.matmul(control_points.permute(0, 2, 1).unsqueeze(-1)) # B x 2 x 4 x 1 175 | 176 | return res.squeeze(-1).permute(0, 2, 1) 177 | 178 | -------------------------------------------------------------------------------- /adet/config/semi_defaults.py: -------------------------------------------------------------------------------- 1 | from detectron2.config.defaults import _C 2 | from detectron2.config import CfgNode as CN 3 | 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # Additional Configs 7 | # ---------------------------------------------------------------------------- # 8 | _C.MODEL.MOBILENET = False 9 | _C.MODEL.BACKBONE.ANTI_ALIAS = False 10 | _C.MODEL.RESNETS.DEFORM_INTERVAL = 1 11 | _C.INPUT.HFLIP_TRAIN = False 12 | _C.INPUT.CROP.CROP_INSTANCE = True 13 | _C.INPUT.ROTATE = True 14 | 15 | _C.MODEL.BASIS_MODULE = CN() 16 | _C.MODEL.BASIS_MODULE.NAME = "ProtoNet" 17 | _C.MODEL.BASIS_MODULE.NUM_BASES = 4 18 | _C.MODEL.BASIS_MODULE.LOSS_ON = False 19 | _C.MODEL.BASIS_MODULE.ANN_SET = "coco" 20 | _C.MODEL.BASIS_MODULE.CONVS_DIM = 128 21 | _C.MODEL.BASIS_MODULE.IN_FEATURES = ["p3", "p4", "p5"] 22 | _C.MODEL.BASIS_MODULE.NORM = "SyncBN" 23 | _C.MODEL.BASIS_MODULE.NUM_CONVS = 3 24 | _C.MODEL.BASIS_MODULE.COMMON_STRIDE = 8 25 | _C.MODEL.BASIS_MODULE.NUM_CLASSES = 80 26 | _C.MODEL.BASIS_MODULE.LOSS_WEIGHT = 0.3 27 | 28 | _C.MODEL.TOP_MODULE = CN() 29 | _C.MODEL.TOP_MODULE.NAME = "conv" 30 | _C.MODEL.TOP_MODULE.DIM = 16 31 | 32 | 33 | # ---------------------------------------------------------------------------- # 34 | # BAText Options 35 | # ---------------------------------------------------------------------------- # 36 | _C.MODEL.BATEXT = CN() 37 | _C.MODEL.BATEXT.VOC_SIZE = 96 38 | _C.MODEL.BATEXT.NUM_CHARS = 25 39 | _C.MODEL.BATEXT.POOLER_RESOLUTION = (8, 32) 40 | _C.MODEL.BATEXT.IN_FEATURES = ["p2", "p3", "p4"] 41 | _C.MODEL.BATEXT.POOLER_SCALES = (0.25, 0.125, 0.0625) 42 | _C.MODEL.BATEXT.SAMPLING_RATIO = 1 43 | _C.MODEL.BATEXT.CONV_DIM = 256 44 | _C.MODEL.BATEXT.NUM_CONV = 2 45 | _C.MODEL.BATEXT.RECOGNITION_LOSS = "ctc" 46 | _C.MODEL.BATEXT.RECOGNIZER = "attn" 47 | _C.MODEL.BATEXT.CANONICAL_SIZE = 96 # largest min_size for level 3 (stride=8) 48 | _C.MODEL.BATEXT.USE_COORDCONV = False 49 | _C.MODEL.BATEXT.USE_AET = False 50 | _C.MODEL.BATEXT.CUSTOM_DICT = "" # Path to the class file. 51 | 52 | 53 | # ---------------------------------------------------------------------------- # 54 | # SwinTransformer Options 55 | # ---------------------------------------------------------------------------- # 56 | _C.MODEL.SWIN = CN() 57 | _C.MODEL.SWIN.TYPE = 'tiny' 58 | _C.MODEL.SWIN.DROP_PATH_RATE = 0.2 59 | 60 | # ---------------------------------------------------------------------------- # 61 | # ViTAE-v2 Options 62 | # ---------------------------------------------------------------------------- # 63 | _C.MODEL.ViTAEv2 = CN() 64 | _C.MODEL.ViTAEv2.TYPE = 'vitaev2_s' 65 | _C.MODEL.ViTAEv2.DROP_PATH_RATE = 0.2 66 | 67 | # ---------------------------------------------------------------------------- # 68 | # (Deformable) Transformer Options 69 | # ---------------------------------------------------------------------------- # 70 | _C.MODEL.TRANSFORMER = CN() 71 | _C.MODEL.TRANSFORMER.ENABLED = False 72 | _C.MODEL.TRANSFORMER.INFERENCE_TH_TEST = 0.4 73 | _C.MODEL.TRANSFORMER.AUX_LOSS = True 74 | _C.MODEL.TRANSFORMER.ENC_LAYERS = 6 75 | _C.MODEL.TRANSFORMER.DEC_LAYERS = 6 76 | _C.MODEL.TRANSFORMER.EMB_LAYERS = 3 77 | _C.MODEL.TRANSFORMER.DIM_FEEDFORWARD = 1024 78 | _C.MODEL.TRANSFORMER.HIDDEN_DIM = 256 79 | _C.MODEL.TRANSFORMER.DROPOUT = 0.0 80 | _C.MODEL.TRANSFORMER.NHEADS = 8 81 | _C.MODEL.TRANSFORMER.NUM_QUERIES = 100 82 | _C.MODEL.TRANSFORMER.ENC_N_POINTS = 4 83 | _C.MODEL.TRANSFORMER.DEC_N_POINTS = 4 84 | _C.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE = 6.283185307179586 # 2 PI 85 | _C.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS = 4 86 | _C.MODEL.TRANSFORMER.VOC_SIZE = 37 # a-z + 0-9 + unknown 87 | _C.MODEL.TRANSFORMER.CUSTOM_DICT = "" # Path to the character class file. 88 | _C.MODEL.TRANSFORMER.NUM_POINTS = 25 # the number of point queries for each instance 89 | _C.MODEL.TRANSFORMER.TEMPERATURE = 10000 90 | _C.MODEL.TRANSFORMER.BOUNDARY_HEAD = True # True: with boundary predictions, False: only with center lines 91 | _C.MODEL.TRANSFORMER.SFEM = False 92 | _C.MODEL.TRANSFORMER.SFEM_DECODER_SA_TYPE = 'sa' 93 | _C.MODEL.TRANSFORMER.SFEM_DECODER_MODULE_SEQ = ['sa', 'ca', 'ffn'] 94 | 95 | _C.MODEL.TRANSFORMER.LOSS = CN() 96 | _C.MODEL.TRANSFORMER.LOSS.AUX_LOSS = True 97 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_ALPHA = 0.25 98 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_GAMMA = 2.0 99 | # bezier proposal loss 100 | _C.MODEL.TRANSFORMER.LOSS.BEZIER_CLASS_WEIGHT = 1.0 101 | _C.MODEL.TRANSFORMER.LOSS.BEZIER_COORD_WEIGHT = 1.0 102 | _C.MODEL.TRANSFORMER.LOSS.BEZIER_SAMPLE_POINTS = 25 103 | # supervise the sampled on-curve points but not 4 Bezier control points 104 | 105 | # target loss 106 | _C.MODEL.TRANSFORMER.LOSS.POINT_CLASS_WEIGHT = 1.0 107 | _C.MODEL.TRANSFORMER.LOSS.POINT_COORD_WEIGHT = 1.0 108 | _C.MODEL.TRANSFORMER.LOSS.POINT_TEXT_WEIGHT = 0.5 109 | _C.MODEL.TRANSFORMER.LOSS.BOUNDARY_WEIGHT = 0.5 110 | _C.MODEL.TRANSFORMER.LOSS.ADP_POINT_COORD_WEIGHT = 1.0 111 | _C.MODEL.TRANSFORMER.LOSS.ADP_POINT_TEXT_WEIGHT = 0.5 112 | _C.MODEL.TRANSFORMER.LOSS.ADP_BOUNDARY_WEIGHT = 0.5 113 | _C.MODEL.TRANSFORMER.LOSS.USE_DYNAMIC_K=False 114 | _C.MODEL.TRANSFORMER.LOSS.DET_ADAPTIVE_TYPE= 'edit_distance' 115 | _C.MODEL.TRANSFORMER.LOSS.REC_ADAPTIVE_TYPE= 'polygon_diou' 116 | _C.MODEL.TRANSFORMER.LOSS.PRECISE_TEACHER = True # CTC compare for precise rec label both in DCE & RHEM 117 | #instance confidence loss 118 | _C.MODEL.TRANSFORMER.LOSS.INSTANCE_CLASS_WEIGHT = 1.0 119 | _C.MODEL.TRANSFORMER.LOSS.INSTANCE_REC_WEIGHT = 1.0 120 | #Hard Sample Mining loss 121 | _C.MODEL.TRANSFORMER.LOSS.LEVEN_ALPHA = 20 122 | _C.MODEL.TRANSFORMER.LOSS.COST_ALPHA = 20 123 | #O2M MATCH NUM 124 | _C.MODEL.TRANSFORMER.LOSS.O2M_MATCH_NUM = 5 125 | _C.MODEL.TRANSFORMER.LOSS.O2M_ENC_MATCH_NUM = 13 126 | _C.MODEL.TRANSFORMER.LOSS.USE_SUP_O2M = False 127 | 128 | # ---------------------------------------------------------------------------- # 129 | # Semi Supervised Learning Options 130 | # ---------------------------------------------------------------------------- # 131 | _C.SSL = CN() 132 | _C.SSL.MODE = "mean-teacher" 133 | _C.SSL.SEMI_WRAPPER = "MultiStreamSpotter" 134 | _C.SSL.PSEUDO_LABEL_INITIAL_SCORE_THR= 0.4 135 | _C.SSL.PSEUDO_LABEL_FINAL_SCORE_THR= 0.7 136 | _C.SSL.PSEUDO_LABEL_SIM_THR= 0.7 137 | _C.SSL.USE_COMBINED_THR= False 138 | _C.SSL.O2M_TEXT_O2O= False 139 | _C.SSL.USE_O2M_ENC = False 140 | _C.SSL.USE_SEPERATE_MATCHER=False 141 | _C.SSL.USE_SPOTTING_NMS=True 142 | _C.SSL.DYNAMIC_RATIO=1.0 143 | _C.SSL.GMM=None 144 | _C.SSL.SCORE_BUFFER=100 145 | _C.SSL.SCORE_LOGGER = 1000 146 | _C.SSL.USE_GMM_FILTER=False 147 | _C.SSL.MIN_PSEDO_BOX_SIZE=0 148 | _C.SSL.UNSUP_WEIGHT=4.0 149 | _C.SSL.CONSISTENCY_WEIGHT=1.0 150 | _C.SSL.AUG_QUERY=False 151 | _C.SSL.INFERENCE_ON = "teacher" 152 | _C.SSL.STEP_HOOK = False 153 | _C.SSL.WARM_UP = 0 154 | _C.SSL.USE_EMA = True 155 | _C.SSL.STAGE_WARM_UP = 100 156 | _C.SSL.EMA = CN() 157 | _C.SSL.EMA.MOMENTUM = 0.999 158 | _C.SSL.EMA.INTERVAL = 1 159 | _C.SSL.EMA.WARM_UP = 0 160 | _C.SSL.PL = CN() 161 | _C.SSL.PL.STAC = False 162 | _C.SSL.USE_CONSISTENCY = False 163 | _C.SSL.DECODER_ONLY = False 164 | _C.SSL.EXTRA_STUDENT_INFO = False 165 | _C.SSL.DECODER_LOSS = ["labels", "texts", "ctrl_points", "bd_points"] 166 | _C.SSL.O2O_DECODER_LOSS = ["labels", "texts", "ctrl_points", "bd_points"] 167 | _C.SSL.ASPECT_RATIO_GROUPING = True 168 | 169 | 170 | # Configs for cross domain semi data 171 | _C.INPUT.REPLACE =False 172 | _C.INPUT_UNLABEL = CN() 173 | _C.INPUT_UNLABEL.CROP = CN() 174 | _C.INPUT_UNLABEL.ROTATE = True 175 | _C.INPUT_UNLABEL.MIN_SIZE_TRAIN = None 176 | _C.INPUT_UNLABEL.MAX_SIZE_TRAIN = None 177 | _C.INPUT_UNLABEL.MIN_SIZE_TEST = None 178 | _C.INPUT_UNLABEL.MAX_SIZE_TEST = None 179 | 180 | # ---------------------------------------------------------------------------- # 181 | 182 | 183 | _C.SOLVER.OPTIMIZER = "ADAMW" 184 | _C.SOLVER.LR_BACKBONE = 1e-5 185 | _C.SOLVER.LR_BACKBONE_NAMES = [] 186 | _C.SOLVER.LR_LINEAR_PROJ_NAMES = [] 187 | _C.SOLVER.LR_LINEAR_PROJ_MULT = 0.1 188 | _C.SOLVER.SOURCE_RATIO = [1,1] 189 | _C.SOLVER.FIND_UNUSED_PARAMETERS = False 190 | 191 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015) 192 | # 1 - Full lexicon (for totaltext) 193 | _C.TEST.LEXICON_TYPE = 1 194 | 195 | # Configs for cross domain semi data 196 | _C.INPUT.REPLACE =False 197 | _C.INPUT_UNLABEL = CN() 198 | _C.INPUT_UNLABEL.CROP = CN() 199 | _C.INPUT_UNLABEL.ROTATE = True 200 | _C.INPUT_UNLABEL.MIN_SIZE_TRAIN = None 201 | _C.INPUT_UNLABEL.MAX_SIZE_TRAIN = None 202 | _C.INPUT_UNLABEL.MIN_SIZE_TEST = None 203 | _C.INPUT_UNLABEL.MAX_SIZE_TEST = None -------------------------------------------------------------------------------- /adet/utils/polygon_utils.py: -------------------------------------------------------------------------------- 1 | import shapely 2 | from shapely.geometry import LinearRing,Polygon, MultiPolygon, MultiPoint,mapping 3 | import numpy as np 4 | import cv2 5 | from scipy.spatial import ConvexHull 6 | from shapely.validation import make_valid 7 | import torch 8 | from rapidfuzz import string_metric 9 | import matplotlib.pyplot as plt 10 | 11 | def _ctc_decode_recognition_pred_logits(voc_size,CTLABELS, rec): 12 | last_char = '###' 13 | s = '' 14 | for c in rec: 15 | c = int(c) 16 | if c < voc_size - 1: 17 | if last_char != c: 18 | if voc_size == 37 or voc_size == 96: 19 | s += CTLABELS[c] 20 | last_char = c 21 | else: 22 | s += str(chr(CTLABELS[c])) 23 | last_char = c 24 | else: 25 | last_char = '###' 26 | return s 27 | 28 | def _ctc_decode_recognition_pred(voc_size,CTLABELS, rec): 29 | s = '' 30 | for c in rec: 31 | c = int(c) 32 | if c < voc_size - 1: 33 | if voc_size == 37 or voc_size == 96: 34 | s += CTLABELS[c] 35 | 36 | else: 37 | s += str(chr(CTLABELS[c])) 38 | 39 | return s 40 | 41 | 42 | def compare_recs_unequal(selected_recs,target_recs,voc_size): 43 | 44 | s_rec = selected_recs[selected_recs != voc_size] 45 | t_rec = target_recs[target_recs != voc_size] 46 | 47 | return not torch.equal(s_rec, t_rec) 48 | 49 | 50 | def plot_polygons_and_iou(poly1, poly2): 51 | """ 52 | Plots two polygons with different colors and shows their intersection and IoU. 53 | 54 | Parameters: 55 | - poly1: The first polygon (shapely.geometry.Polygon) 56 | - poly2: The second polygon (shapely.geometry.Polygon) 57 | """ 58 | # Calculate the intersection and IoU 59 | intersection = poly1.intersection(poly2) 60 | iou = intersection.area / poly1.area if poly1.area > 0 else 0.0 61 | 62 | # Extract x and y coordinates for plotting 63 | x1, y1 = poly1.exterior.xy 64 | x2, y2 = poly2.exterior.xy 65 | xi, yi = intersection.exterior.xy if intersection.is_valid else ([], []) 66 | 67 | # Create a figure and axis 68 | fig, ax = plt.subplots(figsize=(8, 8)) 69 | 70 | # Plot the first polygon in red 71 | ax.fill(x1, y1, color='red', alpha=0.5, label="Polygon 1") 72 | # Plot the second polygon in blue 73 | ax.fill(x2, y2, color='blue', alpha=0.5, label="Polygon 2") 74 | 75 | # Plot the intersection in green 76 | if intersection.is_valid: 77 | ax.fill(xi, yi, color='green', alpha=0.7, label="Intersection") 78 | 79 | # Set labels and title 80 | ax.set_xlabel("X") 81 | ax.set_ylabel("Y") 82 | ax.set_title(f"Polygon Intersection\nIoU = {iou:.4f}") 83 | 84 | # Show the legend 85 | ax.legend() 86 | 87 | # Show the plot 88 | plt.show() 89 | def calculate_iou_from_bds(selected_bds,target_bds): 90 | selected_target_poly = make_valid_poly(pnt_to_Polygon(selected_bds)) # 50 91 | bd_points_poly = make_valid_poly(pnt_to_Polygon(target_bds)) 92 | iou = get_intersection_over_union(selected_target_poly, bd_points_poly) 93 | return iou 94 | 95 | 96 | 97 | def SPOTTING_NMS(bds,scs,ctcs,recs,voc_size,iou_threshold=None): 98 | bds = bds.cpu().numpy() 99 | #NMS for instance per image 100 | if iou_threshold is None: 101 | iou_threshold = 0.7 102 | sorted_indices = sorted(range(len(scs)), key=lambda i: scs[i]+ctcs[i], reverse=True) 103 | 104 | selected_indices = [] 105 | while sorted_indices: 106 | best_index = sorted_indices[0] 107 | selected_indices.append(best_index) 108 | remaining_indices = [] 109 | for i in range(1, len(sorted_indices)): 110 | id = sorted_indices[i] 111 | matched_iou = calculate_iou_from_bds(bds[best_index],bds[id]) 112 | 113 | #剔除定位重复的样本 114 | valid_loc = matched_iou <= iou_threshold 115 | #对于重复transcript样本,剔除同一object,保留不同object 116 | valid_trans = compare_recs_unequal(recs[best_index], recs[id], voc_size=voc_size) or matched_iou ==0 117 | 118 | if valid_loc and valid_trans: 119 | remaining_indices.append(sorted_indices[i]) 120 | 121 | sorted_indices = remaining_indices 122 | 123 | 124 | return selected_indices 125 | 126 | def get_intersection(poly1, poly2): 127 | try: 128 | inter_area = poly1.intersection(poly2).area # 相交面积 129 | return inter_area 130 | except shapely.geos.TopologicalError: 131 | return 0 132 | 133 | 134 | def plot_polygons_and_iou(poly1, poly2, iou): 135 | """ 136 | Visualize two polygons with different colors, show intersection, and plot IoU. 137 | 138 | Parameters: 139 | - poly1: Polygon (shapely.geometry.Polygon) 140 | - poly2: Polygon (shapely.geometry.Polygon) 141 | - iou: float, Intersection over Union value 142 | """ 143 | fig, ax = plt.subplots() 144 | 145 | # Plot the first polygon with color 'blue' and alpha transparency 146 | x1, y1 = poly1.exterior.xy 147 | ax.fill(x1, y1, color='blue', alpha=0.5, label="Polygon 1") 148 | ax.plot(x1, y1, color='blue', lw=2) 149 | 150 | # Plot the second polygon with color 'red' and alpha transparency 151 | x2, y2 = poly2.exterior.xy 152 | ax.fill(x2, y2, color='red', alpha=0.5, label="Polygon 2") 153 | ax.plot(x2, y2, color='red', lw=2) 154 | 155 | # If the polygons intersect, plot the intersection area 156 | if poly1.intersects(poly2): 157 | inter = poly1.intersection(poly2) 158 | if isinstance(inter, MultiPolygon): 159 | for p in inter: 160 | x, y = p.exterior.xy 161 | ax.fill(x, y, color='green', alpha=0.3, label="Intersection") 162 | ax.plot(x, y, color='green', lw=2) 163 | else: 164 | x, y = inter.exterior.xy 165 | ax.fill(x, y, color='green', alpha=0.3, label="Intersection") 166 | ax.plot(x, y, color='green', lw=2) 167 | 168 | # Plot the points of each polygon with their index number 169 | # Polygon 1 points 170 | for i, (x, y) in enumerate(zip(x1, y1)): 171 | ax.text(x, y, f'{i}', fontsize=12, color='blue', ha='right', va='bottom') 172 | 173 | # Polygon 2 points 174 | for i, (x, y) in enumerate(zip(x2, y2)): 175 | ax.text(x, y, f'{i}', fontsize=12, color='red', ha='right', va='top') 176 | 177 | # Title with IoU 178 | ax.set_title(f'IoU: {iou:.2f}') 179 | 180 | # Set the aspect ratio of the plot to be equal 181 | ax.set_aspect('equal', adjustable='box') 182 | 183 | # Add grid and legend 184 | ax.grid(True) 185 | ax.legend() 186 | 187 | # Show the plot 188 | plt.show() 189 | def get_intersection_over_union_new(poly1, poly2): 190 | # poly1 = Polygon(poly1).convex_hull 191 | # poly2 = Polygon(poly2).convex_hull 192 | 193 | if not poly1.intersects(poly2): 194 | iou = 0 195 | else: 196 | try: 197 | inter_area = poly1.intersection(poly2).area # 相交面积 198 | iou = float(inter_area) / (poly1.area + poly2.area - inter_area) 199 | except shapely.geos.TopologicalError: 200 | print('shapely.geos.TopologicalError occured, iou set to 0') 201 | iou = 0 202 | # plot_polygons_and_iou(poly1, poly2, iou) 203 | return iou 204 | def get_intersection_over_union(poly1, poly2): 205 | # poly1 = Polygon(poly1).convex_hull 206 | # poly2 = Polygon(poly2).convex_hull 207 | if not poly1.intersects(poly2): 208 | iou = 0 209 | else: 210 | try: 211 | inter_area = poly1.intersection(poly2).area # 相交面积 212 | iou = float(inter_area) / (poly1.area + poly2.area - inter_area) 213 | except shapely.geos.TopologicalError: 214 | print('shapely.geos.TopologicalError occured, iou set to 0') 215 | iou = 0 216 | return iou 217 | 218 | def get_intersection_over_union_from_pnts(pnts1, pnts2): 219 | #pnts shape: 50,2 220 | 221 | poly1 = Polygon(pnts1).convex_hull # POLYGON ((0 0, 0 2, 2 2, 2 0, 0 0)) 222 | poly2 = Polygon(pnts2).convex_hull 223 | 224 | if not poly1.intersects(poly2): 225 | iou = 0 226 | else: 227 | try: 228 | inter_area = poly1.intersection(poly2).area # 相交面积 229 | iou = float(inter_area) / (poly1.area + poly2.area - inter_area) 230 | except shapely.geos.TopologicalError: 231 | print('shapely.geos.TopologicalError occured, iou set to 0') 232 | iou = 0 233 | return iou 234 | 235 | 236 | 237 | def build_clockwise_polygon(points): 238 | points_array = np.array(points) 239 | 240 | hull = cv2.ConvexHull(points_array) 241 | sorted_hull_points = points_array[hull.vertices] 242 | 243 | clockwise_polygon = sorted_hull_points.tolist() 244 | 245 | return clockwise_polygon 246 | 247 | 248 | def pnt_to_Polygon(bd_pnt): 249 | bd_pnt = np.hsplit(bd_pnt, 2) 250 | bd_pnt = np.vstack([bd_pnt[0], bd_pnt[1][::-1]]) 251 | return bd_pnt.tolist() 252 | 253 | def simplify_polygon(polygon_points, eps=1e-3, mode=1): 254 | polygon = Polygon(polygon_points) 255 | 256 | if mode == 1: 257 | polygon_new = polygon.buffer(0) 258 | elif mode == 2: 259 | polygon_new = shapely.simplify(polygon, eps) 260 | elif mode == 3: 261 | polygon_new = polygon.buffer(eps).buffer(-eps) 262 | elif mode == 4: 263 | polygon_new = shapely.validation.make_valid(polygon) 264 | polygon_new = list(polygon_new.geoms)[0] 265 | 266 | if isinstance(polygon_new, MultiPolygon): 267 | polygons = sorted([p for p in polygon_new.geoms], key=lambda polygon: len(polygon.exterior.coords), reverse=True) 268 | return np.array(polygons[0].exterior.coords) 269 | else: 270 | return np.array(polygon_new.exterior.coords) 271 | 272 | def make_valid_poly(pts): 273 | #pts -> valid Polygon 274 | #1.check valid Polygon 275 | pgt = Polygon(pts) 276 | if not pgt.is_valid: 277 | pts = simplify_polygon(pts,mode=1) 278 | pgt = Polygon(pts) 279 | 280 | if not pgt.is_valid: 281 | pgt = Polygon(pts).convex_hull # other-wise use convex instead, with fewer points enclosed orignial Poly 282 | pts = mapping(pgt)['coordinates'] 283 | 284 | # 2.make sure the pts are clockwise. 285 | pRing = LinearRing(pts) 286 | if pRing.is_ccw: 287 | pts.reverse() 288 | pRing = LinearRing(pts) 289 | assert not pRing.is_ccw,"Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards." 290 | pgt = Polygon(pts) 291 | 292 | return pgt -------------------------------------------------------------------------------- /adet/data/builtin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from detectron2.data.datasets.register_coco import register_coco_instances 4 | from detectron2.data.datasets.builtin_meta import _get_builtin_metadata 5 | from .datasets.text import register_text_instances, register_text_instances_ssl 6 | from adet.config import get_cfg, get_cfg_semi 7 | from detectron2.engine import default_argument_parser 8 | 9 | _PREDEFINED_SPLITS_PIC = { 10 | "pic_person_train": ("pic/image/train", "pic/annotations/train_person.json"), 11 | "pic_person_val": ("pic/image/val", "pic/annotations/val_person.json"), 12 | } 13 | 14 | metadata_pic = { 15 | "thing_classes": ["person"] 16 | } 17 | 18 | _PREDEFINED_SPLITS_TEXT = { 19 | # 37 voc_size 20 | "syntext1": ("syntext1/train_images", "syntext1/annotations/train_37voc.json"), 21 | "syntext2": ("syntext2/train_images", "syntext2/annotations/train_37voc.json"), 22 | "mlt": ("mlt2017/train_images", "mlt2017/train_37voc.json"), 23 | "totaltext_train": ("totaltext/train_images", "totaltext/train_37voc.json"), 24 | 25 | # semi-supervised settings 26 | "totaltext_train_0.5_label": ("totaltext/train_images", "totaltext/train_37voc_0.5_labeled.json"), 27 | "totaltext_train_0.5_unlabel": ("totaltext/train_images", "totaltext/train_37voc_0.5_unlabeled.json"), 28 | "totaltext_train_1_label": ("totaltext/train_images", "totaltext/train_37voc_1_labeled.json"), 29 | "totaltext_train_1_unlabel": ("totaltext/train_images", "totaltext/train_37voc_1_unlabeled.json"), 30 | "totaltext_train_2_label": ("totaltext/train_images", "totaltext/train_37voc_2_labeled.json"), 31 | "totaltext_train_2_unlabel": ("totaltext/train_images", "totaltext/train_37voc_2_unlabeled.json"), 32 | "totaltext_train_5_label": ("totaltext/train_images", "totaltext/train_37voc_5_labeled.json"), 33 | "totaltext_train_5_unlabel": ("totaltext/train_images", "totaltext/train_37voc_5_unlabeled.json"), 34 | "totaltext_train_10_label": ("totaltext/train_images", "totaltext/train_37voc_10_labeled.json"), 35 | "totaltext_train_10_unlabel": ("totaltext/train_images", "totaltext/train_37voc_10_unlabeled.json"), 36 | 37 | "totaltext_train_full_label": ("totaltext/train_images", "totaltext/train_37voc_full_labeled.json"), 38 | "totaltext_train_full_unlabel": ("totaltext/train_images", "totaltext/train_37voc_full_unlabeled.json"), 39 | 40 | "textocr1_unlabel": ("textocr/train_images", "textocr/train_37voc_1_unlabeled.json"), 41 | "textocr2_unlabel": ("textocr/train_images", "textocr/train_37voc_2_unlabeled.json"), 42 | "textocr_val": ("textocr/train_images", "textocr/train_37voc_2.json"), 43 | # "textocr_test": ("textocr/test_images", "textocr/train_37voc_2.json"), 44 | 45 | "textocr_train_100_label": ("textocr/train_images", "textocr/train_37voc_1_100_labeled.json"), 46 | "textocr_train_1000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_1000_unlabeled.json"), 47 | "textocr_train_2000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_2000_unlabeled.json"), 48 | "textocr_train_3000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_3000_unlabeled.json"), 49 | "textocr_train_4000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_4000_unlabeled.json"), 50 | "textocr_train_5000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_5000_unlabeled.json"), 51 | "textocr_train_10000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_10000_unlabeled.json"), 52 | "textocr_train_15000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_15000_unlabeled.json"), 53 | "textocr_train_20000_unlabel": ("textocr/train_images", "textocr/train_37voc_1_20000_unlabeled.json"), 54 | "textocr_train_5000_sim_unlabel": ("textocr/train_images", "COCO-Text/ocr_train_37voc_1_sim_5k_unlabeled.json"), 55 | 56 | 57 | "ic15_train_0.5_label": ("ic15/train_images", "ic15/train_37voc_0.5_labeled.json"), 58 | "ic15_train_0.5_unlabel": ("ic15/train_images", "ic15/train_37voc_0.5_unlabeled.json"), 59 | "ic15_train_1_label": ("ic15/train_images", "ic15/train_37voc_1_labeled.json"), 60 | "ic15_train_1_unlabel": ("ic15/train_images", "ic15/train_37voc_1_unlabeled.json"), 61 | "ic15_train_2_label": ("ic15/train_images", "ic15/train_37voc_2_labeled.json"), 62 | "ic15_train_2_unlabel": ("ic15/train_images", "ic15/train_37voc_2_unlabeled.json"), 63 | "ic15_train_5_label": ("ic15/train_images", "ic15/train_37voc_5_labeled.json"), 64 | "ic15_train_5_unlabel": ("ic15/train_images", "ic15/train_37voc_5_unlabeled.json"), 65 | "ic15_train_10_label": ("ic15/train_images", "ic15/train_37voc_10_labeled.json"), 66 | "ic15_train_10_unlabel": ("ic15/train_images", "ic15/train_37voc_10_unlabeled.json"), 67 | 68 | "ic15_train_full_label": ("ic15/train_images", "ic15/train_37voc_full_labeled.json"), 69 | "ic15_train_full_unlabel": ("ic15/train_images", "ic15/train_37voc_full_unlabeled.json"), 70 | 71 | "ctw1500_train_0.5_label": ("ctw1500/train_images", "ctw1500/train_96voc_0.5_labeled.json"), 72 | "ctw1500_train_0.5_unlabel": ("ctw1500/train_images", "ctw1500/train_96voc_0.5_unlabeled.json"), 73 | "ctw1500_train_1_label": ("ctw1500/train_images", "ctw1500/train_96voc_1_labeled.json"), 74 | "ctw1500_train_1_unlabel": ("ctw1500/train_images", "ctw1500/train_96voc_1_unlabeled.json"), 75 | "ctw1500_train_2_label": ("ctw1500/train_images", "ctw1500/train_96voc_2_labeled.json"), 76 | "ctw1500_train_2_unlabel": ("ctw1500/train_images", "ctw1500/train_96voc_2_unlabeled.json"), 77 | "ctw1500_train_5_label": ("ctw1500/train_images", "ctw1500/train_96voc_5_labeled.json"), 78 | "ctw1500_train_5_unlabel": ("ctw1500/train_images", "ctw1500/train_96voc_5_unlabeled.json"), 79 | "ctw1500_train_10_label": ("ctw1500/train_images", "ctw1500/train_96voc_10_labeled.json"), 80 | "ctw1500_train_10_unlabel": ("ctw1500/train_images", "ctw1500/train_96voc_10_unlabeled.json"), 81 | 82 | 83 | "ic13_train": ("ic13/train_images", "ic13/train_37voc.json"), 84 | "ic15_train": ("ic15/train_images", "ic15/train_37voc.json"), 85 | "textocr1": ("textocr/train_images", "textocr/train_37voc_1.json"), 86 | "textocr2": ("textocr/train_images", "textocr/train_37voc_2.json"), 87 | 88 | 89 | # 96 voc_size 90 | "syntext1_96voc": ("syntext1/train_images", "syntext1/annotations/train_96voc.json"), 91 | "syntext2_96voc": ("syntext2/train_images", "syntext2/annotations/train_96voc.json"), 92 | "mlt_96voc": ("mlt2017/train_images", "mlt2017/train_96voc.json"), 93 | "totaltext_train_96voc": ("totaltext/train_images", "totaltext/train_96voc.json"), 94 | "ic13_train_96voc": ("ic13/train_images", "ic13/train_96voc.json"), 95 | "ic15_train_96voc": ("ic15/train_images", "ic15/train_96voc.json"), 96 | "ctw1500_train_96voc": ("ctw1500/train_images", "ctw1500/train_96voc.json"), 97 | 98 | # Chinese 99 | "chnsyn_train": ("chnsyntext/syn_130k_images", "chnsyntext/chn_syntext.json"), 100 | "rects_train": ("ReCTS/ReCTS_train_images", "ReCTS/rects_train.json"), 101 | "rects_val": ("ReCTS/ReCTS_val_images", "ReCTS/rects_val.json"), 102 | "lsvt_train": ("LSVT/rename_lsvtimg_train", "LSVT/lsvt_train.json"), 103 | "art_train": ("ArT/rename_artimg_train", "ArT/art_train.json"), 104 | 105 | # evaluation, just for reading images, annotations may be empty 106 | "totaltext_test": ("totaltext/test_images", "totaltext/test.json"), 107 | "ic15_test": ("ic15/test_images", "ic15/test.json"), 108 | "ctw1500_test": ("ctw1500/test_images", "ctw1500/test.json"), 109 | # "inversetext_test": ("inversetext/test_images", "inversetext/test.json"), 110 | # "rects_test": ("ReCTS/ReCTS_test_images", "ReCTS/rects_test.json"), 111 | # "textocr_test": ("textocr/test_images", "textocr/test.json"), 112 | } 113 | 114 | metadata_text = { 115 | "thing_classes": ["text"] 116 | } 117 | 118 | 119 | def register_all_coco(root="datasets", voc_size_cfg=37, num_pts_cfg=25): 120 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_PIC.items(): 121 | # Assume pre-defined datasets live in `./datasets`. 122 | register_coco_instances( 123 | key, 124 | metadata_pic, 125 | os.path.join(root, json_file) if "://" not in json_file else json_file, 126 | os.path.join(root, image_root), 127 | ) 128 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_TEXT.items(): 129 | # Assume pre-defined datasets live in `./datasets`. 130 | register_text_instances( 131 | key, 132 | metadata_text, 133 | os.path.join(root, json_file) if "://" not in json_file else json_file, 134 | os.path.join(root, image_root), 135 | voc_size_cfg, 136 | num_pts_cfg 137 | ) 138 | 139 | def register_all_coco_semi(root="datasets", voc_size_cfg=37, num_pts_cfg=25): 140 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_PIC.items(): 141 | # Assume pre-defined datasets live in `./datasets`. 142 | register_coco_instances( 143 | key, 144 | metadata_pic, 145 | os.path.join(root, json_file) if "://" not in json_file else json_file, 146 | os.path.join(root, image_root), 147 | ) 148 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_TEXT.items(): 149 | # Assume pre-defined datasets live in `./datasets`. 150 | register_text_instances_ssl( 151 | key, 152 | metadata_text, 153 | os.path.join(root, json_file) if "://" not in json_file else json_file, 154 | os.path.join(root, image_root), 155 | voc_size_cfg, 156 | num_pts_cfg 157 | ) 158 | 159 | 160 | # get the vocabulary size and number of point queries in each instance 161 | # to eliminate blank text and sample gt according to Bezier control points 162 | parser = default_argument_parser() 163 | # add the following argument to avoid some errors while running demo/demo.py 164 | parser.add_argument("--input", nargs="+", help="A list of space separated input images") 165 | parser.add_argument( 166 | "--output", 167 | help="A file or directory to save output visualizations. " 168 | "If not given, will show output in an OpenCV window.", 169 | ) 170 | parser.add_argument( 171 | "--opts", 172 | help="Modify config options using the command-line 'KEY VALUE' pairs", 173 | default=[], 174 | nargs=argparse.REMAINDER, 175 | ) 176 | parser.add_argument("--refer", action="store_true", help="whether use the anno in builtin dataset to better visualize") 177 | parser.add_argument("--TSA", action="store_true", help="whether use TSA PL strategy") 178 | args = parser.parse_args() 179 | # cfg = get_cfg() 180 | cfg = get_cfg_semi() 181 | cfg.merge_from_file(args.config_file) 182 | # register_all_coco(voc_size_cfg=cfg.MODEL.TRANSFORMER.VOC_SIZE, num_pts_cfg=cfg.MODEL.TRANSFORMER.NUM_POINTS) 183 | 184 | register_all_coco_semi(voc_size_cfg=cfg.MODEL.TRANSFORMER.VOC_SIZE, num_pts_cfg=cfg.MODEL.TRANSFORMER.NUM_POINTS) -------------------------------------------------------------------------------- /adet/utils/hooks.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import itertools 3 | import logging 4 | import math 5 | import operator 6 | import os 7 | import tempfile 8 | import time 9 | import warnings 10 | from collections import Counter 11 | import torch 12 | from fvcore.common.checkpoint import Checkpointer 13 | from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer 14 | from fvcore.common.param_scheduler import ParamScheduler 15 | from fvcore.common.timer import Timer 16 | from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats 17 | import os.path as osp 18 | import detectron2.utils.comm as comm 19 | from detectron2.evaluation.testing import flatten_results_dict 20 | from detectron2.solver import LRMultiplier 21 | from detectron2.utils.events import EventStorage, EventWriter 22 | from detectron2.utils.file_io import PathManager 23 | import torch 24 | from torch.nn.parallel import DataParallel, DistributedDataParallel 25 | 26 | from detectron2.engine.train_loop import HookBase 27 | from bisect import bisect_right 28 | 29 | __all__ = ['StepRecord', 'MeanTeacher',] 30 | 31 | """ 32 | Customized hooks for the SSL. 33 | """ 34 | 35 | 36 | class PseudoLabelHook(HookBase): 37 | #implement of Pseudolabel generate hook to generate PL to be saved as an attr of model 38 | #support for Teacher-student framework without EMA,while teacher's params can be updated by epoch/iter or freezed 39 | #the format of PL is different from the original groundtruth for easier utilization , default to be the orignal predict of Model 40 | #in some cases only the top-k PL would be used ,which can be implemented by add semi_5s function in Model,while called from here to generate PL 41 | def __init__( 42 | self, 43 | strategy:'freeze', 44 | iter=None, 45 | epoch=None, 46 | save_dir=None, 47 | save=False, 48 | momentum=0.999, 49 | interval=1, 50 | warm_up=100, 51 | decay_intervals=None, 52 | decay_factor=0.1, 53 | clone_teacher=False 54 | ): 55 | self.strategy = strategy 56 | if isinstance(strategy, str): 57 | strategy = [strategy] 58 | allowed_strategies = ['freeze', 'iter', 'epoch', 'EMA'] 59 | if not set(strategy).issubset(set(allowed_strategies)): 60 | raise KeyError(f'metrics {strategy} is not supported') 61 | logger = logging.getLogger(__name__) 62 | logger.info(f"use {self.strategy} strategy for PseudoLabel generating") 63 | 64 | self.pred_allow = True 65 | if not osp.exists(save_dir): 66 | os.makedirs(save_dir) 67 | elif os.listdir(save_dir): 68 | self.pred_allow = False #if PL already exist, then forbid pred at the first round 69 | self.pseudo_init(save_dir) 70 | self.save_pl = save 71 | 72 | self.epoch_base=False 73 | self.iter_base=False 74 | self.EMA= self.strategy == 'EMA' 75 | 76 | self.iter = iter 77 | if self.iter is not None: 78 | self.iter_base = True 79 | 80 | self.epoch = epoch 81 | if self.epoch is not None: 82 | self.epoch_base =True 83 | 84 | if self.EMA : 85 | assert momentum >= 0 and momentum <= 1 86 | self.momentum = momentum 87 | assert isinstance(interval, int) and interval > 0 88 | self.warm_up = warm_up 89 | self.interval = interval 90 | assert isinstance(decay_intervals, list) or decay_intervals is None 91 | self.decay_intervals = decay_intervals 92 | self.decay_factor = decay_factor 93 | self.clone_teacher = clone_teacher 94 | 95 | def before_train(self): 96 | model = self.trainer.model 97 | if is_module_wrapper(model): 98 | model = model.module 99 | assert hasattr(model, "teacher") 100 | assert hasattr(model, "student") 101 | # only do it at initial stage 102 | if self.clone_teacher: 103 | if self.trainer.iter == 0: 104 | self.model_clone(model) 105 | 106 | if self.pred_allow: 107 | self.generate_PL(model) 108 | 109 | 110 | def before_step(self): 111 | """whether to Update parameter every interval.""" 112 | if self.strategy == 'freeze': 113 | pass 114 | elif self.iter_base : 115 | curr_step = self.trainer.iter 116 | if curr_step % self.iter != 0: 117 | return 118 | self.model_clone(self.trainer.model) 119 | 120 | elif self.epoch_base: 121 | curr_epoch = self.trainer.epoch 122 | if curr_epoch % self.epoch != 0: 123 | return 124 | self.model_clone(self.trainer.model) 125 | elif self.EMA: 126 | self.EMA_before_step() 127 | def EMA_before_step(self): 128 | """Update ema parameter every self.interval iterations.""" 129 | curr_step = self.trainer.iter 130 | if curr_step % self.interval != 0: 131 | return 132 | model = self.trainer.model 133 | if is_module_wrapper(model): 134 | model = model.module 135 | # We warm up the momentum considering the instability at beginning 136 | momentum = min( 137 | self.momentum, 1 - (1 + self.warm_up) / (curr_step + 1 + self.warm_up) 138 | ) 139 | if momentum < self.momentum: 140 | logger = logging.getLogger(__name__) 141 | logger.info( 142 | f"warming up momentum to {self.momentum}, current value is {momentum} at {curr_step} step." 143 | ) 144 | self.momentum_update(model, momentum) 145 | 146 | def after_step(self): 147 | if self.strategy == 'freeze': 148 | pass 149 | elif self.iter_base: 150 | pass 151 | elif self.epoch_base: 152 | pass 153 | elif self.EMA: 154 | self.EMA_after_step() 155 | def EMA_after_step(self): 156 | curr_step = self.trainer.iter 157 | if self.decay_intervals is None: 158 | return 159 | self.momentum = 1 - (1 - self.momentum) / self.decay_factor ** bisect_right( 160 | self.decay_intervals, curr_step 161 | ) 162 | 163 | def momentum_update(self, model, momentum): 164 | for (src_name, src_parm), (tgt_name, tgt_parm) in zip( 165 | model.student.named_parameters(), model.teacher.named_parameters() 166 | ): 167 | tgt_parm.data.mul_(momentum).add_(src_parm.data, alpha=1 - momentum) 168 | def model_clone(self,model): 169 | logger = logging.getLogger(__name__) 170 | logger.info("Clone all parameters of student to teacher...") 171 | 172 | self.momentum_update(model, 0) 173 | 174 | 175 | def generate_PL(self,model): #refer to eval hook 176 | self.model.Pseudo = 'to be implemented' 177 | if self.save_pl: 178 | pass 179 | #save pl to save_dir 180 | def pseudo_init(self,save_dir): 181 | pass 182 | #load pl from save_dir 183 | self.trainer.model.pseudo='to be implemented' 184 | 185 | 186 | 187 | class StepRecord(HookBase): 188 | def __init__( 189 | self, 190 | normalize=False, 191 | name="curr_step" 192 | ): 193 | self.normalize = normalize 194 | self.name = name 195 | 196 | def after_step(self): 197 | """ 198 | Called after each iteration. 199 | """ 200 | curr_step = self.trainer.iter 201 | model = self.trainer.model 202 | if is_module_wrapper(model): 203 | model = model.module 204 | 205 | assert hasattr(model, self.name) 206 | setattr(model, self.name, curr_step/10000 if self.normalize else curr_step) 207 | 208 | 209 | class MeanTeacher(HookBase): 210 | def __init__( 211 | self, 212 | momentum=0.999, 213 | interval=1, 214 | warm_up=100, 215 | decay_intervals=None, 216 | decay_factor=0.1, 217 | clone_teacher=False 218 | ): 219 | assert momentum >= 0 and momentum <= 1 220 | self.momentum = momentum 221 | assert isinstance(interval, int) and interval > 0 222 | self.warm_up = warm_up 223 | self.interval = interval 224 | assert isinstance(decay_intervals, list) or decay_intervals is None 225 | self.decay_intervals = decay_intervals 226 | self.decay_factor = decay_factor 227 | self.clone_teacher = clone_teacher 228 | 229 | def before_train(self): 230 | model = self.trainer.model 231 | if is_module_wrapper(model): 232 | model = model.module 233 | assert hasattr(model, "teacher") 234 | assert hasattr(model, "student") 235 | # only do it at initial stage 236 | if self.clone_teacher: 237 | if self.trainer.iter == 0: 238 | logger = logging.getLogger(__name__) 239 | logger.info("Clone all parameters of student to teacher...") 240 | 241 | self.momentum_update(model, 0) 242 | 243 | def before_step(self): 244 | """Update ema parameter every self.interval iterations.""" 245 | curr_step = self.trainer.iter 246 | if curr_step % self.interval != 0: 247 | return 248 | model = self.trainer.model 249 | if is_module_wrapper(model): 250 | model = model.module 251 | # We warm up the momentum considering the instability at beginning 252 | momentum = min( 253 | self.momentum, 1 - (1 + self.warm_up) / (curr_step + 1 + self.warm_up) 254 | ) 255 | if momentum < self.momentum: 256 | logger = logging.getLogger(__name__) 257 | logger.info( 258 | f"warming up momentum to {self.momentum}, current value is {momentum} at {curr_step} step." 259 | ) 260 | self.momentum_update(model, momentum) 261 | 262 | def after_step(self): 263 | curr_step = self.trainer.iter 264 | if self.decay_intervals is None: 265 | return 266 | self.momentum = 1 - (1 - self.momentum) / self.decay_factor ** bisect_right( 267 | self.decay_intervals, curr_step 268 | ) 269 | 270 | def momentum_update(self, model, momentum): 271 | for (src_name, src_parm), (tgt_name, tgt_parm) in zip( 272 | model.student.named_parameters(), model.teacher.named_parameters() 273 | ): 274 | tgt_parm.data.mul_(momentum).add_(src_parm.data, alpha=1 - momentum) 275 | 276 | 277 | def is_module_wrapper(module: torch.nn.Module) -> bool: 278 | """Check if a module is a module wrapper. 279 | 280 | The following 3 modules in MMCV (and their subclasses) are regarded as 281 | module wrappers: DataParallel, DistributedDataParallel, 282 | MMDistributedDataParallel (the deprecated version). You may add you own 283 | module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or 284 | its children registries. 285 | 286 | Args: 287 | module (nn.Module): The module to be checked. 288 | 289 | Returns: 290 | bool: True if the input module is a module wrapper. 291 | """ 292 | return isinstance(module, (DataParallel, DistributedDataParallel)) 293 | 294 | # def is_module_in_wrapper(module, module_wrapper): 295 | # module_wrappers = tuple(module_wrapper.module_dict.values()) 296 | # if isinstance(module, module_wrappers): 297 | # return True 298 | # for child in module_wrapper.children.values(): 299 | # if is_module_in_wrapper(module, child): 300 | # return True 301 | # return False 302 | # 303 | # return is_module_in_wrapper(module, (DataParallel, DistributedDataParallel)) -------------------------------------------------------------------------------- /datasets/ic15/train_37voc_1_labeled.json: -------------------------------------------------------------------------------- 1 | {"images": [{"coco_url": "", "date_captured": "", "file_name": "img_790.jpg", "flickr_url": "", "id": 790, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_471.jpg", "flickr_url": "", "id": 471, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_416.jpg", "flickr_url": "", "id": 416, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_341.jpg", "flickr_url": "", "id": 341, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_836.jpg", "flickr_url": "", "id": 836, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_326.jpg", "flickr_url": "", "id": 326, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_347.jpg", "flickr_url": "", "id": 347, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_988.jpg", "flickr_url": "", "id": 988, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_833.jpg", "flickr_url": "", "id": 833, "license": 0, "width": 1280, "height": 720}, {"coco_url": "", "date_captured": "", "file_name": "img_196.jpg", "flickr_url": "", "id": 196, "license": 0, "width": 1280, "height": 720}], "annotations": [{"area": 78500.0, "bbox": [281.0, 127.0, 628.0, 125.0], "category_id": 1, "id": 1560, "image_id": 790, "iscrowd": 0, "bezier_pts": [281, 175, 488, 159, 695, 143, 902, 127, 908, 203, 701, 219, 494, 235, 287, 251], "rec": [22, 0, 17, 4, 7, 14, 20, 18, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 3159.0, "bbox": [270.0, 444.0, 81.0, 39.0], "category_id": 1, "id": 1561, "image_id": 790, "iscrowd": 0, "bezier_pts": [273, 444, 298, 445, 324, 446, 350, 448, 347, 482, 321, 480, 295, 478, 270, 477], "rec": [18, 0, 11, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 902.0, "bbox": [138.0, 367.0, 41.0, 22.0], "category_id": 1, "id": 1562, "image_id": 790, "iscrowd": 0, "bezier_pts": [140, 368, 152, 367, 165, 367, 178, 367, 176, 387, 163, 387, 150, 387, 138, 388], "rec": [18, 0, 11, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1701.0, "bbox": [520.0, 250.0, 81.0, 21.0], "category_id": 1, "id": 3627, "image_id": 471, "iscrowd": 0, "bezier_pts": [520, 254, 545, 252, 571, 251, 597, 250, 600, 266, 574, 267, 548, 268, 522, 270], "rec": [9, 0, 21, 4, 13, 20, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 2184.0, "bbox": [602.0, 243.0, 91.0, 24.0], "category_id": 1, "id": 3628, "image_id": 471, "iscrowd": 0, "bezier_pts": [602, 248, 631, 246, 660, 244, 690, 243, 692, 261, 663, 262, 634, 264, 605, 266], "rec": [5, 11, 8, 15, 1, 14, 14, 19, 7, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 468.0, "bbox": [551.0, 146.0, 36.0, 13.0], "category_id": 1, "id": 3962, "image_id": 416, "iscrowd": 0, "bezier_pts": [551, 146, 562, 146, 574, 146, 586, 146, 586, 158, 574, 157, 562, 157, 551, 157], "rec": [18, 0, 21, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 600.0, "bbox": [781.0, 139.0, 40.0, 15.0], "category_id": 1, "id": 3963, "image_id": 416, "iscrowd": 0, "bezier_pts": [781, 140, 793, 139, 805, 139, 818, 139, 820, 153, 807, 153, 794, 153, 782, 153], "rec": [18, 0, 21, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 612.0, "bbox": [586.0, 142.0, 36.0, 17.0], "category_id": 1, "id": 3964, "image_id": 416, "iscrowd": 0, "bezier_pts": [586, 143, 597, 142, 609, 142, 621, 142, 621, 157, 609, 157, 597, 157, 586, 158], "rec": [13, 14, 22, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1056.0, "bbox": [188.0, 289.0, 48.0, 22.0], "category_id": 1, "id": 226, "image_id": 341, "iscrowd": 0, "bezier_pts": [188, 290, 203, 289, 218, 289, 234, 289, 235, 309, 220, 309, 205, 309, 190, 310], "rec": [31, 26, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1008.0, "bbox": [191.0, 308.0, 48.0, 21.0], "category_id": 1, "id": 227, "image_id": 341, "iscrowd": 0, "bezier_pts": [191, 308, 206, 308, 222, 308, 238, 308, 238, 328, 222, 328, 206, 328, 191, 328], "rec": [14, 5, 5, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 11844.0, "bbox": [783.0, 46.0, 126.0, 94.0], "category_id": 1, "id": 228, "image_id": 341, "iscrowd": 0, "bezier_pts": [783, 89, 824, 74, 865, 60, 907, 46, 908, 101, 866, 113, 825, 126, 784, 139], "rec": [15, 0, 13, 3, 14, 17, 0, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 4588.0, "bbox": [491.0, 197.0, 148.0, 31.0], "category_id": 1, "id": 4322, "image_id": 836, "iscrowd": 0, "bezier_pts": [491, 197, 539, 197, 588, 197, 637, 198, 638, 227, 589, 226, 540, 226, 492, 226], "rec": [18, 8, 13, 6, 0, 15, 14, 17, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 3570.0, "bbox": [503.0, 226.0, 119.0, 30.0], "category_id": 1, "id": 4323, "image_id": 836, "iscrowd": 0, "bezier_pts": [507, 226, 545, 226, 583, 227, 621, 228, 618, 255, 579, 254, 541, 253, 503, 253], "rec": [0, 8, 17, 11, 8, 13, 4, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 65224.0, "bbox": [2.0, 328.0, 248.0, 263.0], "category_id": 1, "id": 487, "image_id": 326, "iscrowd": 0, "bezier_pts": [2, 455, 84, 412, 166, 370, 249, 328, 249, 463, 166, 505, 84, 547, 2, 590], "rec": [14, 5, 5, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 3198.0, "bbox": [658.0, 416.0, 39.0, 82.0], "category_id": 1, "id": 488, "image_id": 326, "iscrowd": 0, "bezier_pts": [658, 434, 667, 428, 676, 422, 686, 416, 696, 469, 685, 478, 674, 487, 663, 497], "rec": [14, 5, 5, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 6348.0, "bbox": [494.0, 28.0, 92.0, 69.0], "category_id": 1, "id": 489, "image_id": 326, "iscrowd": 0, "bezier_pts": [494, 57, 517, 47, 541, 37, 565, 28, 585, 64, 559, 74, 534, 85, 509, 96], "rec": [33, 26, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 3465.0, "bbox": [579.0, 6.0, 63.0, 55.0], "category_id": 1, "id": 490, "image_id": 326, "iscrowd": 0, "bezier_pts": [579, 26, 595, 19, 612, 12, 629, 6, 641, 40, 623, 46, 606, 53, 589, 60], "rec": [14, 5, 5, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 722.0, "bbox": [607.0, 516.0, 38.0, 19.0], "category_id": 1, "id": 3405, "image_id": 347, "iscrowd": 0, "bezier_pts": [607, 519, 619, 518, 631, 517, 643, 516, 644, 531, 632, 532, 620, 533, 608, 534], "rec": [9, 0, 12, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 2025.0, "bbox": [587.0, 489.0, 75.0, 27.0], "category_id": 1, "id": 3406, "image_id": 347, "iscrowd": 0, "bezier_pts": [587, 497, 611, 494, 635, 491, 660, 489, 661, 506, 636, 509, 611, 512, 587, 515], "rec": [6, 8, 5, 19, 0, 22, 0, 24, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 4386.0, "bbox": [0.0, 276.0, 86.0, 51.0], "category_id": 1, "id": 3407, "image_id": 347, "iscrowd": 0, "bezier_pts": [2, 276, 29, 276, 57, 277, 85, 278, 81, 326, 54, 325, 27, 324, 0, 324], "rec": [23, 19, 17, 4, 12, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 6230.0, "bbox": [815.0, 71.0, 70.0, 89.0], "category_id": 1, "id": 379, "image_id": 988, "iscrowd": 0, "bezier_pts": [815, 122, 836, 105, 858, 88, 880, 71, 884, 115, 863, 129, 843, 144, 823, 159], "rec": [31, 24, 4, 0, 17, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 10465.0, "bbox": [874.0, 2.0, 91.0, 115.0], "category_id": 1, "id": 380, "image_id": 988, "iscrowd": 0, "bezier_pts": [874, 62, 901, 42, 929, 22, 957, 2, 964, 57, 937, 76, 910, 96, 883, 116], "rec": [17, 0, 12, 4, 13, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 12852.0, "bbox": [1095.0, 84.0, 108.0, 119.0], "category_id": 1, "id": 381, "image_id": 988, "iscrowd": 0, "bezier_pts": [1095, 120, 1128, 108, 1161, 96, 1194, 84, 1202, 169, 1167, 180, 1132, 191, 1097, 202], "rec": [19, 14, 13, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1504.0, "bbox": [590.0, 103.0, 47.0, 32.0], "category_id": 1, "id": 800, "image_id": 833, "iscrowd": 0, "bezier_pts": [592, 103, 606, 105, 621, 107, 636, 110, 633, 134, 618, 131, 604, 129, 590, 127], "rec": [4, 23, 8, 19, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1431.0, "bbox": [537.0, 220.0, 53.0, 27.0], "category_id": 1, "id": 801, "image_id": 833, "iscrowd": 0, "bezier_pts": [539, 220, 555, 221, 572, 223, 589, 225, 587, 246, 570, 244, 553, 242, 537, 241], "rec": [0, 11, 8, 2, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1740.0, "bbox": [355.0, 323.0, 87.0, 20.0], "category_id": 1, "id": 802, "image_id": 833, "iscrowd": 0, "bezier_pts": [356, 324, 384, 323, 412, 323, 441, 323, 440, 341, 411, 341, 383, 341, 355, 342], "rec": [1, 11, 4, 13, 7, 4, 8, 12, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1440.0, "bbox": [350.0, 342.0, 72.0, 20.0], "category_id": 1, "id": 803, "image_id": 833, "iscrowd": 0, "bezier_pts": [351, 345, 374, 344, 397, 343, 421, 342, 420, 357, 396, 358, 373, 359, 350, 361], "rec": [0, 21, 4, 13, 20, 4, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1240.0, "bbox": [614.0, 98.0, 62.0, 20.0], "category_id": 1, "id": 4267, "image_id": 196, "iscrowd": 0, "bezier_pts": [615, 98, 635, 99, 655, 100, 675, 102, 674, 117, 654, 116, 634, 115, 614, 114], "rec": [3, 14, 14, 17, 18, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 2222.0, "bbox": [101.0, 267.0, 101.0, 22.0], "category_id": 1, "id": 4268, "image_id": 196, "iscrowd": 0, "bezier_pts": [101, 267, 134, 267, 167, 267, 200, 267, 201, 288, 168, 288, 135, 288, 102, 288], "rec": [17, 4, 18, 4, 17, 21, 4, 3, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1120.0, "bbox": [462.0, 93.0, 56.0, 20.0], "category_id": 1, "id": 4269, "image_id": 196, "iscrowd": 0, "bezier_pts": [462, 93, 480, 93, 498, 94, 516, 95, 517, 112, 499, 111, 481, 110, 463, 110], "rec": [10, 4, 4, 15, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1064.0, "bbox": [522.0, 96.0, 56.0, 19.0], "category_id": 1, "id": 4270, "image_id": 196, "iscrowd": 0, "bezier_pts": [523, 96, 541, 96, 559, 96, 577, 97, 576, 114, 558, 113, 540, 113, 522, 113], "rec": [2, 11, 4, 0, 17, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1638.0, "bbox": [198.0, 268.0, 78.0, 21.0], "category_id": 1, "id": 4271, "image_id": 196, "iscrowd": 0, "bezier_pts": [198, 270, 223, 269, 248, 268, 274, 268, 275, 287, 250, 287, 225, 287, 200, 288], "rec": [18, 4, 0, 19, 8, 13, 6, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1340.0, "bbox": [897.0, 104.0, 67.0, 20.0], "category_id": 1, "id": 4272, "image_id": 196, "iscrowd": 0, "bezier_pts": [901, 104, 921, 104, 942, 105, 963, 106, 960, 123, 939, 122, 918, 121, 897, 121], "rec": [9, 0, 13, 6, 0, 13, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 1026.0, "bbox": [958.0, 108.0, 57.0, 18.0], "category_id": 1, "id": 4273, "image_id": 196, "iscrowd": 0, "bezier_pts": [961, 108, 978, 109, 996, 110, 1014, 111, 1012, 125, 994, 124, 976, 123, 958, 123], "rec": [1, 4, 17, 3, 8, 17, 8, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 736.0, "bbox": [1010.0, 110.0, 46.0, 16.0], "category_id": 1, "id": 4274, "image_id": 196, "iscrowd": 0, "bezier_pts": [1011, 110, 1025, 110, 1040, 111, 1055, 112, 1054, 125, 1039, 124, 1024, 123, 1010, 123], "rec": [3, 4, 10, 0, 19, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}, {"area": 731.0, "bbox": [1056.0, 112.0, 43.0, 17.0], "category_id": 1, "id": 4275, "image_id": 196, "iscrowd": 0, "bezier_pts": [1057, 112, 1070, 112, 1084, 113, 1098, 114, 1097, 128, 1083, 127, 1069, 126, 1056, 126], "rec": [15, 8, 13, 19, 20, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37]}], "categories": [{"id": 1, "name": "text", "supercategory": "beverage", "keypoints": ["mean", "xmin", "x2", "x3", "xmax", "ymin", "y2", "y3", "ymax", "cross"]}]} -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Detection Training Script. 4 | 5 | This scripts reads a given config file and runs the training or evaluation. 6 | It is an entry point that is made to train standard models in detectron2. 7 | 8 | In order to let one script support training of many models, 9 | this script contains logic that are specific to these built-in models and therefore 10 | may not be suitable for your own project. 11 | For example, your research project perhaps only needs a single "evaluator". 12 | 13 | Therefore, we recommend you to use detectron2 as an library and take 14 | this file as an example of how to use the library. 15 | You may want to write your own script with your datasets and other customizations. 16 | """ 17 | 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from typing import Any, Dict, List, Set 22 | import torch 23 | import itertools 24 | from torch.nn.parallel import DistributedDataParallel 25 | 26 | import sys 27 | import os 28 | path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 29 | sys.path.insert(0, path) 30 | 31 | import detectron2.utils.comm as comm 32 | from detectron2.data import MetadataCatalog, build_detection_train_loader, build_detection_test_loader 33 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 34 | from detectron2.utils.events import EventStorage 35 | from detectron2.evaluation import ( 36 | COCOEvaluator, 37 | COCOPanopticEvaluator, 38 | DatasetEvaluators, 39 | LVISEvaluator, 40 | PascalVOCDetectionEvaluator, 41 | SemSegEvaluator, 42 | verify_results, 43 | ) 44 | from detectron2.solver.build import maybe_add_gradient_clipping 45 | from detectron2.modeling import GeneralizedRCNNWithTTA 46 | from detectron2.utils.logger import setup_logger 47 | 48 | from adet.data.dataset_mapper import DatasetMapperWithBasis 49 | from adet.config import get_cfg 50 | from adet.checkpoint import AdetCheckpointer 51 | from adet.evaluation import TextEvaluator 52 | from adet.modeling import swin, vitae_v2 53 | 54 | 55 | class Trainer(DefaultTrainer): 56 | """ 57 | This is the same Trainer except that we rewrite the 58 | `build_train_loader`/`resume_or_load` method. 59 | """ 60 | def build_hooks(self): 61 | """ 62 | Replace `DetectionCheckpointer` with `AdetCheckpointer`. 63 | 64 | Build a list of default hooks, including timing, evaluation, 65 | checkpointing, lr scheduling, precise BN, writing events. 66 | """ 67 | ret = super().build_hooks() 68 | for i in range(len(ret)): 69 | if isinstance(ret[i], hooks.PeriodicCheckpointer): 70 | self.checkpointer = AdetCheckpointer( 71 | self.model, 72 | self.cfg.OUTPUT_DIR, 73 | optimizer=self.optimizer, 74 | scheduler=self.scheduler, 75 | ) 76 | ret[i] = hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD) 77 | return ret 78 | 79 | def resume_or_load(self, resume=True): 80 | checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) 81 | if resume and self.checkpointer.has_checkpoint(): 82 | self.start_iter = checkpoint.get("iteration", -1) + 1 83 | 84 | def train_loop(self, start_iter: int, max_iter: int): 85 | """ 86 | Args: 87 | start_iter, max_iter (int): See docs above 88 | """ 89 | logger = logging.getLogger("adet.trainer") 90 | # param = sum(p.numel() for p in self.model.parameters()) 91 | # logger.info(f"Model Params: {param}") 92 | logger.info("Starting training from iteration {}".format(start_iter)) 93 | 94 | self.iter = self.start_iter = start_iter 95 | self.max_iter = max_iter 96 | 97 | with EventStorage(start_iter) as self.storage: 98 | self.before_train() 99 | for self.iter in range(start_iter, max_iter): 100 | self.before_step() 101 | self.run_step() 102 | self.after_step() 103 | self.after_train() 104 | 105 | def train(self): 106 | """ 107 | Run training. 108 | 109 | Returns: 110 | OrderedDict of results, if evaluation is enabled. Otherwise None. 111 | """ 112 | self.train_loop(self.start_iter, self.max_iter) 113 | if hasattr(self, "_last_eval_results") and comm.is_main_process(): 114 | verify_results(self.cfg, self._last_eval_results) 115 | return self._last_eval_results 116 | 117 | @classmethod 118 | def build_train_loader(cls, cfg): 119 | """ 120 | Returns: 121 | iterable 122 | 123 | It calls :func:`detectron2.data.build_detection_train_loader` with a customized 124 | DatasetMapper, which adds categorical labels as a semantic mask. 125 | """ 126 | mapper = DatasetMapperWithBasis(cfg, True) 127 | return build_detection_train_loader(cfg, mapper=mapper) 128 | 129 | @classmethod 130 | def build_test_loader(cls, cfg, dataset_name): 131 | """ 132 | Returns: 133 | iterable 134 | 135 | It now calls :func:`detectron2.data.build_detection_test_loader`. 136 | Overwrite it if you'd like a different data loader. 137 | """ 138 | mapper = DatasetMapperWithBasis(cfg, False) 139 | return build_detection_test_loader(cfg, dataset_name, mapper=mapper) 140 | 141 | @classmethod 142 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 143 | """ 144 | Create evaluator(s) for a given dataset. 145 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 146 | For your own dataset, you can simply create an evaluator manually in your 147 | script and do not have to worry about the hacky if-else logic here. 148 | """ 149 | if output_folder is None: 150 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 151 | evaluator_list = [] 152 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 153 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 154 | evaluator_list.append( 155 | SemSegEvaluator( 156 | dataset_name, 157 | distributed=True, 158 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 159 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 160 | output_dir=output_folder, 161 | ) 162 | ) 163 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 164 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 165 | if evaluator_type == "coco_panoptic_seg": 166 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 167 | if evaluator_type == "pascal_voc": 168 | return PascalVOCDetectionEvaluator(dataset_name) 169 | if evaluator_type == "lvis": 170 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 171 | if evaluator_type == "text": 172 | return TextEvaluator(dataset_name, cfg, True, output_folder) 173 | if len(evaluator_list) == 0: 174 | raise NotImplementedError( 175 | "no Evaluator for the dataset {} with the type {}".format( 176 | dataset_name, evaluator_type 177 | ) 178 | ) 179 | if len(evaluator_list) == 1: 180 | return evaluator_list[0] 181 | return DatasetEvaluators(evaluator_list) 182 | 183 | @classmethod 184 | def test_with_TTA(cls, cfg, model): 185 | logger = logging.getLogger("adet.trainer") 186 | # In the end of training, run an evaluation with TTA 187 | # Only support some R-CNN models. 188 | logger.info("Running inference with test-time augmentation ...") 189 | model = GeneralizedRCNNWithTTA(cfg, model) 190 | evaluators = [ 191 | cls.build_evaluator( 192 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 193 | ) 194 | for name in cfg.DATASETS.TEST 195 | ] 196 | res = cls.test(cfg, model, evaluators) 197 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 198 | return res 199 | 200 | @classmethod 201 | def build_optimizer(cls, cfg, model): 202 | def match_name_keywords(n, name_keywords): 203 | out = False 204 | for b in name_keywords: 205 | if b in n: 206 | out = True 207 | break 208 | return out 209 | 210 | params: List[Dict[str, Any]] = [] 211 | memo: Set[torch.nn.parameter.Parameter] = set() 212 | for key, value in model.named_parameters(recurse=True): 213 | if not value.requires_grad: 214 | continue 215 | # Avoid duplicating parameters 216 | if value in memo: 217 | continue 218 | memo.add(value) 219 | lr = cfg.SOLVER.BASE_LR 220 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 221 | 222 | if match_name_keywords(key, cfg.SOLVER.LR_BACKBONE_NAMES): 223 | lr = cfg.SOLVER.LR_BACKBONE 224 | elif match_name_keywords(key, cfg.SOLVER.LR_LINEAR_PROJ_NAMES): 225 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.LR_LINEAR_PROJ_MULT 226 | 227 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 228 | 229 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 230 | # detectron2 doesn't have full model gradient clipping now 231 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 232 | enable = ( 233 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 234 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 235 | and clip_norm_val > 0.0 236 | ) 237 | 238 | class FullModelGradientClippingOptimizer(optim): 239 | def step(self, closure=None): 240 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 241 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 242 | super().step(closure=closure) 243 | 244 | return FullModelGradientClippingOptimizer if enable else optim 245 | 246 | optimizer_type = cfg.SOLVER.OPTIMIZER 247 | if optimizer_type == "SGD": 248 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 249 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 250 | ) 251 | elif optimizer_type == "ADAMW": 252 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 253 | params, cfg.SOLVER.BASE_LR 254 | ) 255 | else: 256 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 257 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 258 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 259 | return optimizer 260 | 261 | 262 | def setup(args): 263 | """ 264 | Create configs and perform basic setups. 265 | """ 266 | cfg = get_cfg() 267 | cfg.merge_from_file(args.config_file) 268 | cfg.merge_from_list(args.opts) 269 | cfg.freeze() 270 | default_setup(cfg, args) 271 | 272 | rank = comm.get_rank() 273 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="adet") 274 | 275 | return cfg 276 | 277 | 278 | def main(args): 279 | cfg = setup(args) 280 | 281 | if args.eval_only: 282 | model = Trainer.build_model(cfg) 283 | AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 284 | cfg.MODEL.WEIGHTS, resume=args.resume 285 | ) 286 | res = Trainer.test(cfg, model) # d2 defaults.py 287 | if comm.is_main_process(): 288 | verify_results(cfg, res) 289 | if cfg.TEST.AUG.ENABLED: 290 | res.update(Trainer.test_with_TTA(cfg, model)) 291 | return res 292 | 293 | """ 294 | If you'd like to do anything fancier than the standard training logic, 295 | consider writing your own training loop or subclassing the trainer. 296 | """ 297 | trainer = Trainer(cfg) 298 | trainer.resume_or_load(resume=args.resume) 299 | if cfg.TEST.AUG.ENABLED: 300 | trainer.register_hooks( 301 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 302 | ) 303 | return trainer.train() 304 | 305 | 306 | if __name__ == "__main__": 307 | args = default_argument_parser().parse_args() 308 | print("Command Line Args:", args) 309 | launch( 310 | main, 311 | args.num_gpus, 312 | num_machines=args.num_machines, 313 | machine_rank=args.machine_rank, 314 | dist_url=args.dist_url, 315 | args=(args,), 316 | ) 317 | -------------------------------------------------------------------------------- /adet/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | import torch 5 | from detectron2.utils.visualizer import Visualizer,VisImage 6 | import matplotlib.colors as mplc 7 | import matplotlib.font_manager as mfm 8 | import matplotlib as mpl 9 | import matplotlib.figure as mplfigure 10 | import random 11 | from shapely.geometry import LineString 12 | import math 13 | import operator 14 | from functools import reduce 15 | from torch import cat,device 16 | 17 | 18 | class TextVisualizer(Visualizer): 19 | def __init__(self, image, metadata, instance_mode, cfg , with_gt = False): 20 | Visualizer.__init__(self, image, metadata, instance_mode=instance_mode) 21 | self.voc_size = cfg.MODEL.TRANSFORMER.VOC_SIZE 22 | self.use_customer_dictionary = cfg.MODEL.TRANSFORMER.CUSTOM_DICT 23 | if with_gt : 24 | self.output_gt = VisImage(self.img, scale=1.0) 25 | if self.voc_size == 96: 26 | self.CTLABELS = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~'] 27 | elif self.voc_size == 37: 28 | self.CTLABELS = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','0','1','2','3','4','5','6','7','8','9'] 29 | else: 30 | with open(self.use_customer_dictionary, 'rb') as fp: 31 | self.CTLABELS = pickle.load(fp) 32 | # voc_size includes the unknown class, which is not in self.CTABLES 33 | assert(int(self.voc_size - 1) == len(self.CTLABELS)), "voc_size is not matched dictionary size, got {} and {}.".format(int(self.voc_size - 1), len(self.CTLABELS)) 34 | 35 | def draw_instance_predictions(self, predictions): 36 | ctrl_pnts = predictions['ctrl_points'].numpy() 37 | scores = predictions['scores'].tolist() 38 | recs = predictions["recs"] 39 | bd_pts = np.asarray(predictions["bd_points"]) 40 | ctc_scores = predictions['ctc_score'] 41 | self.overlay_instances(ctrl_pnts, recs, bd_pts,scores=scores,ctc_scores = ctc_scores) 42 | 43 | return self.output 44 | def draw_instance_predictions_withGT(self,anno): 45 | self.output = self.output_gt#clear 46 | bd_gt =np.array([instance['boundary'].reshape(25,4) for instance in anno]) 47 | ctrl_gt = np.array([instance['polyline'].reshape(-1) for instance in anno]) 48 | recs_gt = np.array([instance['text'] for instance in anno]) 49 | self.overlay_instances(ctrl_gt, recs_gt, bd_gt) 50 | return self.output 51 | def draw_ts(self, predictions, img_size = None): 52 | if img_size is not None: 53 | img_size = img_size.__reversed__() 54 | ctrl_pnts = (predictions['ctrl_points'].to(device('cpu')) * img_size).numpy() 55 | bd_pts = np.asarray(predictions['bd_points'].to(device('cpu')) * cat([img_size, img_size])) 56 | recs = predictions['texts'].to(device('cpu')).numpy() 57 | self.overlay_instances_ts(ctrl_pnts, bd_pts, recs) 58 | else: 59 | ctrl_pnts = (predictions['ctrl_points'].to(device('cpu'))).numpy() 60 | bd_pts = np.asarray(predictions['bd_points'].to(device('cpu'))) 61 | recs = predictions['texts'].to(device('cpu')).numpy() 62 | self.overlay_instances_ts(ctrl_pnts, bd_pts, recs) 63 | return self.output 64 | 65 | def draw_ref_points(self, ref_points, img_size): 66 | colors = [(0,0.5,0),(0,0.75,0),(1,0,1),(0.75,0,0.75),(0.5,0,0.5),(1,0,0),(0.75,0,0),(0.5,0,0), 67 | (0,0,1),(0,0,0.75),(0.75,0.25,0.25),(0.75,0.5,0.5),(0,0.75,0.75),(0,0.5,0.5),(0,0.3,0.75)] 68 | img_size = img_size.__reversed__() 69 | ref_points = (ref_points.to(device('cpu')) * img_size).numpy() 70 | for i, ref_point in enumerate(ref_points): 71 | color = random.choice(colors) 72 | line = self._process_ctrl_pnt(ref_point) 73 | self.draw_line( 74 | line[:, 0], 75 | line[:, 1], 76 | color=color, 77 | linewidth=2 78 | ) 79 | for pt in line: 80 | self.draw_circle(pt, 'r', radius=6) 81 | self.draw_text( 82 | f'{i}', 83 | line[-1] + np.array([0, 10]), 84 | color=color, 85 | horizontal_alignment='left', 86 | font_size=self._default_font_size, 87 | draw_chinese=False if self.voc_size == 37 or self.voc_size == 96 else True 88 | ) 89 | return self.output 90 | 91 | def _process_ctrl_pnt(self, pnt): 92 | points = pnt.reshape(-1, 2) 93 | return points 94 | 95 | def _ctc_decode_recognition(self, rec): 96 | last_char = '###' 97 | s = '' 98 | for c in rec: 99 | c = int(c) 100 | if c < self.voc_size - 1: 101 | if last_char != c: 102 | if self.voc_size == 37 or self.voc_size == 96: 103 | s += self.CTLABELS[c] 104 | last_char = c 105 | else: 106 | s += str(chr(self.CTLABELS[c])) 107 | last_char = c 108 | else: 109 | last_char = '###' 110 | return s 111 | 112 | def overlay_instances(self, ctrl_pnts, recs, bd_pnts, alpha=0.4,scores=None,ctc_scores=None): 113 | colors = [(0,0.5,0),(0,0.75,0),(1,0,1),(0.75,0,0.75),(0.5,0,0.5),(1,0,0),(0.75,0,0),(0.5,0,0), 114 | (0,0,1),(0,0,0.75),(0.75,0.25,0.25),(0.75,0.5,0.5),(0,0.75,0.75),(0,0.5,0.5),(0,0.3,0.75)] 115 | instance_num = ctrl_pnts.shape[0] 116 | scores = [1 for i in range(instance_num)] if scores is None else scores 117 | ctc_scores=[1 for i in range(instance_num)] if ctc_scores is None else ctc_scores 118 | for ctrl_pnt, rec, bd , sc , ctc_sc in zip(ctrl_pnts, recs, bd_pnts , scores , ctc_scores): 119 | color = random.choice(colors) 120 | 121 | # draw polygons 122 | if bd is not None: 123 | bd = np.hsplit(bd, 2) 124 | bd = np.vstack([bd[0], bd[1][::-1]]) 125 | self.draw_polygon(bd, color, alpha=alpha) 126 | 127 | # draw center lines 128 | line = self._process_ctrl_pnt(ctrl_pnt) 129 | line_ = LineString(line) 130 | center_point = np.array(line_.interpolate(0.5, normalized=True).coords[0], dtype=np.int32) 131 | self.draw_line( 132 | line[:, 0], 133 | line[:, 1], 134 | color=color, 135 | linewidth=2 136 | ) 137 | for pt in line: 138 | self.draw_circle(pt, 'w', radius=4) 139 | self.draw_circle(pt, 'r', radius=2) 140 | 141 | # draw text 142 | text = self._ctc_decode_recognition(rec) 143 | if self.voc_size == 37: 144 | text = text.upper() #大写 145 | # text = "{}".format(text) 146 | text = f'{text} {ctc_sc:.4f}' 147 | det_text = f'{sc:.4f}' 148 | det_text_pos = center_point 149 | lighter_color = self._change_color_brightness(color, brightness_factor=0) 150 | if bd is not None: 151 | text_pos = bd[0] - np.array([0,15]) 152 | else: 153 | text_pos = center_point 154 | horiz_align = "left" 155 | font_size = self._default_font_size 156 | self.draw_text( 157 | text, 158 | text_pos, 159 | color=lighter_color, 160 | horizontal_alignment=horiz_align, 161 | font_size=font_size, 162 | draw_chinese=False if self.voc_size == 37 or self.voc_size == 96 else True 163 | ) 164 | self.draw_text( 165 | det_text, 166 | det_text_pos, 167 | color=lighter_color, 168 | horizontal_alignment=horiz_align, 169 | font_size=font_size, 170 | draw_chinese=False if self.voc_size == 37 or self.voc_size == 96 else True 171 | ) 172 | 173 | def overlay_instances_ts(self, ctrl_pnts, bd_pnts, recs, alpha=0.4): 174 | colors = [(0,0.5,0),(0,0.75,0),(1,0,1),(0.75,0,0.75),(0.5,0,0.5),(1,0,0),(0.75,0,0),(0.5,0,0), 175 | (0,0,1),(0,0,0.75),(0.75,0.25,0.25),(0.75,0.5,0.5),(0,0.75,0.75),(0,0.5,0.5),(0,0.3,0.75)] 176 | 177 | for ctrl_pnt, rec, bd in zip(ctrl_pnts, recs, bd_pnts): 178 | color = random.choice(colors) 179 | 180 | # draw polygons 181 | if bd is not None: 182 | bd = np.hsplit(bd, 2) 183 | bd = np.vstack([bd[0], bd[1][::-1]]) 184 | self.draw_polygon(bd, color, alpha=alpha) 185 | 186 | # draw center lines 187 | line = self._process_ctrl_pnt(ctrl_pnt) 188 | line_ = LineString(line) 189 | center_point = np.array(line_.interpolate(0.5, normalized=True).coords[0], dtype=np.int32) 190 | self.draw_line( 191 | line[:, 0], 192 | line[:, 1], 193 | color=color, 194 | linewidth=2 195 | ) 196 | for pt in line: 197 | self.draw_circle(pt, 'w', radius=4) 198 | self.draw_circle(pt, 'r', radius=2) 199 | 200 | # draw text 201 | text = self._ctc_decode_recognition(rec) 202 | if self.voc_size == 37: 203 | text = text.upper() 204 | # text = "{:.2f}: {}".format(score, text) 205 | text = "{}".format(text) 206 | lighter_color = self._change_color_brightness(color, brightness_factor=0) 207 | if bd is not None: 208 | text_pos = bd[0] - np.array([0,15]) 209 | else: 210 | text_pos = center_point 211 | horiz_align = "left" 212 | font_size = self._default_font_size 213 | self.draw_text( 214 | text, 215 | text_pos, 216 | color=lighter_color, 217 | horizontal_alignment=horiz_align, 218 | font_size=font_size, 219 | draw_chinese=False if self.voc_size == 37 or self.voc_size == 96 else True 220 | ) 221 | 222 | def draw_text( 223 | self, 224 | text, 225 | position, 226 | *, 227 | font_size=None, 228 | color="g", 229 | horizontal_alignment="center", 230 | rotation=0, 231 | draw_chinese=False 232 | ): 233 | """ 234 | Args: 235 | text (str): class label 236 | position (tuple): a tuple of the x and y coordinates to place text on image. 237 | font_size (int, optional): font of the text. If not provided, a font size 238 | proportional to the image width is calculated and used. 239 | color: color of the text. Refer to `matplotlib.colors` for full list 240 | of formats that are accepted. 241 | horizontal_alignment (str): see `matplotlib.text.Text` 242 | rotation: rotation angle in degrees CCW 243 | Returns: 244 | output (VisImage): image object with text drawn. 245 | """ 246 | if not font_size: 247 | font_size = self._default_font_size 248 | 249 | # since the text background is dark, we don't want the text to be dark 250 | color = np.maximum(list(mplc.to_rgb(color)), 0.2) 251 | color[np.argmax(color)] = max(0.8, np.max(color)) 252 | 253 | x, y = position 254 | if draw_chinese: 255 | font_path = "./simsun.ttc" 256 | prop = mfm.FontProperties(fname=font_path) 257 | self.output.ax.text( 258 | x, 259 | y, 260 | text, 261 | size=font_size * self.output.scale, 262 | family="sans-serif", 263 | bbox={"facecolor": "white", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, 264 | verticalalignment="top", 265 | horizontalalignment=horizontal_alignment, 266 | color=color, 267 | zorder=10, 268 | rotation=rotation, 269 | fontproperties=prop 270 | ) 271 | else: 272 | self.output.ax.text( 273 | x, 274 | y, 275 | text, 276 | size=font_size * self.output.scale, 277 | family="sans-serif", 278 | bbox={"facecolor": "white", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, 279 | verticalalignment="top", 280 | horizontalalignment=horizontal_alignment, 281 | color=color, 282 | zorder=10, 283 | rotation=rotation, 284 | ) 285 | return self.output --------------------------------------------------------------------------------