├── Figure ├── Readme.md └── img.png ├── README.md └── Segmentation ├── .gitignore ├── 0.1.1 ├── CITATION.cff ├── Figure └── img.png ├── LICENSE ├── MANIFEST.in ├── Qtrick_architecture ├── __init__.py └── clock_driven │ ├── __init__.py │ ├── ann2snn │ ├── __init__.py │ ├── converter.py │ ├── examples │ │ ├── __init__.py │ │ ├── cnn_mnist.py │ │ └── resnet18_cifar10.py │ ├── modules.py │ └── utils.py │ ├── base.py │ ├── configure.py │ ├── cu_kernel_opt.py │ ├── encoding.py │ ├── examples │ ├── A2C.py │ ├── DQN_state.py │ ├── E_Spikingformer.py │ ├── PPO.py │ ├── Spiking_A2C.py │ ├── Spiking_DQN_state.py │ ├── Spiking_PPO.py │ ├── __init__.py │ ├── cifar10_r11_enabling_spikebased_backpropagation.py │ ├── classify_dvsg.py │ ├── common │ │ ├── __init__.py │ │ └── multiprocessing_env.py │ ├── conv_fashion_mnist.py │ ├── dqn_cart_pole.py │ ├── lif_fc_mnist.py │ ├── rsnn_sequential_fmnist.py │ ├── speechcommands.py │ ├── spiking_lstm_sequential_mnist.py │ └── spiking_lstm_text.py │ ├── functional.py │ ├── lava_exchange.py │ ├── layer.py │ ├── model │ ├── __init__.py │ ├── parametric_lif_net.py │ ├── sew_resnet.py │ ├── spiking_resnet.py │ ├── spiking_vgg.py │ ├── train_classify.py │ └── train_imagenet.py │ ├── monitor.py │ ├── neuron.py │ ├── neuron_kernel.py │ ├── rnn.py │ ├── spike_op.py │ ├── surrogate.py │ └── tensor_cache.py ├── README.md ├── configs ├── FPN │ ├── fpn_sdtv2_512x512_ade20k.py │ └── fpn_sdtv3_512x512_ade20k.py ├── Spike2Former │ ├── SDTv2_Spike2former_voc_512x512.py │ ├── SDTv2_maskformer_DCNPixelDecoder_CityScapes.py │ ├── SDTv2_maskformer_DCNpixelDecoder_ade20k.py │ ├── SDTv2_maskformer_cocostuff10k_512x512.py │ ├── SDTv2_maskformer_cocostuff164k_512x512.py │ ├── SDTv3_b_Spike2former_Cityscapes_512x1024.py │ ├── SDTv3_b_Spike2former_ade20k_512x512.py │ ├── SDTv3_b_Spike2former_voc_512x512.py │ ├── fpn_sdtv3_512x512_10M_ade20k.py │ └── fpn_sdtv3_512x512_19M_ade20k.py └── _base_ │ ├── datasets │ ├── ade20k.py │ ├── ade20k_640x640.py │ ├── chase_db1.py │ ├── cityscapes.py │ ├── cityscapes_1024x1024.py │ ├── cityscapes_768x768.py │ ├── cityscapes_769x769.py │ ├── cityscapes_832x832.py │ ├── coco-stuff10k.py │ ├── coco-stuff164k.py │ ├── ddd17.py │ ├── pascal_context.py │ ├── pascal_context_59.py │ ├── pascal_voc12.py │ ├── pascal_voc12_aug.py │ └── synapse.py │ ├── default_runtime.py │ ├── models │ ├── fpn_snn_r50.py │ └── snn_sdtv2_fpn.py │ └── schedules │ ├── schedule_160k.py │ ├── schedule_320k.py │ └── schedule_80k.py ├── dataset-index.yml ├── docs ├── en │ ├── Makefile │ ├── _static │ │ ├── css │ │ │ └── readthedocs.css │ │ └── images │ │ │ └── mmsegmentation.png │ ├── advanced_guides │ │ ├── add_datasets.md │ │ ├── add_metrics.md │ │ ├── add_models.md │ │ ├── add_transforms.md │ │ ├── customize_runtime.md │ │ ├── data_flow.md │ │ ├── datasets.md │ │ ├── engine.md │ │ ├── evaluation.md │ │ ├── index.rst │ │ ├── models.md │ │ ├── structures.md │ │ ├── training_tricks.md │ │ └── transforms.md │ ├── api.rst │ ├── conf.py │ ├── device │ │ └── npu.md │ ├── get_started.md │ ├── index.rst │ ├── make.bat │ ├── migration │ │ ├── index.rst │ │ ├── interface.md │ │ └── package.md │ ├── model_zoo.md │ ├── modelzoo_statistics.md │ ├── notes │ │ ├── changelog.md │ │ ├── changelog_v0.x.md │ │ └── faq.md │ ├── overview.md │ ├── stat.py │ ├── switch_language.md │ └── user_guides │ │ ├── 1_config.md │ │ ├── 2_dataset_prepare.md │ │ ├── 3_inference.md │ │ ├── 4_train_test.md │ │ ├── 5_deployment.md │ │ ├── index.rst │ │ ├── useful_tools.md │ │ ├── visualization.md │ │ └── visualization_feature_map.md └── zh_cn │ ├── Makefile │ ├── _static │ ├── css │ │ └── readthedocs.css │ └── images │ │ └── mmsegmentation.png │ ├── advanced_guides │ ├── add_datasets.md │ ├── add_metrics.md │ ├── add_models.md │ ├── add_transforms.md │ ├── contribute_dataset.md │ ├── customize_runtime.md │ ├── data_flow.md │ ├── datasets.md │ ├── engine.md │ ├── evaluation.md │ ├── index.rst │ ├── models.md │ ├── structures.md │ ├── training_tricks.md │ └── transforms.md │ ├── api.rst │ ├── conf.py │ ├── device │ └── npu.md │ ├── get_started.md │ ├── imgs │ ├── qq_group_qrcode.jpg │ ├── seggroup_qrcode.jpg │ └── zhihu_qrcode.jpg │ ├── index.rst │ ├── make.bat │ ├── migration │ ├── index.rst │ ├── interface.md │ └── package.md │ ├── model_zoo.md │ ├── modelzoo_statistics.md │ ├── notes │ └── faq.md │ ├── overview.md │ ├── stat.py │ ├── switch_language.md │ └── user_guides │ ├── 1_config.md │ ├── 2_dataset_prepare.md │ ├── 3_inference.md │ ├── 4_train_test.md │ ├── 5_deployment.md │ ├── index.rst │ ├── useful_tools.md │ ├── visualization.md │ └── visualization_feature_map.md ├── mmdet ├── __init__.py ├── models │ ├── __init__.py │ ├── dense_heads │ │ ├── __init__.py │ │ ├── anchor_free_head.py │ │ ├── base_dense_head.py │ │ ├── base_mask_head.py │ │ └── maskformer_head.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── brick_wrappers.py │ │ ├── conv_upsample.py │ │ ├── dropblock.py │ │ ├── matrix_nms.py │ │ ├── normed_predictor.py │ │ ├── pixel_decoder.py │ │ ├── positional_encoding.py │ │ └── transformer │ │ │ ├── Spike2former_layers.py │ │ │ ├── __init__.py │ │ │ ├── base_blocks.py │ │ │ ├── dab_detr_layers.py │ │ │ ├── deformable_detr_layers.py │ │ │ ├── detr_layers.py │ │ │ ├── mask2former_layers.py │ │ │ ├── mmcv_spike │ │ │ ├── BASE_Transformer.py │ │ │ ├── CycleMLP.py │ │ │ ├── SNN_core.py │ │ │ ├── __init__.py │ │ │ ├── base_blocks.py │ │ │ ├── drop.py │ │ │ ├── ext_loader.py │ │ │ ├── multi_scale_deform_attn.py │ │ │ ├── scale.py │ │ │ ├── spikeformer.py │ │ │ └── transformer.py │ │ │ ├── ops_dcnv3 │ │ │ ├── __init__.py │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ └── dcnv3_func.py │ │ │ ├── make.sh │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ └── dcnv3.py │ │ │ ├── setup.py │ │ │ ├── src │ │ │ │ ├── cpu │ │ │ │ │ ├── dcnv3_cpu.cpp │ │ │ │ │ └── dcnv3_cpu.h │ │ │ │ ├── cuda │ │ │ │ │ ├── dcnv3_cuda.cu │ │ │ │ │ ├── dcnv3_cuda.h │ │ │ │ │ └── dcnv3_im2col_cuda.cuh │ │ │ │ ├── dcnv3.h │ │ │ │ └── vision.cpp │ │ │ └── test.py │ │ │ └── utils.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── cross_entropy_loss.py │ │ ├── dice_loss.py │ │ ├── focal_loss.py │ │ ├── iou_loss.py │ │ ├── mse_loss.py │ │ ├── multipos_cross_entropy_loss.py │ │ └── utils.py │ ├── task_modules │ │ ├── __init__.py │ │ ├── assigners │ │ │ ├── __init__.py │ │ │ ├── approx_max_iou_assigner.py │ │ │ ├── assign_result.py │ │ │ ├── atss_assigner.py │ │ │ ├── base_assigner.py │ │ │ ├── center_region_assigner.py │ │ │ ├── dynamic_soft_label_assigner.py │ │ │ ├── grid_assigner.py │ │ │ ├── hungarian_assigner.py │ │ │ ├── iou2d_calculator.py │ │ │ ├── match_cost.py │ │ │ ├── max_iou_assigner.py │ │ │ ├── multi_instance_assigner.py │ │ │ ├── point_assigner.py │ │ │ ├── region_assigner.py │ │ │ ├── sim_ota_assigner.py │ │ │ ├── task_aligned_assigner.py │ │ │ └── uniform_assigner.py │ │ ├── builder.py │ │ ├── prior_generators │ │ │ ├── __init__.py │ │ │ ├── anchor_generator.py │ │ │ ├── point_generator.py │ │ │ └── utils.py │ │ └── samplers │ │ │ ├── __init__.py │ │ │ ├── base_sampler.py │ │ │ ├── combined_sampler.py │ │ │ ├── instance_balanced_pos_sampler.py │ │ │ ├── iou_balanced_neg_sampler.py │ │ │ ├── mask_pseudo_sampler.py │ │ │ ├── mask_sampling_result.py │ │ │ ├── multi_instance_random_sampler.py │ │ │ ├── multi_instance_sampling_result.py │ │ │ ├── ohem_sampler.py │ │ │ ├── pseudo_sampler.py │ │ │ ├── random_sampler.py │ │ │ ├── sampling_result.py │ │ │ └── score_hlr_sampler.py │ ├── test_time_augs │ │ ├── __init__.py │ │ ├── det_tta.py │ │ └── merge_augs.py │ ├── tracking_heads │ │ ├── __init__.py │ │ └── mask2former_track_head.py │ └── utils │ │ ├── Qtrick.py │ │ ├── __init__.py │ │ ├── gaussian_target.py │ │ ├── image.py │ │ ├── make_divisible.py │ │ ├── misc.py │ │ ├── panoptic_gt_processing.py │ │ ├── point_sample.py │ │ └── vlfuse_helper.py ├── registry.py ├── structures │ ├── __init__.py │ ├── bbox │ │ ├── __init__.py │ │ ├── base_boxes.py │ │ ├── bbox_overlaps.py │ │ ├── box_type.py │ │ ├── horizontal_boxes.py │ │ └── transforms.py │ ├── det_data_sample.py │ ├── mask │ │ ├── __init__.py │ │ ├── mask_target.py │ │ ├── structures.py │ │ └── utils.py │ ├── reid_data_sample.py │ └── track_data_sample.py ├── utils │ ├── __init__.py │ ├── benchmark.py │ ├── collect_env.py │ ├── compat_config.py │ ├── contextmanagers.py │ ├── dist_utils.py │ ├── logger.py │ ├── memory.py │ ├── misc.py │ ├── mot_error_visualize.py │ ├── profiling.py │ ├── replace_cfg_vals.py │ ├── setup_env.py │ ├── split_batch.py │ ├── typing_utils.py │ ├── util_mixins.py │ └── util_random.py └── version.py ├── mmseg ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ └── mmseg_inferencer.py ├── datasets │ ├── __init__.py │ ├── ade.py │ ├── basesegdataset.py │ ├── chase_db1.py │ ├── cityscapes.py │ ├── coco_stuff.py │ ├── dataset_wrappers.py │ ├── ddd17.py │ ├── drive.py │ ├── pascal_context.py │ ├── synapse.py │ ├── transforms │ │ ├── __init__.py │ │ ├── formatting.py │ │ ├── loading.py │ │ └── transforms.py │ └── voc.py ├── engine │ ├── __init__.py │ ├── hooks │ │ ├── __init__.py │ │ ├── cal_firing_rate.py │ │ ├── resetmodel_hook.py │ │ └── visualization_hook.py │ └── optimizers │ │ ├── __init__.py │ │ └── layer_decay_optimizer_constructor.py ├── evaluation │ ├── __init__.py │ └── metrics │ │ ├── __init__.py │ │ ├── citys_metric.py │ │ └── iou_metric.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── MSResnet.py │ │ ├── __init__.py │ │ ├── sdtv2.py │ │ ├── sdtv3.py │ │ ├── sdtv3MAE.py │ │ └── spike-sdtv3.py │ ├── builder.py │ ├── data_preprocessor.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── anchor_free_head.py │ │ ├── base_dense_head.py │ │ ├── cascade_decode_head.py │ │ ├── decode_head.py │ │ ├── fpn_head.py │ │ ├── mask2former_head.py │ │ └── maskformer_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── boundary_loss.py │ │ ├── cross_entropy_loss.py │ │ ├── dice_loss.py │ │ ├── focal_loss.py │ │ ├── huasdorff_distance_loss.py │ │ ├── lovasz_loss.py │ │ ├── ohem_cross_entropy_loss.py │ │ ├── tversky_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ └── fpn.py │ ├── segmentors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cascade_encoder_decoder.py │ │ ├── encoder_decoder.py │ │ └── seg_tta.py │ └── utils │ │ ├── Qtrick.py │ │ ├── __init__.py │ │ ├── basic_block.py │ │ ├── embed.py │ │ ├── encoding.py │ │ ├── inverted_residual.py │ │ ├── make_divisible.py │ │ ├── misc.py │ │ ├── panoptic_gt_processing.py │ │ ├── point_sample.py │ │ ├── ppm.py │ │ ├── res_layer.py │ │ ├── se_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ ├── up_conv_block.py │ │ └── wrappers.py ├── registry │ ├── __init__.py │ └── registry.py ├── structures │ ├── __init__.py │ ├── sampler │ │ ├── __init__.py │ │ ├── base_pixel_sampler.py │ │ ├── builder.py │ │ └── ohem_pixel_sampler.py │ └── seg_data_sample.py ├── utils │ ├── __init__.py │ ├── class_names.py │ ├── collect_env.py │ ├── io.py │ ├── misc.py │ ├── panoptic_gt_processing.py │ ├── set_env.py │ └── typing_utils.py ├── version.py └── visualization │ ├── __init__.py │ └── local_visualizer.py ├── model-index.yml ├── requirements.txt ├── requirements ├── albu.txt ├── docs.txt ├── mminstall.txt ├── optional.txt ├── readthedocs.txt ├── runtime.txt └── tests.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_apis │ └── test_inferencer.py ├── test_config.py ├── test_datasets │ ├── test_dataset.py │ ├── test_dataset_builder.py │ ├── test_formatting.py │ ├── test_loading.py │ ├── test_transform.py │ └── test_tta.py ├── test_digit_version.py ├── test_engine │ ├── test_layer_decay_optimizer_constructor.py │ ├── test_optimizer.py │ └── test_visualization_hook.py ├── test_evaluation │ └── test_metrics │ │ ├── test_citys_metric.py │ │ └── test_iou_metric.py ├── test_models │ ├── __init__.py │ ├── test_backbones │ │ ├── __init__.py │ │ ├── test_beit.py │ │ ├── test_bisenetv1.py │ │ ├── test_bisenetv2.py │ │ ├── test_blocks.py │ │ ├── test_cgnet.py │ │ ├── test_erfnet.py │ │ ├── test_fast_scnn.py │ │ ├── test_hrnet.py │ │ ├── test_icnet.py │ │ ├── test_mae.py │ │ ├── test_mit.py │ │ ├── test_mobilenet_v3.py │ │ ├── test_mscan.py │ │ ├── test_pidnet.py │ │ ├── test_resnest.py │ │ ├── test_resnet.py │ │ ├── test_resnext.py │ │ ├── test_stdc.py │ │ ├── test_swin.py │ │ ├── test_timm_backbone.py │ │ ├── test_twins.py │ │ ├── test_unet.py │ │ ├── test_vit.py │ │ └── utils.py │ ├── test_data_preprocessor.py │ ├── test_forward.py │ ├── test_heads │ │ ├── __init__.py │ │ ├── test_ann_head.py │ │ ├── test_apc_head.py │ │ ├── test_aspp_head.py │ │ ├── test_cc_head.py │ │ ├── test_decode_head.py │ │ ├── test_dm_head.py │ │ ├── test_dnl_head.py │ │ ├── test_dpt_head.py │ │ ├── test_ema_head.py │ │ ├── test_fcn_head.py │ │ ├── test_gc_head.py │ │ ├── test_ham_head.py │ │ ├── test_isa_head.py │ │ ├── test_lraspp_head.py │ │ ├── test_mask2former_head.py │ │ ├── test_maskformer_head.py │ │ ├── test_nl_head.py │ │ ├── test_ocr_head.py │ │ ├── test_pidnet_head.py │ │ ├── test_psa_head.py │ │ ├── test_psp_head.py │ │ ├── test_segformer_head.py │ │ ├── test_segmenter_mask_head.py │ │ ├── test_setr_mla_head.py │ │ ├── test_setr_up_head.py │ │ ├── test_uper_head.py │ │ └── utils.py │ ├── test_losses │ │ ├── test_dice_loss.py │ │ ├── test_huasdorff_distance_loss.py │ │ └── test_tversky_loss.py │ ├── test_necks │ │ ├── __init__.py │ │ ├── test_feature2pyramid.py │ │ ├── test_fpn.py │ │ ├── test_ic_neck.py │ │ ├── test_jpu.py │ │ ├── test_mla_neck.py │ │ └── test_multilevel_neck.py │ ├── test_segmentors │ │ ├── __init__.py │ │ ├── test_cascade_encoder_decoder.py │ │ ├── test_encoder_decoder.py │ │ ├── test_seg_tta_model.py │ │ └── utils.py │ └── test_utils │ │ ├── __init__.py │ │ ├── test_embed.py │ │ └── test_shape_convert.py ├── test_sampler.py ├── test_structures │ └── test_seg_data_sample.py ├── test_utils │ ├── test_io.py │ └── test_set_env.py └── test_visualization │ └── test_local_visualizer.py └── tools ├── Calculation_tools.py ├── analysis_tools ├── analyze_logs.py ├── benchmark.py ├── browse_dataset.py ├── confusion_matrix.py ├── get_flops.py ├── profile.py └── utils.py ├── cal_firing_num.py ├── dataset_converters ├── chase_db1.py ├── cityscapes.py ├── coco_stuff10k.py ├── coco_stuff164k.py ├── drive.py ├── hrf.py ├── isaid.py ├── levircd.py ├── loveda.py ├── pascal_context.py ├── potsdam.py ├── pro_gen1.py ├── prophesee │ ├── __init__.py │ ├── io │ │ ├── __init__.py │ │ ├── box_filtering.py │ │ ├── box_loading.py │ │ ├── dat_events_tools.py │ │ ├── npy_events_tools.py │ │ └── psee_loader.py │ ├── metrics │ │ ├── __init__.py │ │ └── coco_eval.py │ ├── psee_evaluator.py │ └── visualize │ │ ├── __init__.py │ │ └── vis_utils.py ├── refuge.py ├── stare.py ├── synapse.py ├── vaihingen.py └── voc_aug.py ├── deployment └── pytorch2torchscript.py ├── dist_test.sh ├── dist_train.sh ├── firing_utils ├── dataset.py └── misc.py ├── misc ├── browse_dataset.py ├── print_config.py └── publish_model.py ├── model_converters ├── beit2mmseg.py ├── mit2mmseg.py ├── stdc2mmseg.py ├── swin2mmseg.py ├── twins2mmseg.py ├── vit2mmseg.py └── vitjax2mmseg.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── test.sh └── train.py /Figure/Readme.md: -------------------------------------------------------------------------------- 1 | Put Figure Here 2 | -------------------------------------------------------------------------------- /Figure/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Figure/img.png -------------------------------------------------------------------------------- /Segmentation/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/en/_build/ 68 | docs/zh_cn/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | .DS_Store 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | data 109 | .vscode 110 | .idea 111 | 112 | # custom 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | mmseg/.mim 118 | 119 | # Pytorch 120 | *.pth 121 | -------------------------------------------------------------------------------- /Segmentation/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - name: "MMSegmentation Contributors" 5 | title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark" 6 | date-released: 2020-07-10 7 | url: "https://github.com/open-mmlab/mmsegmentation" 8 | license: Apache-2.0 9 | -------------------------------------------------------------------------------- /Segmentation/Figure/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Figure/img.png -------------------------------------------------------------------------------- /Segmentation/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements/*.txt 2 | include mmseg/.mim/model-index.yml 3 | include mmaction/.mim/dataset-index.yml 4 | recursive-include mmseg/.mim/configs *.py *.yaml 5 | recursive-include mmseg/.mim/tools *.py *.sh 6 | -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/__init__.py: -------------------------------------------------------------------------------- 1 | from . import clock_driven 2 | -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/__init__.py -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/ann2snn/__init__.py: -------------------------------------------------------------------------------- 1 | from spikingjelly.clock_driven.ann2snn.converter import Converter 2 | from spikingjelly.clock_driven.ann2snn.utils import download_url -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/ann2snn/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/ann2snn/examples/__init__.py -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/ann2snn/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | from tqdm import tqdm 4 | 5 | def download_url(url, dst): 6 | headers = { 7 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:67.0) Gecko/20100101 Firefox/67.0' 8 | } 9 | 10 | response = requests.get(url, headers=headers, stream=True) # (1) 11 | file_size = int(response.headers['content-length']) # (2) 12 | if os.path.exists(dst): 13 | first_byte = os.path.getsize(dst) # (3) 14 | else: 15 | first_byte = 0 16 | if first_byte >= file_size: # (4) 17 | return file_size 18 | 19 | header = {"Range": f"bytes={first_byte}-{file_size}"} 20 | 21 | pbar = tqdm(total=file_size, initial=first_byte, unit='B', unit_scale=True, desc=dst) 22 | req = requests.get(url, headers=header, stream=True) # (5) 23 | with open(dst, 'ab') as f: 24 | for chunk in req.iter_content(chunk_size=1024): # (6) 25 | if chunk: 26 | f.write(chunk) 27 | pbar.update(1024) 28 | pbar.close() 29 | return file_size -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/examples/__init__.py -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/examples/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/examples/common/__init__.py -------------------------------------------------------------------------------- /Segmentation/Qtrick_architecture/clock_driven/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/model/__init__.py -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/cityscapes_1024x1024.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | crop_size = (1024, 1024) 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile'), 5 | dict(type='LoadAnnotations'), 6 | dict( 7 | type='RandomResize', 8 | scale=(2048, 1024), 9 | ratio_range=(0.5, 2.0), 10 | keep_ratio=True), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='PackSegInputs') 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict(type='Resize', scale=(2048, 1024), keep_ratio=True), 19 | # add loading annotation after ``Resize`` because ground truth 20 | # does not need to do resize data transform 21 | dict(type='LoadAnnotations'), 22 | dict(type='PackSegInputs') 23 | ] 24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) 25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) 26 | test_dataloader = val_dataloader 27 | 28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 29 | test_evaluator = val_evaluator 30 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/cityscapes_768x768.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | crop_size = (768, 768) 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile'), 5 | dict(type='LoadAnnotations'), 6 | dict( 7 | type='RandomResize', 8 | scale=(2049, 1025), 9 | ratio_range=(0.5, 2.0), 10 | keep_ratio=True), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='PackSegInputs') 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict(type='Resize', scale=(2049, 1025), keep_ratio=True), 19 | # add loading annotation after ``Resize`` because ground truth 20 | # does not need to do resize data transform 21 | dict(type='LoadAnnotations'), 22 | dict(type='PackSegInputs') 23 | ] 24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) 25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) 26 | test_dataloader = val_dataloader 27 | 28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 29 | test_evaluator = val_evaluator 30 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/cityscapes_769x769.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | crop_size = (769, 769) 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile'), 5 | dict(type='LoadAnnotations'), 6 | dict( 7 | type='RandomResize', 8 | scale=(2049, 1025), 9 | ratio_range=(0.5, 2.0), 10 | keep_ratio=True), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='PackSegInputs') 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict(type='Resize', scale=(2049, 1025), keep_ratio=True), 19 | # add loading annotation after ``Resize`` because ground truth 20 | # does not need to do resize data transform 21 | dict(type='LoadAnnotations'), 22 | dict(type='PackSegInputs') 23 | ] 24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) 25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) 26 | test_dataloader = val_dataloader 27 | 28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 29 | test_evaluator = val_evaluator 30 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/cityscapes_832x832.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | crop_size = (832, 832) 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile'), 5 | dict(type='LoadAnnotations'), 6 | dict( 7 | type='RandomResize', 8 | scale=(2048, 1024), 9 | ratio_range=(0.5, 2.0), 10 | keep_ratio=True), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='PackSegInputs') 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict(type='Resize', scale=(2048, 1024), keep_ratio=True), 19 | # add loading annotation after ``Resize`` because ground truth 20 | # does not need to do resize data transform 21 | dict(type='LoadAnnotations'), 22 | dict(type='PackSegInputs') 23 | ] 24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) 25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) 26 | test_dataloader = val_dataloader 27 | 28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 29 | test_evaluator = val_evaluator 30 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/ddd17.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DDD17Dataset' 3 | 4 | data_root = '/public/liguoqi/lzx/data/ddd17_seg/T4' 5 | 6 | crop_size = (200, 352) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromNpyFile'), 9 | dict(type='LoadAnnotations'), 10 | # dict(type='RandomCrop', crop_size=crop_size), 11 | dict(type='PackSegInputs') 12 | ] 13 | test_pipeline = [ 14 | dict(type='LoadImageFromNpyFile'), 15 | # add loading annotation after ``Resize`` because ground truth 16 | # does not need to do resize data transform 17 | dict(type='LoadAnnotations'), 18 | # dict(type='RandomCrop', crop_size=crop_size), 19 | dict(type='PackSegInputs') 20 | ] 21 | img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 22 | # tta_pipeline = [ 23 | # dict(type='LoadImageFromFile', backend_args=None), 24 | # dict( 25 | # type='TestTimeAug', 26 | # transforms=[ 27 | # [ 28 | # dict(type='Resize', scale_factor=r, keep_ratio=True) 29 | # for r in img_ratios 30 | # ], 31 | # [ 32 | # dict(type='RandomFlip', prob=0., direction='horizontal'), 33 | # dict(type='RandomFlip', prob=1., direction='horizontal') 34 | # ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] 35 | # ]) 36 | # ] 37 | train_dataloader = dict( 38 | batch_size=6, 39 | num_workers=16, 40 | persistent_workers=True, 41 | sampler=dict(type='InfiniteSampler', shuffle=True), 42 | dataset=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | data_prefix=dict( 46 | img_path='images/training', seg_map_path='annotations/training'), 47 | pipeline=train_pipeline)) 48 | val_dataloader = dict( 49 | batch_size=1, 50 | num_workers=16, 51 | persistent_workers=True, 52 | sampler=dict(type='DefaultSampler', shuffle=False), 53 | dataset=dict( 54 | type=dataset_type, 55 | data_root=data_root, 56 | data_prefix=dict( 57 | img_path='images/validation', 58 | seg_map_path='annotations/validation'), 59 | pipeline=test_pipeline)) 60 | test_dataloader = val_dataloader 61 | 62 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 63 | test_evaluator = val_evaluator 64 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'PascalContextDataset' 3 | data_root = 'data/VOCdevkit/VOC2010/' 4 | 5 | img_scale = (520, 520) 6 | crop_size = (480, 480) 7 | 8 | train_pipeline = [ 9 | dict(type='LoadImageFromFile'), 10 | dict(type='LoadAnnotations'), 11 | dict( 12 | type='RandomResize', 13 | scale=img_scale, 14 | ratio_range=(0.5, 2.0), 15 | keep_ratio=True), 16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 17 | dict(type='RandomFlip', prob=0.5), 18 | dict(type='PhotoMetricDistortion'), 19 | dict(type='PackSegInputs') 20 | ] 21 | test_pipeline = [ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='Resize', scale=img_scale, keep_ratio=True), 24 | # add loading annotation after ``Resize`` because ground truth 25 | # does not need to do resize data transform 26 | dict(type='LoadAnnotations'), 27 | dict(type='PackSegInputs') 28 | ] 29 | train_dataloader = dict( 30 | batch_size=4, 31 | num_workers=4, 32 | persistent_workers=True, 33 | sampler=dict(type='InfiniteSampler', shuffle=True), 34 | dataset=dict( 35 | type=dataset_type, 36 | data_root=data_root, 37 | data_prefix=dict( 38 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 39 | ann_file='ImageSets/SegmentationContext/train.txt', 40 | pipeline=train_pipeline)) 41 | val_dataloader = dict( 42 | batch_size=1, 43 | num_workers=4, 44 | persistent_workers=True, 45 | sampler=dict(type='DefaultSampler', shuffle=False), 46 | dataset=dict( 47 | type=dataset_type, 48 | data_root=data_root, 49 | data_prefix=dict( 50 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 51 | ann_file='ImageSets/SegmentationContext/val.txt', 52 | pipeline=test_pipeline)) 53 | test_dataloader = val_dataloader 54 | 55 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 56 | test_evaluator = val_evaluator 57 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/datasets/synapse.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'SynapseDataset' 2 | data_root = 'data/synapse/' 3 | img_scale = (224, 224) 4 | train_pipeline = [ 5 | dict(type='LoadImageFromFile'), 6 | dict(type='LoadAnnotations'), 7 | dict(type='Resize', scale=img_scale, keep_ratio=True), 8 | dict(type='RandomRotFlip', rotate_prob=0.5, flip_prob=0.5, degree=20), 9 | dict(type='PackSegInputs') 10 | ] 11 | test_pipeline = [ 12 | dict(type='LoadImageFromFile'), 13 | dict(type='Resize', scale=img_scale, keep_ratio=True), 14 | dict(type='LoadAnnotations'), 15 | dict(type='PackSegInputs') 16 | ] 17 | train_dataloader = dict( 18 | batch_size=6, 19 | num_workers=2, 20 | persistent_workers=True, 21 | sampler=dict(type='InfiniteSampler', shuffle=True), 22 | dataset=dict( 23 | type=dataset_type, 24 | data_root=data_root, 25 | data_prefix=dict( 26 | img_path='img_dir/train', seg_map_path='ann_dir/train'), 27 | pipeline=train_pipeline)) 28 | val_dataloader = dict( 29 | batch_size=1, 30 | num_workers=4, 31 | persistent_workers=True, 32 | sampler=dict(type='DefaultSampler', shuffle=False), 33 | dataset=dict( 34 | type=dataset_type, 35 | data_root=data_root, 36 | data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'), 37 | pipeline=test_pipeline)) 38 | test_dataloader = val_dataloader 39 | 40 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice']) 41 | test_evaluator = val_evaluator 42 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | default_scope = 'mmseg' 2 | env_cfg = dict( 3 | cudnn_benchmark=True, 4 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 5 | dist_cfg=dict(backend='nccl'), 6 | ) 7 | vis_backends = [dict(type='LocalVisBackend')] 8 | 9 | # vis_backends = [] 10 | visualizer = dict( 11 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') 12 | 13 | 14 | log_processor = dict(by_epoch=False) 15 | log_level = 'INFO' 16 | load_from = None 17 | resume = False 18 | custom_imports = dict(imports=['mmseg.engine.hooks.resetmodel_hook'], allow_failed_imports=False) 19 | custom_hooks = [ 20 | dict(type='ResetModelHook') 21 | ] 22 | 23 | tta_model = dict(type='SegTTAModel') 24 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/models/fpn_snn_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | data_preprocessor = dict( 4 | type='SegDataPreProcessor', 5 | mean=[123.675, 116.28, 103.53], 6 | std=[58.395, 57.12, 57.375], 7 | bgr_to_rgb=True, 8 | pad_val=0, 9 | seg_pad_val=255) 10 | model = dict( 11 | type='EncoderDecoder', 12 | data_preprocessor=data_preprocessor, 13 | # pretrained='open-mmlab://resnet50_v1c', 14 | backbone=dict( 15 | type='ResNetV1c', 16 | ), 17 | neck=dict( 18 | type='QFPN', 19 | in_channels=[256, 512, 1024, 2048], 20 | out_channels=256, 21 | num_outs=4, 22 | T=4, 23 | ), 24 | decode_head=dict( 25 | type='QFPNHead', 26 | in_channels=[256, 256, 256, 256], 27 | in_index=[0, 1, 2, 3], 28 | feature_strides=[4, 8, 16, 32], 29 | channels=128, 30 | dropout_ratio=0.1, 31 | num_classes=150, 32 | T=4, 33 | norm_cfg=norm_cfg, 34 | align_corners=False, 35 | loss_decode=dict( 36 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 37 | # model training and testing settings 38 | train_cfg=dict(), 39 | test_cfg=dict(mode='whole')) 40 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/models/snn_sdtv2_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | data_preprocessor = dict( 4 | type='SegDataPreProcessor', 5 | mean=[123.675, 116.28, 103.53], 6 | std=[58.395, 57.12, 57.375], 7 | bgr_to_rgb=True, 8 | pad_val=0, 9 | seg_pad_val=255) 10 | model = dict( 11 | type='EncoderDecoder', 12 | data_preprocessor=data_preprocessor, 13 | pretrained='/raid/ligq/lzx/spikeformerv2/seg/checkpoint/checkpoint-199.pth', 14 | backbone=dict( 15 | # init_cfg=dict(type='Pretrained', checkpoint="/raid/ligq/lzx/spikeformerv2/seg/checkpoint/checkpoint-199.pth"), 16 | type='Sdtv2', 17 | img_size_h=512, 18 | img_size_w=512, 19 | embed_dim=[128, 256, 512, 640], 20 | num_classes=150, 21 | T=1, 22 | qkv_bias=False, 23 | decode_mode='snn', 24 | ), 25 | neck=dict( 26 | type='FPN_SNN', 27 | in_channels=[256, 512, 1024, 2048], 28 | out_channels=256, 29 | num_outs=4), 30 | decode_head=dict( 31 | type='FPNHead_SNN', 32 | in_channels=[256, 256, 256, 256], 33 | in_index=[0, 1, 2, 3], 34 | feature_strides=[4, 8, 16, 32], 35 | channels=128, 36 | dropout_ratio=0.1, 37 | num_classes=150, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) 4 | # learning policy 5 | param_scheduler = [ 6 | dict( 7 | type='PolyLR', 8 | eta_min=1e-5, 9 | power=0.9, 10 | begin=0, 11 | end=160000, 12 | by_epoch=False) 13 | ] 14 | # training schedule for 160k 15 | train_cfg = dict( 16 | type='IterBasedTrainLoop', max_iters=160000, val_interval=16000) 17 | val_cfg = dict(type='ValLoop') 18 | test_cfg = dict(type='TestLoop') 19 | default_hooks = dict( 20 | timer=dict(type='IterTimerHook'), 21 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), 22 | param_scheduler=dict(type='ParamSchedulerHook'), 23 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=16000), 24 | sampler_seed=dict(type='DistSamplerSeedHook'), 25 | visualization=dict(type='SegVisualizationHook', draw=True, interval=1)) 26 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/schedules/schedule_320k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) 4 | # learning policy 5 | param_scheduler = [ 6 | dict( 7 | type='PolyLR', 8 | eta_min=1e-4, 9 | power=0.9, 10 | begin=0, 11 | end=320000, 12 | by_epoch=False) 13 | ] 14 | # training schedule for 320k 15 | train_cfg = dict( 16 | type='IterBasedTrainLoop', max_iters=320000, val_interval=32000) 17 | val_cfg = dict(type='ValLoop') 18 | test_cfg = dict(type='TestLoop') 19 | default_hooks = dict( 20 | timer=dict(type='IterTimerHook'), 21 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), 22 | param_scheduler=dict(type='ParamSchedulerHook'), 23 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=32000), 24 | sampler_seed=dict(type='DistSamplerSeedHook'), 25 | visualization=dict(type='SegVisualizationHook')) 26 | -------------------------------------------------------------------------------- /Segmentation/configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None) 4 | # learning policy 5 | param_scheduler = [ 6 | dict( 7 | type='PolyLR', 8 | eta_min=1e-4, 9 | power=0.9, 10 | begin=0, 11 | end=80000, 12 | by_epoch=False) 13 | ] 14 | # training schedule for 80k 15 | train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000) 16 | val_cfg = dict(type='ValLoop') 17 | test_cfg = dict(type='TestLoop') 18 | default_hooks = dict( 19 | timer=dict(type='IterTimerHook'), 20 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), 21 | param_scheduler=dict(type='ParamSchedulerHook'), 22 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000), 23 | sampler_seed=dict(type='DistSamplerSeedHook'), 24 | visualization=dict(type='SegVisualizationHook')) 25 | -------------------------------------------------------------------------------- /Segmentation/dataset-index.yml: -------------------------------------------------------------------------------- 1 | ade20k: 2 | dataset: ADE20K_2016 3 | download_root: data 4 | data_root: data/ade 5 | 6 | cityscapes: 7 | dataset: CityScapes 8 | download_root: data 9 | data_root: data/cityscapes 10 | 11 | voc2012: 12 | dataset: PASCAL_VOC2012 13 | download_root: data 14 | data_root: data/VOCdevkit/VOC2012 15 | 16 | cocostuff: 17 | dataset: COCO-Stuff 18 | download_root: data 19 | data_root: data/coco_stuff164k 20 | 21 | mapillary: 22 | dataset: Mapillary 23 | download_root: data 24 | data_root: data/mapillary 25 | 26 | pascal_context: 27 | dataset: VOC2010 28 | download_root: data 29 | data_root: data/VOCdevkit/VOC2010 30 | 31 | isaid: 32 | dataset: iSAID 33 | download_root: data 34 | data_root: data/iSAID 35 | 36 | isprs_potsdam: 37 | dataset: ISPRS_Potsdam 38 | download_root: data 39 | data_root: data/potsdam 40 | 41 | loveda: 42 | dataset: LoveDA 43 | download_root: data 44 | data_root: data/loveDA 45 | 46 | chase_db1: 47 | dataset: CHASE_DB1 48 | download_root: data 49 | data_root: data/CHASE_DB1 50 | 51 | drive: 52 | dataset: DRIVE 53 | download_root: data 54 | data_root: data/DRIVE 55 | 56 | hrf: 57 | dataset: HRF 58 | download_root: data 59 | data_root: data/HRF 60 | 61 | stare: 62 | dataset: STARE 63 | download_root: data 64 | data_root: data/STARE 65 | 66 | synapse: 67 | dataset: SurgVisDom 68 | download_root: data 69 | data_root: data/synapse 70 | 71 | refuge: 72 | dataset: REFUGE_Challenge 73 | download_root: data 74 | data_root: data/REFUGE 75 | 76 | lip: 77 | dataset: LIP 78 | download_root: data 79 | data_root: data/LIP 80 | -------------------------------------------------------------------------------- /Segmentation/docs/en/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /Segmentation/docs/en/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("../images/mmsegmentation.png"); 3 | background-size: 201px 40px; 4 | height: 40px; 5 | width: 201px; 6 | } 7 | -------------------------------------------------------------------------------- /Segmentation/docs/en/_static/images/mmsegmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/en/_static/images/mmsegmentation.png -------------------------------------------------------------------------------- /Segmentation/docs/en/advanced_guides/add_transforms.md: -------------------------------------------------------------------------------- 1 | # Adding New Data Transforms 2 | 3 | ## Customization data transformation 4 | 5 | The customized data transformation must inherited from `BaseTransform` and implement `transform` function. 6 | Here we use a simple flipping transformation as example: 7 | 8 | ```python 9 | import random 10 | import mmcv 11 | from mmcv.transforms import BaseTransform, TRANSFORMS 12 | 13 | @TRANSFORMS.register_module() 14 | class MyFlip(BaseTransform): 15 | def __init__(self, direction: str): 16 | super().__init__() 17 | self.direction = direction 18 | 19 | def transform(self, results: dict) -> dict: 20 | img = results['img'] 21 | results['img'] = mmcv.imflip(img, direction=self.direction) 22 | return results 23 | ``` 24 | 25 | Moreover, import the new class. 26 | 27 | ```python 28 | from .my_pipeline import MyFlip 29 | ``` 30 | 31 | Thus, we can instantiate a `MyFlip` object and use it to process the data dict. 32 | 33 | ```python 34 | import numpy as np 35 | 36 | transform = MyFlip(direction='horizontal') 37 | data_dict = {'img': np.random.rand(224, 224, 3)} 38 | data_dict = transform(data_dict) 39 | processed_img = data_dict['img'] 40 | ``` 41 | 42 | Or, we can use `MyFlip` transformation in data pipeline in our config file. 43 | 44 | ```python 45 | pipeline = [ 46 | ... 47 | dict(type='MyFlip', direction='horizontal'), 48 | ... 49 | ] 50 | ``` 51 | 52 | Note that if you want to use `MyFlip` in config, you must ensure the file containing `MyFlip` is imported during runtime. 53 | -------------------------------------------------------------------------------- /Segmentation/docs/en/advanced_guides/index.rst: -------------------------------------------------------------------------------- 1 | Basic Concepts 2 | *************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | data_flow.md 8 | structures.md 9 | models.md 10 | datasets.md 11 | transforms.md 12 | evaluation.md 13 | engine.md 14 | training_tricks.md 15 | 16 | Component Customization 17 | ************************ 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | 22 | add_models.md 23 | add_datasets.md 24 | add_transforms.md 25 | add_metrics.md 26 | customize_runtime.md 27 | -------------------------------------------------------------------------------- /Segmentation/docs/en/api.rst: -------------------------------------------------------------------------------- 1 | mmseg.apis 2 | -------------- 3 | .. automodule:: mmseg.apis 4 | :members: 5 | 6 | mmseg.datasets 7 | -------------- 8 | 9 | datasets 10 | ^^^^^^^^^^ 11 | .. automodule:: mmseg.datasets 12 | :members: 13 | 14 | transforms 15 | ^^^^^^^^^^^^ 16 | .. automodule:: mmseg.datasets.transforms 17 | :members: 18 | 19 | mmseg.engine 20 | -------------- 21 | 22 | hooks 23 | ^^^^^^^^^^ 24 | .. automodule:: mmseg.engine.hooks 25 | :members: 26 | 27 | optimizers 28 | ^^^^^^^^^^^^^^^ 29 | .. automodule:: mmseg.engine.optimizers 30 | :members: 31 | 32 | mmseg.evaluation 33 | ----------------- 34 | 35 | metrics 36 | ^^^^^^^^^^ 37 | .. automodule:: mmseg.evaluation.metrics 38 | :members: 39 | 40 | mmseg.models 41 | -------------- 42 | 43 | backbones 44 | ^^^^^^^^^^^^^^^^^^ 45 | .. automodule:: mmseg.models.backbones 46 | :members: 47 | 48 | decode_heads 49 | ^^^^^^^^^^^^^^^ 50 | .. automodule:: mmseg.models.decode_heads 51 | :members: 52 | 53 | segmentors 54 | ^^^^^^^^^^ 55 | .. automodule:: mmseg.models.segmentors 56 | :members: 57 | 58 | losses 59 | ^^^^^^^^^^ 60 | .. automodule:: mmseg.models.losses 61 | :members: 62 | 63 | necks 64 | ^^^^^^^^^^^^ 65 | .. automodule:: mmseg.models.necks 66 | :members: 67 | 68 | utils 69 | ^^^^^^^^^^ 70 | .. automodule:: mmseg.models.utils 71 | :members: 72 | 73 | 74 | mmseg.structures 75 | -------------------- 76 | 77 | structures 78 | ^^^^^^^^^^^^^^^^^ 79 | .. automodule:: mmseg.structures 80 | :members: 81 | 82 | sampler 83 | ^^^^^^^^^^ 84 | .. automodule:: mmseg.structures.sampler 85 | :members: 86 | 87 | mmseg.visualization 88 | -------------------- 89 | .. automodule:: mmseg.visualization 90 | :members: 91 | 92 | mmseg.utils 93 | -------------- 94 | .. automodule:: mmseg.utils 95 | :members: 96 | -------------------------------------------------------------------------------- /Segmentation/docs/en/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to MMSegmentation's documentation! 2 | =========================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Get Started 7 | 8 | overview.md 9 | get_started.md 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: User Guides 14 | 15 | user_guides/index.rst 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Advanced Guides 20 | 21 | advanced_guides/index.rst 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | :caption: Migration 26 | 27 | migration/index.rst 28 | 29 | .. toctree:: 30 | :caption: API Reference 31 | 32 | api.rst 33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | :caption: Model Zoo 37 | 38 | model_zoo.md 39 | modelzoo_statistics.md 40 | 41 | .. toctree:: 42 | :maxdepth: 1 43 | :caption: Notes 44 | 45 | notes/changelog.md 46 | notes/faq.md 47 | 48 | .. toctree:: 49 | :caption: Device Support 50 | 51 | device/npu.md 52 | 53 | .. toctree:: 54 | :caption: Switch Language 55 | 56 | switch_language.md 57 | 58 | 59 | Indices and tables 60 | ================== 61 | 62 | * :ref:`genindex` 63 | * :ref:`search` 64 | -------------------------------------------------------------------------------- /Segmentation/docs/en/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /Segmentation/docs/en/migration/index.rst: -------------------------------------------------------------------------------- 1 | Migration 2 | *************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | interface.md 8 | package.md 9 | -------------------------------------------------------------------------------- /Segmentation/docs/en/stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import functools as func 4 | import glob 5 | import os.path as osp 6 | import re 7 | 8 | import numpy as np 9 | 10 | url_prefix = 'https://github.com/open-mmlab/mmsegmentation/blob/master/' 11 | 12 | files = sorted(glob.glob('../../configs/*/README.md')) 13 | 14 | stats = [] 15 | titles = [] 16 | num_ckpts = 0 17 | 18 | for f in files: 19 | url = osp.dirname(f.replace('../../', url_prefix)) 20 | 21 | with open(f) as content_file: 22 | content = content_file.read() 23 | 24 | title = content.split('\n')[0].replace('#', '').strip() 25 | ckpts = { 26 | x.lower().strip() 27 | for x in re.findall(r'https?://download.*\.pth', content) 28 | if 'mmsegmentation' in x 29 | } 30 | if len(ckpts) == 0: 31 | continue 32 | 33 | _papertype = [ 34 | x for x in re.findall(r'', content) 35 | ] 36 | assert len(_papertype) > 0 37 | papertype = _papertype[0] 38 | 39 | paper = {(papertype, title)} 40 | 41 | titles.append(title) 42 | num_ckpts += len(ckpts) 43 | statsmsg = f""" 44 | \t* [{papertype}] [{title}]({url}) ({len(ckpts)} ckpts) 45 | """ 46 | stats.append((paper, ckpts, statsmsg)) 47 | 48 | allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _ in stats]) 49 | msglist = '\n'.join(x for _, _, x in stats) 50 | 51 | papertypes, papercounts = np.unique([t for t, _ in allpapers], 52 | return_counts=True) 53 | countstr = '\n'.join( 54 | [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)]) 55 | 56 | modelzoo = f""" 57 | # Model Zoo Statistics 58 | 59 | * Number of papers: {len(set(titles))} 60 | {countstr} 61 | 62 | * Number of checkpoints: {num_ckpts} 63 | {msglist} 64 | """ 65 | 66 | with open('modelzoo_statistics.md', 'w') as f: 67 | f.write(modelzoo) 68 | -------------------------------------------------------------------------------- /Segmentation/docs/en/switch_language.md: -------------------------------------------------------------------------------- 1 | ## English 2 | 3 | ## 简体中文 4 | -------------------------------------------------------------------------------- /Segmentation/docs/en/user_guides/index.rst: -------------------------------------------------------------------------------- 1 | Train & Test 2 | ************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | 1_config.md 8 | 2_dataset_prepare.md 9 | 3_inference.md 10 | 4_train_test.md 11 | 12 | Useful Tools 13 | ************* 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | 18 | visualization.md 19 | useful_tools.md 20 | deployment.md 21 | visualization_feature_map.md 22 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("../images/mmsegmentation.png"); 3 | background-size: 201px 40px; 4 | height: 40px; 5 | width: 201px; 6 | } 7 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/_static/images/mmsegmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/_static/images/mmsegmentation.png -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/advanced_guides/add_transforms.md: -------------------------------------------------------------------------------- 1 | # 新增数据增强 2 | 3 | ## 自定义数据增强 4 | 5 | 自定义数据增强必须继承 `BaseTransform` 并实现 `transform` 函数。这里我们使用一个简单的翻转变换作为示例: 6 | 7 | ```python 8 | import random 9 | import mmcv 10 | from mmcv.transforms import BaseTransform, TRANSFORMS 11 | 12 | @TRANSFORMS.register_module() 13 | class MyFlip(BaseTransform): 14 | def __init__(self, direction: str): 15 | super().__init__() 16 | self.direction = direction 17 | 18 | def transform(self, results: dict) -> dict: 19 | img = results['img'] 20 | results['img'] = mmcv.imflip(img, direction=self.direction) 21 | return results 22 | ``` 23 | 24 | 此外,新的类需要被导入。 25 | 26 | ```python 27 | from .my_pipeline import MyFlip 28 | ``` 29 | 30 | 这样,我们就可以实例化一个 `MyFlip` 对象并使用它来处理数据字典。 31 | 32 | ```python 33 | import numpy as np 34 | 35 | transform = MyFlip(direction='horizontal') 36 | data_dict = {'img': np.random.rand(224, 224, 3)} 37 | data_dict = transform(data_dict) 38 | processed_img = data_dict['img'] 39 | ``` 40 | 41 | 或者,我们可以在配置文件中的数据流程中使用 `MyFlip` 变换。 42 | 43 | ```python 44 | pipeline = [ 45 | ... 46 | dict(type='MyFlip', direction='horizontal'), 47 | ... 48 | ] 49 | ``` 50 | 51 | 需要注意,如果要在配置文件中使用 `MyFlip`,必须确保在运行时导入了包含 `MyFlip` 的文件。 52 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/advanced_guides/index.rst: -------------------------------------------------------------------------------- 1 | 基本概念 2 | *************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | data_flow.md 8 | structures.md 9 | models.md 10 | datasets.md 11 | transforms.md 12 | evaluation.md 13 | engine.md 14 | training_tricks.md 15 | 16 | 自定义组件 17 | ************************ 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | 22 | add_models.md 23 | add_datasets.md 24 | add_transforms.md 25 | add_metrics.md 26 | customize_runtime.md 27 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/api.rst: -------------------------------------------------------------------------------- 1 | mmseg.apis 2 | -------------- 3 | .. automodule:: mmseg.apis 4 | :members: 5 | 6 | mmseg.datasets 7 | -------------- 8 | 9 | datasets 10 | ^^^^^^^^^^ 11 | .. automodule:: mmseg.datasets 12 | :members: 13 | 14 | transforms 15 | ^^^^^^^^^^^^ 16 | .. automodule:: mmseg.datasets.transforms 17 | :members: 18 | 19 | mmseg.engine 20 | -------------- 21 | 22 | hooks 23 | ^^^^^^^^^^ 24 | .. automodule:: mmseg.engine.hooks 25 | :members: 26 | 27 | optimizers 28 | ^^^^^^^^^^^^^^^ 29 | .. automodule:: mmseg.engine.optimizers 30 | :members: 31 | 32 | mmseg.evaluation 33 | -------------- 34 | 35 | metrics 36 | ^^^^^^^^^^ 37 | .. automodule:: mmseg.evaluation.metrics 38 | :members: 39 | 40 | mmseg.models 41 | -------------- 42 | 43 | backbones 44 | ^^^^^^^^^^^^^^^^^^ 45 | .. automodule:: mmseg.models.backbones 46 | :members: 47 | 48 | decode_heads 49 | ^^^^^^^^^^^^^^^ 50 | .. automodule:: mmseg.models.decode_heads 51 | :members: 52 | 53 | segmentors 54 | ^^^^^^^^^^ 55 | .. automodule:: mmseg.models.segmentors 56 | :members: 57 | 58 | losses 59 | ^^^^^^^^^^ 60 | .. automodule:: mmseg.models.losses 61 | :members: 62 | 63 | necks 64 | ^^^^^^^^^^^^ 65 | .. automodule:: mmseg.models.necks 66 | :members: 67 | 68 | utils 69 | ^^^^^^^^^^ 70 | .. automodule:: mmseg.models.utils 71 | :members: 72 | 73 | 74 | mmseg.structures 75 | -------------------- 76 | 77 | structures 78 | ^^^^^^^^^^^^^^^^^ 79 | .. automodule:: mmseg.structures 80 | :members: 81 | 82 | sampler 83 | ^^^^^^^^^^ 84 | .. automodule:: mmseg.structures.sampler 85 | :members: 86 | 87 | mmseg.visualization 88 | -------------------- 89 | .. automodule:: mmseg.visualization 90 | :members: 91 | 92 | mmseg.utils 93 | -------------- 94 | .. automodule:: mmseg.utils 95 | :members: 96 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/imgs/qq_group_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/imgs/qq_group_qrcode.jpg -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/imgs/seggroup_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/imgs/seggroup_qrcode.jpg -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/imgs/zhihu_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/imgs/zhihu_qrcode.jpg -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/index.rst: -------------------------------------------------------------------------------- 1 | 欢迎来到 MMSegmentation 的文档! 2 | ======================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: 开始你的第一步 7 | 8 | get_started.md 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: 用户指南 13 | 14 | user_guides/index.rst 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :caption: 进阶指南 19 | 20 | advanced_guides/index.rst 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :caption: 迁移指引 25 | 26 | migration/index.rst 27 | 28 | .. toctree:: 29 | :caption: 接口文档(英文) 30 | 31 | api.rst 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: 模型库 36 | 37 | model_zoo.md 38 | modelzoo_statistics.md 39 | 40 | .. toctree:: 41 | :maxdepth: 2 42 | :caption: 说明 43 | 44 | changelog.md 45 | faq.md 46 | 47 | .. toctree:: 48 | :caption: 语言切换 49 | 50 | switch_language.md 51 | 52 | 53 | Indices and tables 54 | ================== 55 | 56 | * :ref:`genindex` 57 | * :ref:`search` 58 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/migration/index.rst: -------------------------------------------------------------------------------- 1 | 迁移 2 | *************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | interface.md 8 | package.md 9 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import functools as func 4 | import glob 5 | import os.path as osp 6 | import re 7 | 8 | import numpy as np 9 | 10 | url_prefix = 'https://github.com/open-mmlab/mmsegmentation/blob/master/' 11 | 12 | files = sorted(glob.glob('../../configs/*/README.md')) 13 | 14 | stats = [] 15 | titles = [] 16 | num_ckpts = 0 17 | 18 | for f in files: 19 | url = osp.dirname(f.replace('../../', url_prefix)) 20 | 21 | with open(f) as content_file: 22 | content = content_file.read() 23 | 24 | title = content.split('\n')[0].replace('#', '').strip() 25 | ckpts = { 26 | x.lower().strip() 27 | for x in re.findall(r'https?://download.*\.pth', content) 28 | if 'mmsegmentation' in x 29 | } 30 | if len(ckpts) == 0: 31 | continue 32 | 33 | _papertype = [ 34 | x for x in re.findall(r'', content) 35 | ] 36 | assert len(_papertype) > 0 37 | papertype = _papertype[0] 38 | 39 | paper = {(papertype, title)} 40 | 41 | titles.append(title) 42 | num_ckpts += len(ckpts) 43 | statsmsg = f""" 44 | \t* [{papertype}] [{title}]({url}) ({len(ckpts)} ckpts) 45 | """ 46 | stats.append((paper, ckpts, statsmsg)) 47 | 48 | allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _ in stats]) 49 | msglist = '\n'.join(x for _, _, x in stats) 50 | 51 | papertypes, papercounts = np.unique([t for t, _ in allpapers], 52 | return_counts=True) 53 | countstr = '\n'.join( 54 | [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)]) 55 | 56 | modelzoo = f""" 57 | # 模型库统计数据 58 | 59 | * 论文数量: {len(set(titles))} 60 | {countstr} 61 | 62 | * 模型数量: {num_ckpts} 63 | {msglist} 64 | """ 65 | 66 | with open('modelzoo_statistics.md', 'w') as f: 67 | f.write(modelzoo) 68 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/switch_language.md: -------------------------------------------------------------------------------- 1 | ## English 2 | 3 | ## 简体中文 4 | -------------------------------------------------------------------------------- /Segmentation/docs/zh_cn/user_guides/index.rst: -------------------------------------------------------------------------------- 1 | 训练 & 测试 2 | ************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | 1_config.md 8 | 2_dataset_prepare.md 9 | 3_inference.md 10 | 4_train_test.md 11 | 12 | 实用工具 13 | ************* 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | 18 | visualization.md 19 | useful_tools.md 20 | deployment.md 21 | visualization_feature_map.md 22 | -------------------------------------------------------------------------------- /Segmentation/mmdet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import mmengine 4 | from mmengine.utils import digit_version 5 | 6 | from .version import __version__, version_info 7 | 8 | mmcv_minimum_version = '2.0.0rc4' 9 | mmcv_maximum_version = '2.1.1' 10 | mmcv_version = digit_version(mmcv.__version__) 11 | 12 | mmengine_minimum_version = '0.7.1' 13 | mmengine_maximum_version = '1.0.0' 14 | mmengine_version = digit_version(mmengine.__version__) 15 | 16 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 17 | and mmcv_version < digit_version(mmcv_maximum_version)), \ 18 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 19 | f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' 20 | 21 | assert (mmengine_version >= digit_version(mmengine_minimum_version) 22 | and mmengine_version < digit_version(mmengine_maximum_version)), \ 23 | f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ 24 | f'Please install mmengine>={mmengine_minimum_version}, ' \ 25 | f'<{mmengine_maximum_version}.' 26 | 27 | __all__ = ['__version__', 'version_info', 'digit_version'] 28 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # from .data_preprocessors import * # noqa: F401,F403 3 | from .dense_heads import * # noqa: F401,F403 4 | from .layers import * # noqa: F401,F403 5 | from .losses import * # noqa: F401,F403 6 | from .task_modules import * 7 | from .test_time_augs import * 8 | from .tracking_heads import * # noqa: F401,F403 9 | # from .vis import * # noqa: F401,F403 10 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/dense_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .anchor_free_head import AnchorFreeHead 3 | from .maskformer_head import MaskFormerHead 4 | 5 | 6 | 7 | __all__ = [ 8 | 'AnchorFreeHead', 9 | 'MaskFormerHead', 10 | ] 11 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmengine.utils import digit_version 5 | 6 | from mmdet.registry import MODELS 7 | 8 | if digit_version(torch.__version__) >= digit_version('1.7.0'): 9 | from torch.nn import SiLU 10 | else: 11 | 12 | class SiLU(nn.Module): 13 | """Sigmoid Weighted Liner Unit.""" 14 | 15 | def __init__(self, inplace=True): 16 | super().__init__() 17 | 18 | def forward(self, inputs) -> torch.Tensor: 19 | return inputs * torch.sigmoid(inputs) 20 | 21 | 22 | MODELS.register_module(module=SiLU, name='SiLU') 23 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/brick_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version 6 | 7 | if torch.__version__ == 'parrots': 8 | TORCH_VERSION = torch.__version__ 9 | else: 10 | # torch.__version__ could be 1.3.1+cu92, we only need the first two 11 | # for comparison 12 | TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) 13 | 14 | 15 | def adaptive_avg_pool2d(input, output_size): 16 | """Handle empty batch dimension to adaptive_avg_pool2d. 17 | 18 | Args: 19 | input (tensor): 4D tensor. 20 | output_size (int, tuple[int,int]): the target output size. 21 | """ 22 | if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): 23 | if isinstance(output_size, int): 24 | output_size = [output_size, output_size] 25 | output_size = [*input.shape[:2], *output_size] 26 | empty = NewEmptyTensorOp.apply(input, output_size) 27 | return empty 28 | else: 29 | return F.adaptive_avg_pool2d(input, output_size) 30 | 31 | 32 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): 33 | """Handle empty batch dimension to AdaptiveAvgPool2d.""" 34 | 35 | def forward(self, x): 36 | # PyTorch 1.9 does not support empty tensor inference yet 37 | if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): 38 | output_size = self.output_size 39 | if isinstance(output_size, int): 40 | output_size = [output_size, output_size] 41 | else: 42 | output_size = [ 43 | v if v is not None else d 44 | for v, d in zip(output_size, 45 | x.size()[-2:]) 46 | ] 47 | output_size = [*x.shape[:2], *output_size] 48 | empty = NewEmptyTensorOp.apply(x, output_size) 49 | return empty 50 | 51 | return super().forward(x) 52 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/mmcv_spike/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_scale_deform_attn import SpikeMultiScaleDeformableAttention 2 | from .transformer import MultiheadAttention, FFN, MSDA_FFN, MS_MLP 3 | from .spikeformer import MSTransformerDecoder, CrossAttention, SelfAttention, MLP 4 | from .BASE_Transformer import Transformer 5 | # NOTE: Move the mmcv function here to change the basic version 6 | __all__ = [ 7 | "SpikeMultiScaleDeformableAttention", "MultiheadAttention", 8 | "MSTransformerDecoder", "CrossAttention", "SelfAttention", "MLP", "MSDA_FFN", 9 | "Transformer", "MS_MLP" 10 | ] -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/mmcv_spike/scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Scale(nn.Module): 7 | """A learnable scale parameter. 8 | 9 | This layer scales the input by a learnable factor. It multiplies a 10 | learnable scale parameter of shape (1,) with input of any shape. 11 | 12 | Args: 13 | scale (float): Initial value of scale factor. Default: 1.0 14 | """ 15 | 16 | def __init__(self, scale: float = 1.0): 17 | super().__init__() 18 | self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | return x * self.scale 22 | 23 | 24 | class LayerScale(nn.Module): 25 | """LayerScale layer. 26 | 27 | Args: 28 | dim (int): Dimension of input features. 29 | inplace (bool): Whether performs operation in-place. 30 | Default: `False`. 31 | data_format (str): The input data format, could be 'channels_last' 32 | or 'channels_first', representing (B, C, H, W) and 33 | (B, N, C) format data respectively. Default: 'channels_last'. 34 | scale (float): Initial value of scale factor. Default: 1.0 35 | """ 36 | 37 | def __init__(self, 38 | dim: int, 39 | inplace: bool = False, 40 | data_format: str = 'channels_last', 41 | scale: float = 1e-5): 42 | super().__init__() 43 | assert data_format in ('channels_last', 'channels_first'), \ 44 | "'data_format' could only be channels_last or channels_first." 45 | self.inplace = inplace 46 | self.data_format = data_format 47 | self.weight = nn.Parameter(torch.ones(dim) * scale) 48 | 49 | def forward(self, x) -> torch.Tensor: 50 | if self.data_format == 'channels_first': 51 | shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) 52 | else: 53 | shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) 54 | if self.inplace: 55 | return x.mul_(self.weight.view(*shape)) 56 | else: 57 | return x * self.weight.view(*shape) 58 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules.dcnv3 import DCNv3_pytorch 2 | __all__ = [ 3 | "DCNv3_pytorch" 4 | ] -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3_func import dcnv3_core_pytorch 8 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -------------------------------------------------------- 3 | # InternImage 4 | # Copyright (c) 2022 OpenGVLab 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # -------------------------------------------------------- 7 | 8 | python setup.py build install 9 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3 import DCNv3_pytorch -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/cpu/dcnv3_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, 18 | const at::Tensor &mask, const int kernel_h, 19 | const int kernel_w, const int stride_h, 20 | const int stride_w, const int pad_h, 21 | const int pad_w, const int dilation_h, 22 | const int dilation_w, const int group, 23 | const int group_channels, const float offset_scale, 24 | const int im2col_step) { 25 | AT_ERROR("Not implement on cpu"); 26 | } 27 | 28 | std::vector 29 | dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, 30 | const at::Tensor &mask, const int kernel_h, 31 | const int kernel_w, const int stride_h, const int stride_w, 32 | const int pad_h, const int pad_w, const int dilation_h, 33 | const int dilation_w, const int group, 34 | const int group_channels, const float offset_scale, 35 | const at::Tensor &grad_output, const int im2col_step) { 36 | AT_ERROR("Not implement on cpu"); 37 | } 38 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/cpu/dcnv3_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, 16 | const at::Tensor &mask, const int kernel_h, 17 | const int kernel_w, const int stride_h, 18 | const int stride_w, const int pad_h, 19 | const int pad_w, const int dilation_h, 20 | const int dilation_w, const int group, 21 | const int group_channels, const float offset_scale, 22 | const int im2col_step); 23 | 24 | std::vector 25 | dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, 26 | const at::Tensor &mask, const int kernel_h, 27 | const int kernel_w, const int stride_h, const int stride_w, 28 | const int pad_h, const int pad_w, const int dilation_h, 29 | const int dilation_w, const int group, 30 | const int group_channels, const float offset_scale, 31 | const at::Tensor &grad_output, const int im2col_step); 32 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/cuda/dcnv3_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, 16 | const at::Tensor &mask, const int kernel_h, 17 | const int kernel_w, const int stride_h, 18 | const int stride_w, const int pad_h, 19 | const int pad_w, const int dilation_h, 20 | const int dilation_w, const int group, 21 | const int group_channels, 22 | const float offset_scale, const int im2col_step); 23 | 24 | std::vector 25 | dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, 26 | const at::Tensor &mask, const int kernel_h, 27 | const int kernel_w, const int stride_h, const int stride_w, 28 | const int pad_h, const int pad_w, const int dilation_h, 29 | const int dilation_w, const int group, 30 | const int group_channels, const float offset_scale, 31 | const at::Tensor &grad_output, const int im2col_step); 32 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #include "dcnv3.h" 13 | 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward"); 16 | m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward"); 17 | } 18 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 4 | cross_entropy, mask_cross_entropy) 5 | from .dice_loss import DiceLoss 6 | from .focal_loss import FocalLoss, sigmoid_focal_loss 7 | from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, EIoULoss, GIoULoss, 8 | IoULoss, SIoULoss, bounded_iou_loss, iou_loss) 9 | from .mse_loss import MSELoss, mse_loss 10 | from .multipos_cross_entropy_loss import MultiPosCrossEntropyLoss 11 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 12 | 13 | 14 | __all__ = [ 15 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 16 | 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss', 17 | 'FocalLoss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss', 18 | 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 19 | 'EIoULoss', 'SIoULoss', 'reduce_loss', 20 | 'weight_reduce_loss', 'weighted_loss', 21 | 'DiceLoss', 'MultiPosCrossEntropyLoss', 22 | ] 23 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/task_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import (ANCHOR_GENERATORS, BBOX_ASSIGNERS, BBOX_CODERS, 4 | BBOX_SAMPLERS, IOU_CALCULATORS, MATCH_COSTS, 5 | PRIOR_GENERATORS, build_anchor_generator, build_assigner, 6 | build_bbox_coder, build_iou_calculator, build_match_cost, 7 | build_prior_generator, build_sampler) 8 | 9 | from .prior_generators import * # noqa: F401,F403 10 | from .assigners import * 11 | from .samplers import * 12 | 13 | 14 | __all__ = [ 15 | 'ANCHOR_GENERATORS', 'PRIOR_GENERATORS', 'BBOX_ASSIGNERS', 'BBOX_SAMPLERS', 16 | 'MATCH_COSTS', 'BBOX_CODERS', 'IOU_CALCULATORS', 'build_anchor_generator', 17 | 'build_prior_generator', 'build_assigner', 'build_sampler', 18 | 'build_iou_calculator', 'build_match_cost', 'build_bbox_coder' 19 | ] 20 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/task_modules/assigners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .approx_max_iou_assigner import ApproxMaxIoUAssigner 3 | from .assign_result import AssignResult 4 | from .atss_assigner import ATSSAssigner 5 | from .base_assigner import BaseAssigner 6 | from .center_region_assigner import CenterRegionAssigner 7 | from .dynamic_soft_label_assigner import DynamicSoftLabelAssigner 8 | from .grid_assigner import GridAssigner 9 | from .hungarian_assigner import HungarianAssigner 10 | from .iou2d_calculator import BboxOverlaps2D 11 | from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost, 12 | DiceCost, FocalLossCost, IoUCost) 13 | from .max_iou_assigner import MaxIoUAssigner 14 | from .multi_instance_assigner import MultiInstanceAssigner 15 | from .point_assigner import PointAssigner 16 | from .region_assigner import RegionAssigner 17 | from .sim_ota_assigner import SimOTAAssigner 18 | from .task_aligned_assigner import TaskAlignedAssigner 19 | from .uniform_assigner import UniformAssigner 20 | 21 | __all__ = [ 22 | 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 23 | 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', 24 | 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', 25 | 'TaskAlignedAssigner', 'BBoxL1Cost', 'ClassificationCost', 26 | 'CrossEntropyLossCost', 'DiceCost', 'FocalLossCost', 'IoUCost', 27 | 'BboxOverlaps2D', 'DynamicSoftLabelAssigner', 'MultiInstanceAssigner' 28 | ] 29 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/task_modules/assigners/base_assigner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Optional 4 | 5 | from mmengine.structures import InstanceData 6 | 7 | 8 | class BaseAssigner(metaclass=ABCMeta): 9 | """Base assigner that assigns boxes to ground truth boxes.""" 10 | 11 | @abstractmethod 12 | def assign(self, 13 | pred_instances: InstanceData, 14 | gt_instances: InstanceData, 15 | gt_instances_ignore: Optional[InstanceData] = None, 16 | **kwargs): 17 | """Assign boxes to either a ground truth boxes or a negative boxes.""" 18 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/task_modules/prior_generators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator, 3 | SSDAnchorGenerator, YOLOAnchorGenerator) 4 | from .point_generator import MlvlPointGenerator, PointGenerator 5 | from .utils import anchor_inside_flags, calc_region 6 | 7 | __all__ = [ 8 | 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags', 9 | 'PointGenerator', 'calc_region', 'YOLOAnchorGenerator', 10 | 'MlvlPointGenerator', 'SSDAnchorGenerator' 11 | ] 12 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/task_modules/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_sampler import BaseSampler 3 | from .combined_sampler import CombinedSampler 4 | from .instance_balanced_pos_sampler import InstanceBalancedPosSampler 5 | from .iou_balanced_neg_sampler import IoUBalancedNegSampler 6 | from .mask_pseudo_sampler import MaskPseudoSampler 7 | from .mask_sampling_result import MaskSamplingResult 8 | from .multi_instance_random_sampler import MultiInsRandomSampler 9 | from .multi_instance_sampling_result import MultiInstanceSamplingResult 10 | from .ohem_sampler import OHEMSampler 11 | from .pseudo_sampler import PseudoSampler 12 | from .random_sampler import RandomSampler 13 | from .sampling_result import SamplingResult 14 | from .score_hlr_sampler import ScoreHLRSampler 15 | 16 | __all__ = [ 17 | 'BaseSampler', 'PseudoSampler', 'RandomSampler', 18 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 19 | 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler', 20 | 'MaskSamplingResult', 'MultiInstanceSamplingResult', 21 | 'MultiInsRandomSampler' 22 | ] 23 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/task_modules/samplers/combined_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.registry import TASK_UTILS 3 | from .base_sampler import BaseSampler 4 | 5 | 6 | @TASK_UTILS.register_module() 7 | class CombinedSampler(BaseSampler): 8 | """A sampler that combines positive sampler and negative sampler.""" 9 | 10 | def __init__(self, pos_sampler, neg_sampler, **kwargs): 11 | super(CombinedSampler, self).__init__(**kwargs) 12 | self.pos_sampler = TASK_UTILS.build(pos_sampler, default_args=kwargs) 13 | self.neg_sampler = TASK_UTILS.build(neg_sampler, default_args=kwargs) 14 | 15 | def _sample_pos(self, **kwargs): 16 | """Sample positive samples.""" 17 | raise NotImplementedError 18 | 19 | def _sample_neg(self, **kwargs): 20 | """Sample negative samples.""" 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/test_time_augs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .det_tta import DetTTAModel 3 | from .merge_augs import (merge_aug_bboxes, merge_aug_masks, 4 | merge_aug_proposals, merge_aug_results, 5 | merge_aug_scores) 6 | 7 | __all__ = [ 8 | 'merge_aug_bboxes', 'merge_aug_masks', 'merge_aug_proposals', 9 | 'merge_aug_scores', 'merge_aug_results', 'DetTTAModel' 10 | ] 11 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/tracking_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mask2former_track_head import Mask2FormerTrackHead 3 | __all__ = [ 4 | 'Mask2FormerTrackHead', 5 | ] 6 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/utils/Qtrick.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | quant = 4 5 | T = quant 6 | 7 | 8 | class Quant(torch.autograd.Function): 9 | @staticmethod 10 | @torch.cuda.amp.custom_fwd 11 | def forward(ctx, i, min_value=0, max_value=quant): 12 | ctx.min = min_value 13 | ctx.max = max_value 14 | ctx.save_for_backward(i) 15 | return torch.round(torch.clamp(i, min=min_value, max=max_value)) 16 | 17 | @staticmethod 18 | @torch.cuda.amp.custom_fwd 19 | def backward(ctx, grad_output): 20 | grad_input = grad_output.clone() 21 | i, = ctx.saved_tensors 22 | grad_input[i < ctx.min] = 0 23 | grad_input[i > ctx.max] = 0 24 | return grad_input, None, None 25 | 26 | 27 | class MultiSpike_norm4(nn.Module): 28 | def __init__( 29 | self, 30 | Vth=1.0, 31 | T=T, # 在T上进行Norm 32 | ): 33 | super().__init__() 34 | self.spike = Quant() 35 | self.Vth = Vth 36 | self.T = T 37 | 38 | def forward(self, x): 39 | return self.spike.apply(x) / self.T 40 | 41 | 42 | # 43 | class MultiSpike_4(nn.Module): 44 | def __init__( 45 | self, 46 | T=T, # 在T上进行Norm 47 | ): 48 | super().__init__() 49 | self.spike = Quant() 50 | self.T = T 51 | 52 | def forward(self, x): 53 | return self.spike.apply(x) 54 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .gaussian_target import (gather_feat, gaussian_radius, 3 | gen_gaussian_target, get_local_maximum, 4 | get_topk_from_heatmap, transpose_and_gather_feat) 5 | from .image import imrenormalize 6 | from .make_divisible import make_divisible 7 | from .misc import (aligned_bilinear, center_of_mass, empty_instances, 8 | filter_gt_instances, filter_scores_and_topk, flip_tensor, 9 | generate_coordinate, images_to_levels, interpolate_as, 10 | levels_to_images, mask2ndarray, multi_apply, 11 | relative_coordinate_maps, rename_loss_dict, 12 | reweight_loss_dict, samplelist_boxtype2tensor, 13 | select_single_mlvl, sigmoid_geometric_mean, 14 | unfold_wo_center, unmap, unpack_gt_instances) 15 | from .panoptic_gt_processing import preprocess_panoptic_gt 16 | from .point_sample import (get_uncertain_point_coords_with_randomness, 17 | get_uncertainty) 18 | from .vlfuse_helper import BertEncoderLayer, VLFuse, permute_and_flatten 19 | from .Qtrick import MultiSpike_norm4 20 | 21 | __all__ = [ 22 | 'gaussian_radius', 'gen_gaussian_target', 'make_divisible', 23 | 'get_local_maximum', 'get_topk_from_heatmap', 'transpose_and_gather_feat', 24 | 'interpolate_as', 'sigmoid_geometric_mean', 'gather_feat', 25 | 'preprocess_panoptic_gt', 'get_uncertain_point_coords_with_randomness', 26 | 'get_uncertainty', 'unpack_gt_instances', 'empty_instances', 27 | 'center_of_mass', 'filter_scores_and_topk', 'flip_tensor', 28 | 'generate_coordinate', 'levels_to_images', 'mask2ndarray', 'multi_apply', 29 | 'select_single_mlvl', 'unmap', 'images_to_levels', 30 | 'samplelist_boxtype2tensor', 'filter_gt_instances', 'rename_loss_dict', 31 | 'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear', 32 | 'unfold_wo_center', 'imrenormalize', 'VLFuse', 'permute_and_flatten', 33 | 'BertEncoderLayer', "MultiSpike_norm4" 34 | ] 35 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/utils/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Union 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | from torch import Tensor 8 | 9 | 10 | def imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict, 11 | new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]: 12 | """Re-normalize the image. 13 | 14 | Args: 15 | img (Tensor | ndarray): Input image. If the input is a Tensor, the 16 | shape is (1, C, H, W). If the input is a ndarray, the shape 17 | is (H, W, C). 18 | img_norm_cfg (dict): Original configuration for the normalization. 19 | new_img_norm_cfg (dict): New configuration for the normalization. 20 | 21 | Returns: 22 | Tensor | ndarray: Output image with the same type and shape of 23 | the input. 24 | """ 25 | if isinstance(img, torch.Tensor): 26 | assert img.ndim == 4 and img.shape[0] == 1 27 | new_img = img.squeeze(0).cpu().numpy().transpose(1, 2, 0) 28 | new_img = _imrenormalize(new_img, img_norm_cfg, new_img_norm_cfg) 29 | new_img = new_img.transpose(2, 0, 1)[None] 30 | return torch.from_numpy(new_img).to(img) 31 | else: 32 | return _imrenormalize(img, img_norm_cfg, new_img_norm_cfg) 33 | 34 | 35 | def _imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict, 36 | new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]: 37 | """Re-normalize the image.""" 38 | img_norm_cfg = img_norm_cfg.copy() 39 | new_img_norm_cfg = new_img_norm_cfg.copy() 40 | for k, v in img_norm_cfg.items(): 41 | if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray): 42 | img_norm_cfg[k] = np.array(v, dtype=img.dtype) 43 | # reverse cfg 44 | if 'bgr_to_rgb' in img_norm_cfg: 45 | img_norm_cfg['rgb_to_bgr'] = img_norm_cfg['bgr_to_rgb'] 46 | img_norm_cfg.pop('bgr_to_rgb') 47 | for k, v in new_img_norm_cfg.items(): 48 | if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray): 49 | new_img_norm_cfg[k] = np.array(v, dtype=img.dtype) 50 | img = mmcv.imdenormalize(img, **img_norm_cfg) 51 | img = mmcv.imnormalize(img, **new_img_norm_cfg) 52 | return img 53 | -------------------------------------------------------------------------------- /Segmentation/mmdet/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /Segmentation/mmdet/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .det_data_sample import DetDataSample, OptSampleList, SampleList 3 | from .reid_data_sample import ReIDDataSample 4 | from .track_data_sample import (OptTrackSampleList, TrackDataSample, 5 | TrackSampleList) 6 | 7 | __all__ = [ 8 | 'DetDataSample', 'SampleList', 'OptSampleList', 'TrackDataSample', 9 | 'TrackSampleList', 'OptTrackSampleList', 'ReIDDataSample' 10 | ] 11 | -------------------------------------------------------------------------------- /Segmentation/mmdet/structures/bbox/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_boxes import BaseBoxes 3 | from .bbox_overlaps import bbox_overlaps 4 | from .box_type import (autocast_box_type, convert_box_type, get_box_type, 5 | register_box, register_box_converter) 6 | from .horizontal_boxes import HorizontalBoxes 7 | from .transforms import bbox_cxcyah_to_xyxy # noqa: E501 8 | from .transforms import (bbox2corner, bbox2distance, bbox2result, bbox2roi, 9 | bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping, 10 | bbox_mapping_back, bbox_project, bbox_rescale, 11 | bbox_xyxy_to_cxcyah, bbox_xyxy_to_cxcywh, cat_boxes, 12 | corner2bbox, distance2bbox, empty_box_as, 13 | find_inside_bboxes, get_box_tensor, get_box_wh, 14 | roi2bbox, scale_boxes, stack_boxes) 15 | 16 | __all__ = [ 17 | 'bbox_overlaps', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back', 18 | 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', 19 | 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh', 20 | 'find_inside_bboxes', 'bbox2corner', 'corner2bbox', 'bbox_project', 21 | 'BaseBoxes', 'convert_box_type', 'get_box_type', 'register_box', 22 | 'register_box_converter', 'HorizontalBoxes', 'autocast_box_type', 23 | 'cat_boxes', 'stack_boxes', 'scale_boxes', 'get_box_wh', 'get_box_tensor', 24 | 'empty_box_as', 'bbox_xyxy_to_cxcyah', 'bbox_cxcyah_to_xyxy' 25 | ] 26 | -------------------------------------------------------------------------------- /Segmentation/mmdet/structures/mask/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mask_target import mask_target 3 | from .structures import (BaseInstanceMasks, BitmapMasks, PolygonMasks, 4 | bitmap_to_polygon, polygon_to_bitmap) 5 | from .utils import encode_mask_results, mask2bbox, split_combined_polys 6 | 7 | __all__ = [ 8 | 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks', 9 | 'PolygonMasks', 'encode_mask_results', 'mask2bbox', 'polygon_to_bitmap', 10 | 'bitmap_to_polygon' 11 | ] 12 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .compat_config import compat_cfg 4 | from .dist_utils import (all_reduce_dict, allreduce_grads, reduce_mean, 5 | sync_random_seed) 6 | from .logger import get_caller_name, log_img_scale 7 | from .memory import AvoidCUDAOOM, AvoidOOM 8 | from .misc import (find_latest_checkpoint, get_test_pipeline_cfg, 9 | update_data_root) 10 | from .mot_error_visualize import imshow_mot_errors 11 | from .replace_cfg_vals import replace_cfg_vals 12 | from .setup_env import (register_all_modules, setup_cache_size_limit_of_dynamo, 13 | setup_multi_processes) 14 | from .split_batch import split_batch 15 | from .typing_utils import (ConfigType, InstanceList, MultiConfig, 16 | OptConfigType, OptInstanceList, OptMultiConfig, 17 | OptPixelList, PixelList, RangeType) 18 | 19 | __all__ = [ 20 | 'collect_env', 'find_latest_checkpoint', 'update_data_root', 21 | 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 22 | 'split_batch', 'register_all_modules', 'replace_cfg_vals', 'AvoidOOM', 23 | 'AvoidCUDAOOM', 'all_reduce_dict', 'allreduce_grads', 'reduce_mean', 24 | 'sync_random_seed', 'ConfigType', 'InstanceList', 'MultiConfig', 25 | 'OptConfigType', 'OptInstanceList', 'OptMultiConfig', 'OptPixelList', 26 | 'PixelList', 'RangeType', 'get_test_pipeline_cfg', 27 | 'setup_cache_size_limit_of_dynamo', 'imshow_mot_errors' 28 | ] 29 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.utils import get_git_hash 3 | from mmengine.utils.dl_utils import collect_env as collect_base_env 4 | 5 | import mmdet 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7] 12 | return env_info 13 | 14 | 15 | if __name__ == '__main__': 16 | for name, val in collect_env().items(): 17 | print(f'{name}: {val}') 18 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import inspect 3 | 4 | from mmengine.logging import print_log 5 | 6 | 7 | def get_caller_name(): 8 | """Get name of caller method.""" 9 | # this_func_frame = inspect.stack()[0][0] # i.e., get_caller_name 10 | # callee_frame = inspect.stack()[1][0] # e.g., log_img_scale 11 | caller_frame = inspect.stack()[2][0] # e.g., caller of log_img_scale 12 | caller_method = caller_frame.f_code.co_name 13 | try: 14 | caller_class = caller_frame.f_locals['self'].__class__.__name__ 15 | return f'{caller_class}.{caller_method}' 16 | except KeyError: # caller is a function 17 | return caller_method 18 | 19 | 20 | def log_img_scale(img_scale, shape_order='hw', skip_square=False): 21 | """Log image size. 22 | 23 | Args: 24 | img_scale (tuple): Image size to be logged. 25 | shape_order (str, optional): The order of image shape. 26 | 'hw' for (height, width) and 'wh' for (width, height). 27 | Defaults to 'hw'. 28 | skip_square (bool, optional): Whether to skip logging for square 29 | img_scale. Defaults to False. 30 | 31 | Returns: 32 | bool: Whether to have done logging. 33 | """ 34 | if shape_order == 'hw': 35 | height, width = img_scale 36 | elif shape_order == 'wh': 37 | width, height = img_scale 38 | else: 39 | raise ValueError(f'Invalid shape_order {shape_order}.') 40 | 41 | if skip_square and (height == width): 42 | return False 43 | 44 | caller = get_caller_name() 45 | print_log( 46 | f'image shape: height={height}, width={width} in {caller}', 47 | logger='current') 48 | 49 | return True 50 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import contextlib 3 | import sys 4 | import time 5 | 6 | import torch 7 | 8 | if sys.version_info >= (3, 7): 9 | 10 | @contextlib.contextmanager 11 | def profile_time(trace_name, 12 | name, 13 | enabled=True, 14 | stream=None, 15 | end_stream=None): 16 | """Print time spent by CPU and GPU. 17 | 18 | Useful as a temporary context manager to find sweet spots of code 19 | suitable for async implementation. 20 | """ 21 | if (not enabled) or not torch.cuda.is_available(): 22 | yield 23 | return 24 | stream = stream if stream else torch.cuda.current_stream() 25 | end_stream = end_stream if end_stream else stream 26 | start = torch.cuda.Event(enable_timing=True) 27 | end = torch.cuda.Event(enable_timing=True) 28 | stream.record_event(start) 29 | try: 30 | cpu_start = time.monotonic() 31 | yield 32 | finally: 33 | cpu_end = time.monotonic() 34 | end_stream.record_event(end) 35 | end.synchronize() 36 | cpu_time = (cpu_end - cpu_start) * 1000 37 | gpu_time = start.elapsed_time(end) 38 | msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' 39 | msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' 40 | print(msg, end_stream) 41 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/split_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def split_batch(img, img_metas, kwargs): 6 | """Split data_batch by tags. 7 | 8 | Code is modified from 9 | # noqa: E501 10 | 11 | Args: 12 | img (Tensor): of shape (N, C, H, W) encoding input images. 13 | Typically these should be mean centered and std scaled. 14 | img_metas (list[dict]): List of image info dict where each dict 15 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 16 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 17 | For details on the values of these keys, see 18 | :class:`mmdet.datasets.pipelines.Collect`. 19 | kwargs (dict): Specific to concrete implementation. 20 | 21 | Returns: 22 | data_groups (dict): a dict that data_batch splited by tags, 23 | such as 'sup', 'unsup_teacher', and 'unsup_student'. 24 | """ 25 | 26 | # only stack img in the batch 27 | def fuse_list(obj_list, obj): 28 | return torch.stack(obj_list) if isinstance(obj, 29 | torch.Tensor) else obj_list 30 | 31 | # select data with tag from data_batch 32 | def select_group(data_batch, current_tag): 33 | group_flag = [tag == current_tag for tag in data_batch['tag']] 34 | return { 35 | k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v) 36 | for k, v in data_batch.items() 37 | } 38 | 39 | kwargs.update({'img': img, 'img_metas': img_metas}) 40 | kwargs.update({'tag': [meta['tag'] for meta in img_metas]}) 41 | tags = list(set(kwargs['tag'])) 42 | data_groups = {tag: select_group(kwargs, tag) for tag in tags} 43 | for tag, group in data_groups.items(): 44 | group.pop('tag') 45 | return data_groups 46 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/typing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Collecting some commonly used type hint in mmdetection.""" 3 | from typing import List, Optional, Sequence, Tuple, Union 4 | 5 | from mmengine.config import ConfigDict 6 | from mmengine.structures import InstanceData, PixelData 7 | 8 | # TODO: Need to avoid circular import with assigner and sampler 9 | # Type hint of config data 10 | ConfigType = Union[ConfigDict, dict] 11 | OptConfigType = Optional[ConfigType] 12 | # Type hint of one or more config data 13 | MultiConfig = Union[ConfigType, List[ConfigType]] 14 | OptMultiConfig = Optional[MultiConfig] 15 | 16 | InstanceList = List[InstanceData] 17 | OptInstanceList = Optional[InstanceList] 18 | 19 | PixelList = List[PixelData] 20 | OptPixelList = Optional[PixelList] 21 | 22 | RangeType = Sequence[Tuple[int, int]] 23 | -------------------------------------------------------------------------------- /Segmentation/mmdet/utils/util_random.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Helpers for random number generators.""" 3 | import numpy as np 4 | 5 | 6 | def ensure_rng(rng=None): 7 | """Coerces input into a random number generator. 8 | 9 | If the input is None, then a global random state is returned. 10 | 11 | If the input is a numeric value, then that is used as a seed to construct a 12 | random state. Otherwise the input is returned as-is. 13 | 14 | Adapted from [1]_. 15 | 16 | Args: 17 | rng (int | numpy.random.RandomState | None): 18 | if None, then defaults to the global rng. Otherwise this can be an 19 | integer or a RandomState class 20 | Returns: 21 | (numpy.random.RandomState) : rng - 22 | a numpy random number generator 23 | 24 | References: 25 | .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501 26 | """ 27 | 28 | if rng is None: 29 | rng = np.random.mtrand._rand 30 | elif isinstance(rng, int): 31 | rng = np.random.RandomState(rng) 32 | else: 33 | rng = rng 34 | return rng 35 | -------------------------------------------------------------------------------- /Segmentation/mmdet/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | __version__ = '3.1.0' 4 | short_version = __version__ 5 | 6 | 7 | def parse_version_info(version_str): 8 | """Parse a version string into a tuple. 9 | 10 | Args: 11 | version_str (str): The version string. 12 | Returns: 13 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 14 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 15 | """ 16 | version_info = [] 17 | for x in version_str.split('.'): 18 | if x.isdigit(): 19 | version_info.append(int(x)) 20 | elif x.find('rc') != -1: 21 | patch_version = x.split('rc') 22 | version_info.append(int(patch_version[0])) 23 | version_info.append(f'rc{patch_version[1]}') 24 | return tuple(version_info) 25 | 26 | 27 | version_info = parse_version_info(__version__) 28 | -------------------------------------------------------------------------------- /Segmentation/mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_model, init_model, show_result_pyplot 3 | from .mmseg_inferencer import MMSegInferencer 4 | 5 | __all__ = [ 6 | 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer' 7 | ] 8 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class ChaseDB1Dataset(BaseSegDataset): 10 | """Chase_db1 dataset. 11 | 12 | In segmentation map annotation for Chase_db1, 0 stands for background, 13 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False. 14 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '_1stHO.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'vessel'), 19 | palette=[[120, 120, 120], [6, 230, 230]]) 20 | 21 | def __init__(self, 22 | img_suffix='.png', 23 | seg_map_suffix='_1stHO.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | assert fileio.exists( 32 | self.data_prefix['img_path'], backend_args=self.backend_args) 33 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class CityscapesDataset(BaseSegDataset): 8 | """Cityscapes dataset. 9 | 10 | The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is 11 | fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. 12 | """ 13 | METAINFO = dict( 14 | classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 15 | 'traffic light', 'traffic sign', 'vegetation', 'terrain', 16 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 17 | 'motorcycle', 'bicycle'), 18 | palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 19 | [190, 153, 153], [153, 153, 153], [250, 170, 20 | 30], [220, 220, 0], 21 | [107, 142, 35], [152, 251, 152], [70, 130, 180], 22 | [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], 23 | [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) 24 | 25 | def __init__(self, 26 | img_suffix='_leftImg8bit.png', 27 | seg_map_suffix='_gtFine_labelTrainIds.png', 28 | **kwargs) -> None: 29 | super().__init__( 30 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 31 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/ddd17.py: -------------------------------------------------------------------------------- 1 | # 大概的转换思路是,首先将数据集转化为ade20k的存储格式,然后按照ade20k的存储方法存放 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | from mmseg.registry import DATASETS 4 | from .basesegdataset import BaseSegDataset 5 | import copy 6 | import os.path as osp 7 | from typing import Callable, Dict, List, Optional, Sequence, Union 8 | 9 | import mmengine 10 | import mmengine.fileio as fileio 11 | import numpy as np 12 | from mmengine.dataset import BaseDataset, Compose 13 | 14 | from mmseg.registry import DATASETS 15 | 16 | @DATASETS.register_module() 17 | class DDD17Dataset(BaseSegDataset): 18 | """ADE20K dataset. 19 | 20 | In segmentation map annotation for ADE20K, 0 stands for background, which 21 | is not included in 150 categories. ``reduce_zero_label`` is fixed to True. 22 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 23 | '.png'. 24 | """ 25 | METAINFO = dict( 26 | classes=('flat', 'construction+sky', 'object', 'nature', 'human', 'vehicle'), 27 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 28 | [4, 200, 3], [120, 120, 80]]) 29 | 30 | def __init__(self, 31 | img_suffix='.npy', 32 | seg_map_suffix='.png', 33 | **kwargs) -> None: 34 | super().__init__( 35 | img_suffix=img_suffix, 36 | seg_map_suffix=seg_map_suffix, 37 | **kwargs) 38 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from .basesegdataset import BaseSegDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class DRIVEDataset(BaseSegDataset): 10 | """DRIVE dataset. 11 | 12 | In segmentation map annotation for DRIVE, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '_manual1.png'. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'vessel'), 19 | palette=[[120, 120, 120], [6, 230, 230]]) 20 | 21 | def __init__(self, 22 | img_suffix='.png', 23 | seg_map_suffix='_manual1.png', 24 | reduce_zero_label=False, 25 | **kwargs) -> None: 26 | super().__init__( 27 | img_suffix=img_suffix, 28 | seg_map_suffix=seg_map_suffix, 29 | reduce_zero_label=reduce_zero_label, 30 | **kwargs) 31 | assert fileio.exists( 32 | self.data_prefix['img_path'], backend_args=self.backend_args) 33 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/synapse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg.registry import DATASETS 3 | from .basesegdataset import BaseSegDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class SynapseDataset(BaseSegDataset): 8 | """Synapse dataset. 9 | 10 | Before dataset preprocess of Synapse, there are total 13 categories of 11 | foreground which does not include background. After preprocessing, 8 12 | foreground categories are kept while the other 5 foreground categories are 13 | handled as background. The ``img_suffix`` is fixed to '.jpg' and 14 | ``seg_map_suffix`` is fixed to '.png'. 15 | """ 16 | METAINFO = dict( 17 | classes=('background', 'aorta', 'gallbladder', 'left_kidney', 18 | 'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'), 19 | palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], 20 | [0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255], 21 | [240, 240, 240]]) 22 | 23 | def __init__(self, 24 | img_suffix='.jpg', 25 | seg_map_suffix='.png', 26 | **kwargs) -> None: 27 | super().__init__( 28 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) 29 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .formatting import PackSegInputs 3 | from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, 4 | LoadBiomedicalData, LoadBiomedicalImageFromFile, 5 | LoadImageFromNDArray, LoadMultipleRSImageFromFile, 6 | LoadSingleRSImageFromFile, LoadImageFromNpyFile) 7 | # yapf: disable 8 | from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, 9 | BioMedical3DRandomCrop, BioMedical3DRandomFlip, 10 | BioMedicalGaussianBlur, BioMedicalGaussianNoise, 11 | BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, 12 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 13 | RandomMosaic, RandomRotate, RandomRotFlip, Rerange, 14 | ResizeShortestEdge, ResizeToMultiple, RGB2Gray, 15 | SegRescale) 16 | 17 | # yapf: enable 18 | __all__ = [ 19 | 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', 20 | 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 21 | 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 22 | 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 23 | 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'LoadImageFromNpyFile', 24 | 'GenerateEdge', 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 25 | 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', 26 | 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput', 27 | 'LoadMultipleRSImageFromFile' 28 | ] 29 | -------------------------------------------------------------------------------- /Segmentation/mmseg/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmengine.fileio as fileio 5 | 6 | from mmseg.registry import DATASETS 7 | from .basesegdataset import BaseSegDataset 8 | 9 | 10 | @DATASETS.register_module() 11 | class PascalVOCDataset(BaseSegDataset): 12 | """Pascal VOC dataset. 13 | 14 | Args: 15 | split (str): Split txt file for Pascal VOC. 16 | """ 17 | METAINFO = dict( 18 | classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat', 19 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 20 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 21 | 'sofa', 'train', 'tvmonitor'), 22 | palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 23 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 24 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 25 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 26 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 27 | [0, 64, 128]]) 28 | 29 | def __init__(self, 30 | ann_file, 31 | img_suffix='.jpg', 32 | seg_map_suffix='.png', 33 | **kwargs) -> None: 34 | super().__init__( 35 | img_suffix=img_suffix, 36 | seg_map_suffix=seg_map_suffix, 37 | ann_file=ann_file, 38 | **kwargs) 39 | assert fileio.exists(self.data_prefix['img_path'], 40 | self.backend_args) and osp.isfile(self.ann_file) 41 | -------------------------------------------------------------------------------- /Segmentation/mmseg/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .hooks import SegVisualizationHook 3 | from .optimizers import (LayerDecayOptimizerConstructor, 4 | LearningRateDecayOptimizerConstructor) 5 | 6 | __all__ = [ 7 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', 8 | 'SegVisualizationHook' 9 | ] 10 | -------------------------------------------------------------------------------- /Segmentation/mmseg/engine/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .visualization_hook import SegVisualizationHook 3 | from .resetmodel_hook import ResetModelHook 4 | 5 | __all__ = ['SegVisualizationHook', 'ResetModelHook'] 6 | # __all__ = ['SegVisualizationHook'] -------------------------------------------------------------------------------- /Segmentation/mmseg/engine/hooks/cal_firing_rate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spikingjelly.clock_driven import functional 3 | from typing import Optional, Sequence 4 | from mmengine.hooks import Hook 5 | 6 | from mmseg.registry import HOOKS 7 | 8 | 9 | @HOOKS.register_module() 10 | class Get_lif_firing_num(Hook): 11 | """Docstring for NewHook. 12 | """ 13 | def __init__(self, **kwargs): 14 | super(Get_lif_firing_num, self).__init__( 15 | **kwargs) 16 | 17 | # def before_train_iter(self, 18 | # runner, 19 | # batch_idx: int, 20 | # data_batch: Optional[Sequence[dict]] = None) -> None: 21 | # import pdb; pdb.set_trace() 22 | 23 | # def after_train_iter(self, 24 | # runner, 25 | # batch_idx: int, 26 | # outputs: None, 27 | # data_batch: Optional[Sequence[dict]] = None) -> None: 28 | # torch.cuda.synchronize() 29 | # functional.reset_net(runner.model) 30 | 31 | def before_test_iter(self, 32 | runner, 33 | batch_idx: int, 34 | data_batch: Optional[Sequence[dict]] = None) -> None: 35 | import pdb; pdb.set_trace() 36 | -------------------------------------------------------------------------------- /Segmentation/mmseg/engine/hooks/resetmodel_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spikingjelly.clock_driven import functional 3 | from typing import Optional, Sequence 4 | from mmengine.hooks import Hook 5 | 6 | from mmseg.registry import HOOKS 7 | 8 | 9 | @HOOKS.register_module() 10 | class ResetModelHook(Hook): 11 | """Docstring for NewHook. 12 | """ 13 | def __init__(self, **kwargs): 14 | super(ResetModelHook, self).__init__( 15 | **kwargs) 16 | 17 | def before_train_iter(self, 18 | runner, 19 | batch_idx: int, 20 | data_batch: Optional[Sequence[dict]] = None) -> None: 21 | torch.cuda.synchronize() 22 | functional.reset_net(runner.model) 23 | 24 | 25 | def before_val_iter(self, 26 | runner, 27 | batch_idx: int, 28 | data_batch: Optional[Sequence[dict]] = None) -> None: 29 | torch.cuda.synchronize() 30 | functional.reset_net(runner.model) 31 | 32 | def before_test_iter(self, 33 | runner, 34 | batch_idx: int, 35 | data_batch: Optional[Sequence[dict]] = None) -> None: 36 | torch.cuda.synchronize() 37 | functional.reset_net(runner.model) 38 | 39 | -------------------------------------------------------------------------------- /Segmentation/mmseg/engine/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .layer_decay_optimizer_constructor import ( 3 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) 4 | 5 | __all__ = [ 6 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' 7 | ] 8 | -------------------------------------------------------------------------------- /Segmentation/mmseg/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .metrics import CityscapesMetric, IoUMetric 3 | 4 | __all__ = ['IoUMetric', 'CityscapesMetric'] 5 | -------------------------------------------------------------------------------- /Segmentation/mmseg/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .citys_metric import CityscapesMetric 3 | from .iou_metric import IoUMetric 4 | 5 | __all__ = ['IoUMetric', 'CityscapesMetric'] 6 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 4 | build_head, build_loss, build_segmentor) 5 | from .data_preprocessor import SegDataPreProcessor 6 | from .decode_heads import * # noqa: F401,F403 7 | from .losses import * # noqa: F401,F403 8 | from .necks import * # noqa: F401,F403 9 | from .segmentors import * # noqa: F401,F403 10 | 11 | __all__ = [ 12 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 13 | 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor' 14 | ] 15 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .sdtv2 import Spiking_vit_MetaFormer 4 | from .sdtv3 import Spiking_vit_MetaFormerv2 5 | from .sdtv3MAE import Spiking_vit_MetaFormerv3 6 | 7 | __all__ = [ 8 | 'Spiking_vit_MetaFormer', 'Spiking_vit_MetaFormerv2', 'Spiking_vit_MetaFormerv3' 9 | ] 10 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmseg.registry import MODELS 5 | 6 | BACKBONES = MODELS 7 | NECKS = MODELS 8 | HEADS = MODELS 9 | LOSSES = MODELS 10 | SEGMENTORS = MODELS 11 | 12 | 13 | def build_backbone(cfg): 14 | """Build backbone.""" 15 | warnings.warn('``build_backbone`` would be deprecated soon, please use ' 16 | '``mmseg.registry.MODELS.build()`` ') 17 | return BACKBONES.build(cfg) 18 | 19 | 20 | def build_neck(cfg): 21 | """Build neck.""" 22 | warnings.warn('``build_neck`` would be deprecated soon, please use ' 23 | '``mmseg.registry.MODELS.build()`` ') 24 | return NECKS.build(cfg) 25 | 26 | 27 | def build_head(cfg): 28 | """Build head.""" 29 | warnings.warn('``build_head`` would be deprecated soon, please use ' 30 | '``mmseg.registry.MODELS.build()`` ') 31 | return HEADS.build(cfg) 32 | 33 | 34 | def build_loss(cfg): 35 | """Build loss.""" 36 | warnings.warn('``build_loss`` would be deprecated soon, please use ' 37 | '``mmseg.registry.MODELS.build()`` ') 38 | return LOSSES.build(cfg) 39 | 40 | 41 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 42 | """Build segmentor.""" 43 | if train_cfg is not None or test_cfg is not None: 44 | warnings.warn( 45 | 'train_cfg and test_cfg is deprecated, ' 46 | 'please specify them in model', UserWarning) 47 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 48 | 'train_cfg specified in both outer field and model field ' 49 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 50 | 'test_cfg specified in both outer field and model field ' 51 | return SEGMENTORS.build( 52 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 53 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fpn_head import * 3 | from .mask2former_head import Mask2FormerHead 4 | from .maskformer_head import MaskFormerHead 5 | from mmdet.models import * 6 | 7 | 8 | __all__ = [ 9 | 'FPNHead', 10 | 'MaskFormerHead', 'Mask2FormerHead', 11 | 'FPNHead_SNN', 'QFPNHead' 12 | ] 13 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .boundary_loss import BoundaryLoss 4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 5 | cross_entropy, mask_cross_entropy) 6 | from .dice_loss import DiceLoss 7 | from .focal_loss import FocalLoss 8 | from .huasdorff_distance_loss import HuasdorffDisstanceLoss 9 | from .lovasz_loss import LovaszLoss 10 | from .ohem_cross_entropy_loss import OhemCrossEntropy 11 | from .tversky_loss import TverskyLoss 12 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 13 | 14 | __all__ = [ 15 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 16 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 17 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 18 | 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss', 19 | 'HuasdorffDisstanceLoss' 20 | ] 21 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/losses/boundary_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from mmseg.registry import MODELS 8 | 9 | 10 | @MODELS.register_module() 11 | class BoundaryLoss(nn.Module): 12 | """Boundary loss. 13 | 14 | This function is modified from 15 | `PIDNet `_. # noqa 16 | Licensed under the MIT License. 17 | 18 | 19 | Args: 20 | loss_weight (float): Weight of the loss. Defaults to 1.0. 21 | loss_name (str): Name of the loss item. If you want this loss 22 | item to be included into the backward graph, `loss_` must be the 23 | prefix of the name. Defaults to 'loss_boundary'. 24 | """ 25 | 26 | def __init__(self, 27 | loss_weight: float = 1.0, 28 | loss_name: str = 'loss_boundary'): 29 | super().__init__() 30 | self.loss_weight = loss_weight 31 | self.loss_name_ = loss_name 32 | 33 | def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor: 34 | """Forward function. 35 | Args: 36 | bd_pre (Tensor): Predictions of the boundary head. 37 | bd_gt (Tensor): Ground truth of the boundary. 38 | 39 | Returns: 40 | Tensor: Loss tensor. 41 | """ 42 | log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1) 43 | target_t = bd_gt.view(1, -1).float() 44 | 45 | pos_index = (target_t == 1) 46 | neg_index = (target_t == 0) 47 | 48 | weight = torch.zeros_like(log_p) 49 | pos_num = pos_index.sum() 50 | neg_num = neg_index.sum() 51 | sum_num = pos_num + neg_num 52 | weight[pos_index] = neg_num * 1.0 / sum_num 53 | weight[neg_index] = pos_num * 1.0 / sum_num 54 | 55 | loss = F.binary_cross_entropy_with_logits( 56 | log_p, target_t, weight, reduction='mean') 57 | 58 | return self.loss_weight * loss 59 | 60 | @property 61 | def loss_name(self): 62 | return self.loss_name_ 63 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fpn import * 3 | 4 | __all__ = [ 5 | 'FPN', 'FPN_SNN', 'QFPN' 6 | ] 7 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseSegmentor 3 | from .cascade_encoder_decoder import CascadeEncoderDecoder 4 | from .encoder_decoder import EncoderDecoder 5 | from .seg_tta import SegTTAModel 6 | 7 | __all__ = [ 8 | 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel' 9 | ] 10 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/segmentors/seg_tta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List 3 | 4 | import torch 5 | from mmengine.model import BaseTTAModel 6 | from mmengine.structures import PixelData 7 | 8 | from mmseg.registry import MODELS 9 | from mmseg.utils import SampleList 10 | 11 | 12 | @MODELS.register_module() 13 | class SegTTAModel(BaseTTAModel): 14 | 15 | def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: 16 | """Merge predictions of enhanced data to one prediction. 17 | 18 | Args: 19 | data_samples_list (List[SampleList]): List of predictions 20 | of all enhanced data. 21 | 22 | Returns: 23 | SampleList: Merged prediction. 24 | """ 25 | predictions = [] 26 | for data_samples in data_samples_list: 27 | seg_logits = data_samples[0].seg_logits.data 28 | logits = torch.zeros(seg_logits.shape).to(seg_logits) 29 | for data_sample in data_samples: 30 | seg_logit = data_sample.seg_logits.data 31 | if self.module.out_channels > 1: 32 | logits += seg_logit.softmax(dim=0) 33 | else: 34 | logits += seg_logit.sigmoid() 35 | logits /= len(data_samples) 36 | if self.module.out_channels == 1: 37 | seg_pred = (logits > self.module.decode_head.threshold 38 | ).to(logits).squeeze(1) 39 | else: 40 | seg_pred = logits.argmax(dim=0) 41 | data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)}) 42 | if hasattr(data_samples[0], 'gt_sem_seg'): 43 | data_sample.set_data( 44 | {'gt_sem_seg': data_samples[0].gt_sem_seg}) 45 | data_sample.set_metainfo({'img_path': data_samples[0].img_path}) 46 | predictions.append(data_sample) 47 | return predictions 48 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/utils/Qtrick.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | quant = 4 5 | T = quant 6 | 7 | 8 | class Quant(torch.autograd.Function): 9 | @staticmethod 10 | @torch.cuda.amp.custom_fwd 11 | def forward(ctx, i, min_value=0, max_value=quant): 12 | ctx.min = min_value 13 | ctx.max = max_value 14 | ctx.save_for_backward(i) 15 | return torch.round(torch.clamp(i, min=min_value, max=max_value)) 16 | 17 | @staticmethod 18 | @torch.cuda.amp.custom_fwd 19 | def backward(ctx, grad_output): 20 | grad_input = grad_output.clone() 21 | i, = ctx.saved_tensors 22 | grad_input[i < ctx.min] = 0 23 | grad_input[i > ctx.max] = 0 24 | return grad_input, None, None 25 | 26 | 27 | class Multispike_norm(nn.Module): 28 | def __init__( 29 | self, 30 | T=T, # 在T上进行Norm 31 | ): 32 | super().__init__() 33 | self.spike = Quant() 34 | self.T = T 35 | 36 | def forward(self, x): 37 | 38 | return self.spike.apply(x) / self.T 39 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .basic_block import BasicBlock, Bottleneck 3 | from .embed import PatchEmbed 4 | from .encoding import Encoding 5 | from .inverted_residual import InvertedResidual, InvertedResidualV3 6 | from .make_divisible import make_divisible 7 | from .ppm import DAPPM, PAPPM 8 | from .res_layer import ResLayer 9 | from .se_layer import SELayer 10 | from .self_attention_block import SelfAttentionBlock 11 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, 12 | nlc_to_nchw) 13 | from .up_conv_block import UpConvBlock 14 | from .wrappers import Upsample, resize 15 | from .point_sample import (get_uncertain_point_coords_with_randomness, 16 | get_uncertainty) 17 | from .Qtrick import Multispike_norm 18 | 19 | # NOTE: from mmdet 20 | 21 | from .panoptic_gt_processing import preprocess_panoptic_gt 22 | from .misc import (aligned_bilinear, center_of_mass, empty_instances, 23 | filter_gt_instances, filter_scores_and_topk, flip_tensor, 24 | generate_coordinate, images_to_levels, interpolate_as, 25 | levels_to_images, mask2ndarray, multi_apply, 26 | relative_coordinate_maps, rename_loss_dict, 27 | reweight_loss_dict, samplelist_boxtype2tensor, 28 | select_single_mlvl, sigmoid_geometric_mean, 29 | unfold_wo_center, unmap, unpack_gt_instances) 30 | 31 | __all__ = [ 32 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 33 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 34 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', 35 | 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck', 36 | 'get_uncertain_point_coords_with_randomness', 'get_uncertainty', 37 | 'multi_apply', 'preprocess_panoptic_gt', "Multispike_norm" 38 | ] 39 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /Segmentation/mmseg/models/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | 30 | 31 | class Upsample(nn.Module): 32 | 33 | def __init__(self, 34 | size=None, 35 | scale_factor=None, 36 | mode='nearest', 37 | align_corners=None): 38 | super().__init__() 39 | self.size = size 40 | if isinstance(scale_factor, tuple): 41 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 42 | else: 43 | self.scale_factor = float(scale_factor) if scale_factor else None 44 | self.mode = mode 45 | self.align_corners = align_corners 46 | 47 | def forward(self, x): 48 | if not self.size: 49 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 50 | else: 51 | size = self.size 52 | return resize(x, size, None, self.mode, self.align_corners) 53 | -------------------------------------------------------------------------------- /Segmentation/mmseg/registry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS, 3 | LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, 4 | OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, 5 | PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, 6 | TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, 7 | WEIGHT_INITIALIZERS) 8 | 9 | __all__ = [ 10 | 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 11 | 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 12 | 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 13 | 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', 14 | 'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS' 15 | ] 16 | -------------------------------------------------------------------------------- /Segmentation/mmseg/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler 3 | from .seg_data_sample import SegDataSample 4 | 5 | __all__ = [ 6 | 'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler', 7 | 'build_pixel_sampler' 8 | ] 9 | -------------------------------------------------------------------------------- /Segmentation/mmseg/structures/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_pixel_sampler import BasePixelSampler 3 | from .builder import build_pixel_sampler 4 | from .ohem_pixel_sampler import OHEMPixelSampler 5 | 6 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 7 | -------------------------------------------------------------------------------- /Segmentation/mmseg/structures/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /Segmentation/mmseg/structures/sampler/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmseg.registry import TASK_UTILS 5 | 6 | PIXEL_SAMPLERS = TASK_UTILS 7 | 8 | 9 | def build_pixel_sampler(cfg, **default_args): 10 | """Build pixel sampler for segmentation map.""" 11 | warnings.warn( 12 | '``build_pixel_sampler`` would be deprecated soon, please use ' 13 | '``mmseg.registry.TASK_UTILS.build()`` ') 14 | return TASK_UTILS.build(cfg, default_args=default_args) 15 | -------------------------------------------------------------------------------- /Segmentation/mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # yapf: disable 3 | from .class_names import (ade_classes, ade_palette, cityscapes_classes, 4 | cityscapes_palette, cocostuff_classes, 5 | cocostuff_palette, dataset_aliases, get_classes, 6 | get_palette, isaid_classes, isaid_palette, 7 | loveda_classes, loveda_palette, potsdam_classes, 8 | potsdam_palette, stare_classes, stare_palette, 9 | synapse_classes, synapse_palette, vaihingen_classes, 10 | vaihingen_palette, voc_classes, voc_palette) 11 | # yapf: enable 12 | from .collect_env import collect_env 13 | from .io import datafrombytes 14 | from .misc import add_prefix, stack_batch, multi_apply 15 | from .set_env import register_all_modules 16 | from .typing_utils import (ConfigType, ForwardResults, MultiConfig, 17 | OptConfigType, OptMultiConfig, OptSampleList, 18 | SampleList, TensorDict, TensorList) 19 | 20 | __all__ = [ 21 | 'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix', 22 | 'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig', 23 | 'SampleList', 'OptSampleList', 'TensorDict', 'TensorList', 24 | 'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes', 25 | 'cocostuff_classes', 'loveda_classes', 'potsdam_classes', 26 | 'vaihingen_classes', 'isaid_classes', 'stare_classes', 27 | 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', 28 | 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', 29 | 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', 30 | 'datafrombytes', 'synapse_palette', 'synapse_classes', 'multi_apply', 31 | ] 32 | -------------------------------------------------------------------------------- /Segmentation/mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.utils import get_git_hash 3 | from mmengine.utils.dl_utils import collect_env as collect_base_env 4 | 5 | import mmseg 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print(f'{name}: {val}') 19 | -------------------------------------------------------------------------------- /Segmentation/mmseg/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import gzip 3 | import io 4 | import pickle 5 | 6 | import numpy as np 7 | 8 | 9 | def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray: 10 | """Data decoding from bytes. 11 | 12 | Args: 13 | content (bytes): The data bytes got from files or other streams. 14 | backend (str): The data decoding backend type. Options are 'numpy', 15 | 'nifti' and 'pickle'. Defaults to 'numpy'. 16 | 17 | Returns: 18 | numpy.ndarray: Loaded data array. 19 | """ 20 | if backend == 'pickle': 21 | data = pickle.loads(content) 22 | else: 23 | with io.BytesIO(content) as f: 24 | if backend == 'nifti': 25 | f = gzip.open(f) 26 | try: 27 | from nibabel import FileHolder, Nifti1Image 28 | except ImportError: 29 | print('nifti files io depends on nibabel, please run' 30 | '`pip install nibabel` to install it') 31 | fh = FileHolder(fileobj=f) 32 | data = Nifti1Image.from_file_map({'header': fh, 'image': fh}) 33 | data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata() 34 | elif backend == 'numpy': 35 | data = np.load(f) 36 | else: 37 | raise ValueError 38 | return data 39 | -------------------------------------------------------------------------------- /Segmentation/mmseg/utils/set_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import datetime 3 | import warnings 4 | 5 | from mmengine import DefaultScope 6 | 7 | 8 | def register_all_modules(init_default_scope: bool = True) -> None: 9 | """Register all modules in mmseg into the registries. 10 | 11 | Args: 12 | init_default_scope (bool): Whether initialize the mmseg default scope. 13 | When `init_default_scope=True`, the global default scope will be 14 | set to `mmseg`, and all registries will build modules from mmseg's 15 | registry node. To understand more about the registry, please refer 16 | to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md 17 | Defaults to True. 18 | """ # noqa 19 | import mmseg.datasets # noqa: F401,F403 20 | import mmseg.engine # noqa: F401,F403 21 | import mmseg.evaluation # noqa: F401,F403 22 | import mmseg.models # noqa: F401,F403 23 | import mmseg.structures # noqa: F401,F403 24 | 25 | if init_default_scope: 26 | never_created = DefaultScope.get_current_instance() is None \ 27 | or not DefaultScope.check_instance_created('mmseg') 28 | if never_created: 29 | DefaultScope.get_instance('mmseg', scope_name='mmseg') 30 | return 31 | current_scope = DefaultScope.get_current_instance() 32 | if current_scope.scope_name != 'mmseg': 33 | warnings.warn('The current default scope ' 34 | f'"{current_scope.scope_name}" is not "mmseg", ' 35 | '`register_all_modules` will force the current' 36 | 'default scope to be "mmseg". If this is not ' 37 | 'expected, please set `init_default_scope=False`.') 38 | # avoid name conflict 39 | new_instance_name = f'mmseg-{datetime.datetime.now()}' 40 | DefaultScope.get_instance(new_instance_name, scope_name='mmseg') 41 | -------------------------------------------------------------------------------- /Segmentation/mmseg/utils/typing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Collecting some commonly used type hint in mmflow.""" 3 | from typing import Dict, List, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | from mmengine.config import ConfigDict 7 | 8 | from mmseg.structures import SegDataSample 9 | 10 | # Type hint of config data 11 | ConfigType = Union[ConfigDict, dict] 12 | OptConfigType = Optional[ConfigType] 13 | # Type hint of one or more config data 14 | MultiConfig = Union[ConfigType, Sequence[ConfigType]] 15 | OptMultiConfig = Optional[MultiConfig] 16 | 17 | SampleList = Sequence[SegDataSample] 18 | OptSampleList = Optional[SampleList] 19 | 20 | # Type hint of Tensor 21 | TensorDict = Dict[str, torch.Tensor] 22 | TensorList = Sequence[torch.Tensor] 23 | 24 | ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample], 25 | Tuple[torch.Tensor], torch.Tensor] 26 | -------------------------------------------------------------------------------- /Segmentation/mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '1.1.1' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /Segmentation/mmseg/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .local_visualizer import SegLocalVisualizer 3 | 4 | __all__ = ['SegLocalVisualizer'] 5 | -------------------------------------------------------------------------------- /Segmentation/model-index.yml: -------------------------------------------------------------------------------- 1 | Import: 2 | - configs/ann/metafile.yaml 3 | - configs/apcnet/metafile.yaml 4 | - configs/beit/metafile.yaml 5 | - configs/bisenetv1/metafile.yaml 6 | - configs/bisenetv2/metafile.yaml 7 | - configs/ccnet/metafile.yaml 8 | - configs/cgnet/metafile.yaml 9 | - configs/convnext/metafile.yaml 10 | - configs/danet/metafile.yaml 11 | - configs/ddrnet/metafile.yaml 12 | - configs/deeplabv3/metafile.yaml 13 | - configs/deeplabv3plus/metafile.yaml 14 | - configs/dmnet/metafile.yaml 15 | - configs/dnlnet/metafile.yaml 16 | - configs/dpt/metafile.yaml 17 | - configs/emanet/metafile.yaml 18 | - configs/encnet/metafile.yaml 19 | - configs/erfnet/metafile.yaml 20 | - configs/fastfcn/metafile.yaml 21 | - configs/fastscnn/metafile.yaml 22 | - configs/fcn/metafile.yaml 23 | - configs/gcnet/metafile.yaml 24 | - configs/hrnet/metafile.yaml 25 | - configs/icnet/metafile.yaml 26 | - configs/isanet/metafile.yaml 27 | - configs/knet/metafile.yaml 28 | - configs/mae/metafile.yaml 29 | - configs/mask2former/metafile.yaml 30 | - configs/maskformer/metafile.yaml 31 | - configs/mobilenet_v2/metafile.yaml 32 | - configs/mobilenet_v3/metafile.yaml 33 | - configs/nonlocal_net/metafile.yaml 34 | - configs/ocrnet/metafile.yaml 35 | - configs/pidnet/metafile.yaml 36 | - configs/point_rend/metafile.yaml 37 | - configs/poolformer/metafile.yaml 38 | - configs/psanet/metafile.yaml 39 | - configs/pspnet/metafile.yaml 40 | - configs/resnest/metafile.yaml 41 | - configs/segformer/metafile.yaml 42 | - configs/segmenter/metafile.yaml 43 | - configs/segnext/metafile.yaml 44 | - configs/sem_fpn/metafile.yaml 45 | - configs/setr/metafile.yaml 46 | - configs/stdc/metafile.yaml 47 | - configs/swin/metafile.yaml 48 | - configs/twins/metafile.yaml 49 | - configs/unet/metafile.yaml 50 | - configs/upernet/metafile.yaml 51 | - configs/vit/metafile.yaml 52 | -------------------------------------------------------------------------------- /Segmentation/requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/optional.txt 2 | -r requirements/runtime.txt 3 | -r requirements/tests.txt 4 | -------------------------------------------------------------------------------- /Segmentation/requirements/albu.txt: -------------------------------------------------------------------------------- 1 | albumentations>=0.3.2 --no-binary qudida,albumentations 2 | -------------------------------------------------------------------------------- /Segmentation/requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | myst-parser 3 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 4 | sphinx==4.0.2 5 | sphinx_copybutton 6 | sphinx_markdown_tables 7 | urllib3<2.0.0 8 | -------------------------------------------------------------------------------- /Segmentation/requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv>=2.0.0rc4 2 | mmengine>=0.5.0,<1.0.0 3 | -------------------------------------------------------------------------------- /Segmentation/requirements/optional.txt: -------------------------------------------------------------------------------- 1 | cityscapesscripts 2 | nibabel 3 | -------------------------------------------------------------------------------- /Segmentation/requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv>=2.0.0rc1,<2.1.0 2 | mmengine>=0.4.0,<1.0.0 3 | prettytable 4 | scipy 5 | torch 6 | torchvision 7 | -------------------------------------------------------------------------------- /Segmentation/requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | packaging 4 | prettytable 5 | scipy 6 | -------------------------------------------------------------------------------- /Segmentation/requirements/tests.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | flake8 3 | interrogate 4 | pytest 5 | xdoctest>=0.10.0 6 | yapf 7 | -------------------------------------------------------------------------------- /Segmentation/setup.cfg: -------------------------------------------------------------------------------- 1 | [yapf] 2 | based_on_style = pep8 3 | blank_line_before_nested_class_or_def = true 4 | split_before_expression_after_opening_paren = true 5 | 6 | [isort] 7 | line_length = 79 8 | multi_line_output = 0 9 | extra_standard_library = setuptools 10 | known_first_party = mmseg 11 | known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,packaging,prettytable,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,torch,ts 12 | no_lines_before = STDLIB,LOCALFOLDER 13 | default_section = THIRDPARTY 14 | 15 | [codespell] 16 | skip = *.po,*.ts,*.ipynb 17 | count = 18 | quiet-level = 3 19 | ignore-words-list = formating,sur,hist,dota,warmup,damon 20 | -------------------------------------------------------------------------------- /Segmentation/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_digit_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmseg import digit_version 3 | 4 | 5 | def test_digit_version(): 6 | assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0) 7 | assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0) 8 | assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0) 9 | assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1) 10 | assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0) 11 | assert digit_version('1.0') == digit_version('1.0.0') 12 | assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5') 13 | assert digit_version('1.0.0dev') < digit_version('1.0.0a') 14 | assert digit_version('1.0.0a') < digit_version('1.0.0a1') 15 | assert digit_version('1.0.0a') < digit_version('1.0.0b') 16 | assert digit_version('1.0.0b') < digit_version('1.0.0rc') 17 | assert digit_version('1.0.0rc1') < digit_version('1.0.0') 18 | assert digit_version('1.0.0') < digit_version('1.0.0post') 19 | assert digit_version('1.0.0post') < digit_version('1.0.0post1') 20 | assert digit_version('v1') == (1, 0, 0, 0, 0, 0) 21 | assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0) 22 | -------------------------------------------------------------------------------- /Segmentation/tests/test_engine/test_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmengine.optim import build_optim_wrapper 5 | 6 | 7 | class ExampleModel(nn.Module): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self.param1 = nn.Parameter(torch.ones(1)) 12 | self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False) 13 | self.conv2 = nn.Conv2d(4, 2, kernel_size=1) 14 | self.bn = nn.BatchNorm2d(2) 15 | 16 | def forward(self, x): 17 | return x 18 | 19 | 20 | base_lr = 0.01 21 | base_wd = 0.0001 22 | momentum = 0.9 23 | 24 | 25 | def test_build_optimizer(): 26 | model = ExampleModel() 27 | optim_wrapper_cfg = dict( 28 | type='OptimWrapper', 29 | optimizer=dict( 30 | type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)) 31 | optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) 32 | # test whether optimizer is successfully built from parent. 33 | assert isinstance(optim_wrapper.optimizer, torch.optim.SGD) 34 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/test_bisenetv2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmseg.models.backbones import BiSeNetV2 6 | from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch, 7 | SemanticBranch) 8 | 9 | 10 | def test_bisenetv2_backbone(): 11 | # Test BiSeNetV2 Standard Forward 12 | model = BiSeNetV2() 13 | model.init_weights() 14 | model.train() 15 | batch_size = 2 16 | imgs = torch.randn(batch_size, 3, 128, 256) 17 | feat = model(imgs) 18 | 19 | assert len(feat) == 5 20 | # output for segment Head 21 | assert feat[0].shape == torch.Size([batch_size, 128, 16, 32]) 22 | # for auxiliary head 1 23 | assert feat[1].shape == torch.Size([batch_size, 16, 32, 64]) 24 | # for auxiliary head 2 25 | assert feat[2].shape == torch.Size([batch_size, 32, 16, 32]) 26 | # for auxiliary head 3 27 | assert feat[3].shape == torch.Size([batch_size, 64, 8, 16]) 28 | # for auxiliary head 4 29 | assert feat[4].shape == torch.Size([batch_size, 128, 4, 8]) 30 | 31 | # Test input with rare shape 32 | batch_size = 2 33 | imgs = torch.randn(batch_size, 3, 95, 27) 34 | feat = model(imgs) 35 | assert len(feat) == 5 36 | 37 | 38 | def test_bisenetv2_DetailBranch(): 39 | x = torch.randn(1, 3, 32, 64) 40 | detail_branch = DetailBranch(detail_channels=(64, 16, 32)) 41 | assert isinstance(detail_branch.detail_branch[0][0], ConvModule) 42 | x_out = detail_branch(x) 43 | assert x_out.shape == torch.Size([1, 32, 4, 8]) 44 | 45 | 46 | def test_bisenetv2_SemanticBranch(): 47 | semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128)) 48 | assert semantic_branch.stage1.pool.stride == 2 49 | 50 | 51 | def test_bisenetv2_BGALayer(): 52 | x_a = torch.randn(1, 8, 8, 16) 53 | x_b = torch.randn(1, 8, 2, 4) 54 | bga = BGALayer(out_channels=8) 55 | assert isinstance(bga.conv, ConvModule) 56 | x_out = bga(x_a, x_b) 57 | assert x_out.shape == torch.Size([1, 8, 8, 16]) 58 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/test_fast_scnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.backbones import FastSCNN 6 | 7 | 8 | def test_fastscnn_backbone(): 9 | with pytest.raises(AssertionError): 10 | # Fast-SCNN channel constraints. 11 | FastSCNN( 12 | 3, (32, 48), 13 | 64, (64, 96, 128), (2, 2, 1), 14 | global_out_channels=127, 15 | higher_in_channels=64, 16 | lower_in_channels=128) 17 | 18 | # Test FastSCNN Standard Forward 19 | model = FastSCNN( 20 | in_channels=3, 21 | downsample_dw_channels=(4, 6), 22 | global_in_channels=8, 23 | global_block_channels=(8, 12, 16), 24 | global_block_strides=(2, 2, 1), 25 | global_out_channels=16, 26 | higher_in_channels=8, 27 | lower_in_channels=16, 28 | fusion_out_channels=16, 29 | ) 30 | model.init_weights() 31 | model.train() 32 | batch_size = 4 33 | imgs = torch.randn(batch_size, 3, 64, 128) 34 | feat = model(imgs) 35 | 36 | assert len(feat) == 3 37 | # higher-res 38 | assert feat[0].shape == torch.Size([batch_size, 8, 8, 16]) 39 | # lower-res 40 | assert feat[1].shape == torch.Size([batch_size, 16, 2, 4]) 41 | # FFM output 42 | assert feat[2].shape == torch.Size([batch_size, 16, 8, 16]) 43 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/test_icnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.backbones import ICNet 6 | 7 | 8 | def test_icnet_backbone(): 9 | with pytest.raises(TypeError): 10 | # Must give backbone dict in config file. 11 | ICNet( 12 | in_channels=3, 13 | layer_channels=(128, 512), 14 | light_branch_middle_channels=8, 15 | psp_out_channels=128, 16 | out_channels=(16, 128, 128), 17 | backbone_cfg=None) 18 | 19 | # Test ICNet Standard Forward 20 | model = ICNet( 21 | layer_channels=(128, 512), 22 | backbone_cfg=dict( 23 | type='ResNetV1c', 24 | in_channels=3, 25 | depth=18, 26 | num_stages=4, 27 | out_indices=(0, 1, 2, 3), 28 | dilations=(1, 1, 2, 4), 29 | strides=(1, 2, 1, 1), 30 | norm_cfg=dict(type='BN', requires_grad=True), 31 | norm_eval=False, 32 | style='pytorch', 33 | contract_dilation=True), 34 | ) 35 | assert hasattr(model.backbone, 36 | 'maxpool') and model.backbone.maxpool.ceil_mode is True 37 | model.init_weights() 38 | model.train() 39 | batch_size = 2 40 | imgs = torch.randn(batch_size, 3, 32, 64) 41 | feat = model(imgs) 42 | 43 | assert model.psp_modules[0][0].output_size == 1 44 | assert model.psp_modules[1][0].output_size == 2 45 | assert model.psp_modules[2][0].output_size == 3 46 | assert model.psp_bottleneck.padding == 1 47 | assert model.conv_sub1[0].padding == 1 48 | 49 | assert len(feat) == 3 50 | assert feat[0].shape == torch.Size([batch_size, 64, 4, 8]) 51 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/test_mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.backbones import MobileNetV3 6 | 7 | 8 | def test_mobilenet_v3(): 9 | with pytest.raises(AssertionError): 10 | # check invalid arch 11 | MobileNetV3('big') 12 | 13 | with pytest.raises(AssertionError): 14 | # check invalid reduction_factor 15 | MobileNetV3(reduction_factor=0) 16 | 17 | with pytest.raises(ValueError): 18 | # check invalid out_indices 19 | MobileNetV3(out_indices=(0, 1, 15)) 20 | 21 | with pytest.raises(ValueError): 22 | # check invalid frozen_stages 23 | MobileNetV3(frozen_stages=15) 24 | 25 | with pytest.raises(TypeError): 26 | # check invalid pretrained 27 | model = MobileNetV3() 28 | model.init_weights(pretrained=8) 29 | 30 | # Test MobileNetV3 with default settings 31 | model = MobileNetV3() 32 | model.init_weights() 33 | model.train() 34 | 35 | imgs = torch.randn(2, 3, 56, 56) 36 | feat = model(imgs) 37 | assert len(feat) == 3 38 | assert feat[0].shape == (2, 16, 28, 28) 39 | assert feat[1].shape == (2, 16, 14, 14) 40 | assert feat[2].shape == (2, 576, 7, 7) 41 | 42 | # Test MobileNetV3 with arch = 'large' 43 | model = MobileNetV3(arch='large', out_indices=(1, 3, 16)) 44 | model.init_weights() 45 | model.train() 46 | 47 | imgs = torch.randn(2, 3, 56, 56) 48 | feat = model(imgs) 49 | assert len(feat) == 3 50 | assert feat[0].shape == (2, 16, 28, 28) 51 | assert feat[1].shape == (2, 24, 14, 14) 52 | assert feat[2].shape == (2, 960, 7, 7) 53 | 54 | # Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5 55 | model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5) 56 | with pytest.raises(TypeError): 57 | # check invalid pretrained 58 | model.init_weights(pretrained=8) 59 | model.init_weights() 60 | model.train() 61 | 62 | imgs = torch.randn(2, 3, 56, 56) 63 | feat = model(imgs) 64 | assert len(feat) == 3 65 | assert feat[0].shape == (2, 16, 28, 28) 66 | assert feat[1].shape == (2, 16, 14, 14) 67 | assert feat[2].shape == (2, 576, 7, 7) 68 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/test_resnest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.backbones import ResNeSt 6 | from mmseg.models.backbones.resnest import Bottleneck as BottleneckS 7 | 8 | 9 | def test_resnest_bottleneck(): 10 | with pytest.raises(AssertionError): 11 | # Style must be in ['pytorch', 'caffe'] 12 | BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow') 13 | 14 | # Test ResNeSt Bottleneck structure 15 | block = BottleneckS( 16 | 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch') 17 | assert block.avd_layer.stride == 2 18 | assert block.conv2.channels == 256 19 | 20 | # Test ResNeSt Bottleneck forward 21 | block = BottleneckS(64, 16, radix=2, reduction_factor=4) 22 | x = torch.randn(2, 64, 56, 56) 23 | x_out = block(x) 24 | assert x_out.shape == torch.Size([2, 64, 56, 56]) 25 | 26 | 27 | def test_resnest_backbone(): 28 | with pytest.raises(KeyError): 29 | # ResNeSt depth should be in [50, 101, 152, 200] 30 | ResNeSt(depth=18) 31 | 32 | # Test ResNeSt with radix 2, reduction_factor 4 33 | model = ResNeSt( 34 | depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3)) 35 | model.init_weights() 36 | model.train() 37 | 38 | imgs = torch.randn(2, 3, 224, 224) 39 | feat = model(imgs) 40 | assert len(feat) == 4 41 | assert feat[0].shape == torch.Size([2, 256, 56, 56]) 42 | assert feat[1].shape == torch.Size([2, 512, 28, 28]) 43 | assert feat[2].shape == torch.Size([2, 1024, 14, 14]) 44 | assert feat[3].shape == torch.Size([2, 2048, 7, 7]) 45 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/test_resnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.backbones import ResNeXt 6 | from mmseg.models.backbones.resnext import Bottleneck as BottleneckX 7 | from .utils import is_block 8 | 9 | 10 | def test_renext_bottleneck(): 11 | with pytest.raises(AssertionError): 12 | # Style must be in ['pytorch', 'caffe'] 13 | BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') 14 | 15 | # Test ResNeXt Bottleneck structure 16 | block = BottleneckX( 17 | 64, 64, groups=32, base_width=4, stride=2, style='pytorch') 18 | assert block.conv2.stride == (2, 2) 19 | assert block.conv2.groups == 32 20 | assert block.conv2.out_channels == 128 21 | 22 | # Test ResNeXt Bottleneck with DCN 23 | dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) 24 | with pytest.raises(AssertionError): 25 | # conv_cfg must be None if dcn is not None 26 | BottleneckX( 27 | 64, 28 | 64, 29 | groups=32, 30 | base_width=4, 31 | dcn=dcn, 32 | conv_cfg=dict(type='Conv')) 33 | BottleneckX(64, 64, dcn=dcn) 34 | 35 | # Test ResNeXt Bottleneck forward 36 | block = BottleneckX(64, 16, groups=32, base_width=4) 37 | x = torch.randn(1, 64, 56, 56) 38 | x_out = block(x) 39 | assert x_out.shape == torch.Size([1, 64, 56, 56]) 40 | 41 | 42 | def test_resnext_backbone(): 43 | with pytest.raises(KeyError): 44 | # ResNeXt depth should be in [50, 101, 152] 45 | ResNeXt(depth=18) 46 | 47 | # Test ResNeXt with group 32, base_width 4 48 | model = ResNeXt(depth=50, groups=32, base_width=4) 49 | print(model) 50 | for m in model.modules(): 51 | if is_block(m): 52 | assert m.conv2.groups == 32 53 | model.init_weights() 54 | model.train() 55 | 56 | imgs = torch.randn(1, 3, 224, 224) 57 | feat = model(imgs) 58 | assert len(feat) == 4 59 | assert feat[0].shape == torch.Size([1, 256, 56, 56]) 60 | assert feat[1].shape == torch.Size([1, 512, 28, 28]) 61 | assert feat[2].shape == torch.Size([1, 1024, 14, 14]) 62 | assert feat[3].shape == torch.Size([1, 2048, 7, 7]) 63 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch.nn.modules import GroupNorm 4 | from torch.nn.modules.batchnorm import _BatchNorm 5 | 6 | from mmseg.models.backbones.resnet import BasicBlock, Bottleneck 7 | from mmseg.models.backbones.resnext import Bottleneck as BottleneckX 8 | 9 | 10 | def is_block(modules): 11 | """Check if is ResNet building block.""" 12 | if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)): 13 | return True 14 | return False 15 | 16 | 17 | def is_norm(modules): 18 | """Check if is one of the norms.""" 19 | if isinstance(modules, (GroupNorm, _BatchNorm)): 20 | return True 21 | return False 22 | 23 | 24 | def all_zeros(modules): 25 | """Check if the weight(and bias) is all zero.""" 26 | weight_zero = torch.allclose(modules.weight.data, 27 | torch.zeros_like(modules.weight.data)) 28 | if hasattr(modules, 'bias'): 29 | bias_zero = torch.allclose(modules.bias.data, 30 | torch.zeros_like(modules.bias.data)) 31 | else: 32 | bias_zero = True 33 | 34 | return weight_zero and bias_zero 35 | 36 | 37 | def check_norm_state(modules, train_state): 38 | """Check if norm layer is in correct train state.""" 39 | for mod in modules: 40 | if isinstance(mod, _BatchNorm): 41 | if mod.training != train_state: 42 | return False 43 | return True 44 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_ann_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import ANNHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_ann_head(): 9 | 10 | inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 8, 21, 21)] 11 | head = ANNHead( 12 | in_channels=[4, 8], 13 | channels=2, 14 | num_classes=19, 15 | in_index=[-2, -1], 16 | project_channels=8) 17 | if torch.cuda.is_available(): 18 | head, inputs = to_cuda(head, inputs) 19 | outputs = head(inputs) 20 | assert outputs.shape == (1, head.num_classes, 21, 21) 21 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_apc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import APCHead 6 | from .utils import _conv_has_norm, to_cuda 7 | 8 | 9 | def test_apc_head(): 10 | 11 | with pytest.raises(AssertionError): 12 | # pool_scales must be list|tuple 13 | APCHead(in_channels=8, channels=2, num_classes=19, pool_scales=1) 14 | 15 | # test no norm_cfg 16 | head = APCHead(in_channels=8, channels=2, num_classes=19) 17 | assert not _conv_has_norm(head, sync_bn=False) 18 | 19 | # test with norm_cfg 20 | head = APCHead( 21 | in_channels=8, 22 | channels=2, 23 | num_classes=19, 24 | norm_cfg=dict(type='SyncBN')) 25 | assert _conv_has_norm(head, sync_bn=True) 26 | 27 | # fusion=True 28 | inputs = [torch.randn(1, 8, 45, 45)] 29 | head = APCHead( 30 | in_channels=8, 31 | channels=2, 32 | num_classes=19, 33 | pool_scales=(1, 2, 3), 34 | fusion=True) 35 | if torch.cuda.is_available(): 36 | head, inputs = to_cuda(head, inputs) 37 | assert head.fusion is True 38 | assert head.acm_modules[0].pool_scale == 1 39 | assert head.acm_modules[1].pool_scale == 2 40 | assert head.acm_modules[2].pool_scale == 3 41 | outputs = head(inputs) 42 | assert outputs.shape == (1, head.num_classes, 45, 45) 43 | 44 | # fusion=False 45 | inputs = [torch.randn(1, 8, 45, 45)] 46 | head = APCHead( 47 | in_channels=8, 48 | channels=2, 49 | num_classes=19, 50 | pool_scales=(1, 2, 3), 51 | fusion=False) 52 | if torch.cuda.is_available(): 53 | head, inputs = to_cuda(head, inputs) 54 | assert head.fusion is False 55 | assert head.acm_modules[0].pool_scale == 1 56 | assert head.acm_modules[1].pool_scale == 2 57 | assert head.acm_modules[2].pool_scale == 3 58 | outputs = head(inputs) 59 | assert outputs.shape == (1, head.num_classes, 45, 45) 60 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_cc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import CCHead 6 | from .utils import to_cuda 7 | 8 | 9 | def test_cc_head(): 10 | head = CCHead(in_channels=16, channels=8, num_classes=19) 11 | assert len(head.convs) == 2 12 | assert hasattr(head, 'cca') 13 | if not torch.cuda.is_available(): 14 | pytest.skip('CCHead requires CUDA') 15 | inputs = [torch.randn(1, 16, 23, 23)] 16 | head, inputs = to_cuda(head, inputs) 17 | outputs = head(inputs) 18 | assert outputs.shape == (1, head.num_classes, 23, 23) 19 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_dm_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import DMHead 6 | from .utils import _conv_has_norm, to_cuda 7 | 8 | 9 | def test_dm_head(): 10 | 11 | with pytest.raises(AssertionError): 12 | # filter_sizes must be list|tuple 13 | DMHead(in_channels=8, channels=4, num_classes=19, filter_sizes=1) 14 | 15 | # test no norm_cfg 16 | head = DMHead(in_channels=8, channels=4, num_classes=19) 17 | assert not _conv_has_norm(head, sync_bn=False) 18 | 19 | # test with norm_cfg 20 | head = DMHead( 21 | in_channels=8, 22 | channels=4, 23 | num_classes=19, 24 | norm_cfg=dict(type='SyncBN')) 25 | assert _conv_has_norm(head, sync_bn=True) 26 | 27 | # fusion=True 28 | inputs = [torch.randn(1, 8, 23, 23)] 29 | head = DMHead( 30 | in_channels=8, 31 | channels=4, 32 | num_classes=19, 33 | filter_sizes=(1, 3, 5), 34 | fusion=True) 35 | if torch.cuda.is_available(): 36 | head, inputs = to_cuda(head, inputs) 37 | assert head.fusion is True 38 | assert head.dcm_modules[0].filter_size == 1 39 | assert head.dcm_modules[1].filter_size == 3 40 | assert head.dcm_modules[2].filter_size == 5 41 | outputs = head(inputs) 42 | assert outputs.shape == (1, head.num_classes, 23, 23) 43 | 44 | # fusion=False 45 | inputs = [torch.randn(1, 8, 23, 23)] 46 | head = DMHead( 47 | in_channels=8, 48 | channels=4, 49 | num_classes=19, 50 | filter_sizes=(1, 3, 5), 51 | fusion=False) 52 | if torch.cuda.is_available(): 53 | head, inputs = to_cuda(head, inputs) 54 | assert head.fusion is False 55 | assert head.dcm_modules[0].filter_size == 1 56 | assert head.dcm_modules[1].filter_size == 3 57 | assert head.dcm_modules[2].filter_size == 5 58 | outputs = head(inputs) 59 | assert outputs.shape == (1, head.num_classes, 23, 23) 60 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_dnl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import DNLHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_dnl_head(): 9 | # DNL with 'embedded_gaussian' mode 10 | head = DNLHead(in_channels=8, channels=4, num_classes=19) 11 | assert len(head.convs) == 2 12 | assert hasattr(head, 'dnl_block') 13 | assert head.dnl_block.temperature == 0.05 14 | inputs = [torch.randn(1, 8, 23, 23)] 15 | if torch.cuda.is_available(): 16 | head, inputs = to_cuda(head, inputs) 17 | outputs = head(inputs) 18 | assert outputs.shape == (1, head.num_classes, 23, 23) 19 | 20 | # NonLocal2d with 'dot_product' mode 21 | head = DNLHead( 22 | in_channels=8, channels=4, num_classes=19, mode='dot_product') 23 | inputs = [torch.randn(1, 8, 23, 23)] 24 | if torch.cuda.is_available(): 25 | head, inputs = to_cuda(head, inputs) 26 | outputs = head(inputs) 27 | assert outputs.shape == (1, head.num_classes, 23, 23) 28 | 29 | # NonLocal2d with 'gaussian' mode 30 | head = DNLHead(in_channels=8, channels=4, num_classes=19, mode='gaussian') 31 | inputs = [torch.randn(1, 8, 23, 23)] 32 | if torch.cuda.is_available(): 33 | head, inputs = to_cuda(head, inputs) 34 | outputs = head(inputs) 35 | assert outputs.shape == (1, head.num_classes, 23, 23) 36 | 37 | # NonLocal2d with 'concatenation' mode 38 | head = DNLHead( 39 | in_channels=8, channels=4, num_classes=19, mode='concatenation') 40 | inputs = [torch.randn(1, 8, 23, 23)] 41 | if torch.cuda.is_available(): 42 | head, inputs = to_cuda(head, inputs) 43 | outputs = head(inputs) 44 | assert outputs.shape == (1, head.num_classes, 23, 23) 45 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_dpt_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import DPTHead 6 | 7 | 8 | def test_dpt_head(): 9 | 10 | with pytest.raises(AssertionError): 11 | # input_transform must be 'multiple_select' 12 | head = DPTHead( 13 | in_channels=[768, 768, 768, 768], 14 | channels=4, 15 | num_classes=19, 16 | in_index=[0, 1, 2, 3]) 17 | 18 | head = DPTHead( 19 | in_channels=[768, 768, 768, 768], 20 | channels=4, 21 | num_classes=19, 22 | in_index=[0, 1, 2, 3], 23 | input_transform='multiple_select') 24 | 25 | inputs = [[torch.randn(4, 768, 2, 2), 26 | torch.randn(4, 768)] for _ in range(4)] 27 | output = head(inputs) 28 | assert output.shape == torch.Size((4, 19, 16, 16)) 29 | 30 | # test readout operation 31 | head = DPTHead( 32 | in_channels=[768, 768, 768, 768], 33 | channels=4, 34 | num_classes=19, 35 | in_index=[0, 1, 2, 3], 36 | input_transform='multiple_select', 37 | readout_type='add') 38 | output = head(inputs) 39 | assert output.shape == torch.Size((4, 19, 16, 16)) 40 | 41 | head = DPTHead( 42 | in_channels=[768, 768, 768, 768], 43 | channels=4, 44 | num_classes=19, 45 | in_index=[0, 1, 2, 3], 46 | input_transform='multiple_select', 47 | readout_type='project') 48 | output = head(inputs) 49 | assert output.shape == torch.Size((4, 19, 16, 16)) 50 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_ema_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import EMAHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_emanet_head(): 9 | head = EMAHead( 10 | in_channels=4, 11 | ema_channels=3, 12 | channels=2, 13 | num_stages=3, 14 | num_bases=2, 15 | num_classes=19) 16 | for param in head.ema_mid_conv.parameters(): 17 | assert not param.requires_grad 18 | assert hasattr(head, 'ema_module') 19 | inputs = [torch.randn(1, 4, 23, 23)] 20 | if torch.cuda.is_available(): 21 | head, inputs = to_cuda(head, inputs) 22 | outputs = head(inputs) 23 | assert outputs.shape == (1, head.num_classes, 23, 23) 24 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_gc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import GCHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_gc_head(): 9 | head = GCHead(in_channels=4, channels=4, num_classes=19) 10 | assert len(head.convs) == 2 11 | assert hasattr(head, 'gc_block') 12 | inputs = [torch.randn(1, 4, 23, 23)] 13 | if torch.cuda.is_available(): 14 | head, inputs = to_cuda(head, inputs) 15 | outputs = head(inputs) 16 | assert outputs.shape == (1, head.num_classes, 23, 23) 17 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_ham_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import LightHamHead 5 | from .utils import _conv_has_norm, to_cuda 6 | 7 | ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) 8 | 9 | 10 | def test_ham_head(): 11 | 12 | # test without sync_bn 13 | head = LightHamHead( 14 | in_channels=[16, 32, 64], 15 | in_index=[1, 2, 3], 16 | channels=64, 17 | ham_channels=64, 18 | dropout_ratio=0.1, 19 | num_classes=19, 20 | norm_cfg=ham_norm_cfg, 21 | align_corners=False, 22 | loss_decode=dict( 23 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 24 | ham_kwargs=dict( 25 | MD_S=1, 26 | MD_R=64, 27 | train_steps=6, 28 | eval_steps=7, 29 | inv_t=100, 30 | rand_init=True)) 31 | assert not _conv_has_norm(head, sync_bn=False) 32 | 33 | inputs = [ 34 | torch.randn(1, 8, 32, 32), 35 | torch.randn(1, 16, 16, 16), 36 | torch.randn(1, 32, 8, 8), 37 | torch.randn(1, 64, 4, 4) 38 | ] 39 | if torch.cuda.is_available(): 40 | head, inputs = to_cuda(head, inputs) 41 | assert head.in_channels == [16, 32, 64] 42 | assert head.hamburger.ham_in.in_channels == 64 43 | outputs = head(inputs) 44 | assert outputs.shape == (1, head.num_classes, 16, 16) 45 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_isa_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import ISAHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_isa_head(): 9 | 10 | inputs = [torch.randn(1, 8, 23, 23)] 11 | isa_head = ISAHead( 12 | in_channels=8, 13 | channels=4, 14 | num_classes=19, 15 | isa_channels=4, 16 | down_factor=(8, 8)) 17 | if torch.cuda.is_available(): 18 | isa_head, inputs = to_cuda(isa_head, inputs) 19 | output = isa_head(inputs) 20 | assert output.shape == (1, isa_head.num_classes, 23, 23) 21 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_maskformer_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from os.path import dirname, join 3 | 4 | import torch 5 | from mmengine import Config 6 | from mmengine.registry import init_default_scope 7 | from mmengine.structures import PixelData 8 | 9 | from mmseg.registry import MODELS 10 | from mmseg.structures import SegDataSample 11 | 12 | 13 | def test_maskformer_head(): 14 | init_default_scope('mmseg') 15 | repo_dpath = dirname(dirname(__file__)) 16 | cfg = Config.fromfile( 17 | join( 18 | repo_dpath, 19 | '../../configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' # noqa 20 | )) 21 | cfg.model.train_cfg = None 22 | decode_head = MODELS.build(cfg.model.decode_head) 23 | inputs = (torch.randn(1, 256, 32, 32), torch.randn(1, 512, 16, 16), 24 | torch.randn(1, 1024, 8, 8), torch.randn(1, 2048, 4, 4)) 25 | # test inference 26 | batch_img_metas = [ 27 | dict( 28 | scale_factor=(1.0, 1.0), 29 | img_shape=(512, 683), 30 | ori_shape=(512, 683)) 31 | ] 32 | test_cfg = dict(mode='whole') 33 | output = decode_head.predict(inputs, batch_img_metas, test_cfg) 34 | assert output.shape == (1, 150, 512, 683) 35 | 36 | # test training 37 | inputs = (torch.randn(2, 256, 32, 32), torch.randn(2, 512, 16, 16), 38 | torch.randn(2, 1024, 8, 8), torch.randn(2, 2048, 4, 4)) 39 | batch_data_samples = [] 40 | img_meta = { 41 | 'img_shape': (512, 512), 42 | 'ori_shape': (480, 640), 43 | 'pad_shape': (512, 512), 44 | 'scale_factor': (1.425, 1.425), 45 | } 46 | for _ in range(2): 47 | data_sample = SegDataSample( 48 | gt_sem_seg=PixelData(data=torch.ones(512, 512).long())) 49 | data_sample.set_metainfo(img_meta) 50 | batch_data_samples.append(data_sample) 51 | train_cfg = {} 52 | losses = decode_head.loss(inputs, batch_data_samples, train_cfg) 53 | assert (loss in losses.keys() 54 | for loss in ('loss_cls', 'loss_mask', 'loss_dice')) 55 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_nl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import NLHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_nl_head(): 9 | head = NLHead(in_channels=8, channels=4, num_classes=19) 10 | assert len(head.convs) == 2 11 | assert hasattr(head, 'nl_block') 12 | inputs = [torch.randn(1, 8, 23, 23)] 13 | if torch.cuda.is_available(): 14 | head, inputs = to_cuda(head, inputs) 15 | outputs = head(inputs) 16 | assert outputs.shape == (1, head.num_classes, 23, 23) 17 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_ocr_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import FCNHead, OCRHead 5 | from .utils import to_cuda 6 | 7 | 8 | def test_ocr_head(): 9 | 10 | inputs = [torch.randn(1, 8, 23, 23)] 11 | ocr_head = OCRHead( 12 | in_channels=8, channels=4, num_classes=19, ocr_channels=8) 13 | fcn_head = FCNHead(in_channels=8, channels=4, num_classes=19) 14 | if torch.cuda.is_available(): 15 | head, inputs = to_cuda(ocr_head, inputs) 16 | head, inputs = to_cuda(fcn_head, inputs) 17 | prev_output = fcn_head(inputs) 18 | output = ocr_head(inputs, prev_output) 19 | assert output.shape == (1, ocr_head.num_classes, 23, 23) 20 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_psp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import PSPHead 6 | from .utils import _conv_has_norm, to_cuda 7 | 8 | 9 | def test_psp_head(): 10 | 11 | with pytest.raises(AssertionError): 12 | # pool_scales must be list|tuple 13 | PSPHead(in_channels=4, channels=2, num_classes=19, pool_scales=1) 14 | 15 | # test no norm_cfg 16 | head = PSPHead(in_channels=4, channels=2, num_classes=19) 17 | assert not _conv_has_norm(head, sync_bn=False) 18 | 19 | # test with norm_cfg 20 | head = PSPHead( 21 | in_channels=4, 22 | channels=2, 23 | num_classes=19, 24 | norm_cfg=dict(type='SyncBN')) 25 | assert _conv_has_norm(head, sync_bn=True) 26 | 27 | inputs = [torch.randn(1, 4, 23, 23)] 28 | head = PSPHead( 29 | in_channels=4, channels=2, num_classes=19, pool_scales=(1, 2, 3)) 30 | if torch.cuda.is_available(): 31 | head, inputs = to_cuda(head, inputs) 32 | assert head.psp_modules[0][0].output_size == 1 33 | assert head.psp_modules[1][0].output_size == 2 34 | assert head.psp_modules[2][0].output_size == 3 35 | outputs = head(inputs) 36 | assert outputs.shape == (1, head.num_classes, 23, 23) 37 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_segformer_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import SegformerHead 6 | 7 | 8 | def test_segformer_head(): 9 | with pytest.raises(AssertionError): 10 | # `in_channels` must have same length as `in_index` 11 | SegformerHead( 12 | in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2) 13 | 14 | H, W = (64, 64) 15 | in_channels = (32, 64, 160, 256) 16 | shapes = [(H // 2**(i + 2), W // 2**(i + 2)) 17 | for i in range(len(in_channels))] 18 | model = SegformerHead( 19 | in_channels=in_channels, 20 | in_index=[0, 1, 2, 3], 21 | channels=256, 22 | num_classes=19) 23 | 24 | with pytest.raises(IndexError): 25 | # in_index must match the input feature maps. 26 | inputs = [ 27 | torch.randn((1, in_channel, *shape)) 28 | for in_channel, shape in zip(in_channels, shapes) 29 | ][:3] 30 | temp = model(inputs) 31 | 32 | # Normal Input 33 | # ((1, 32, 16, 16), (1, 64, 8, 8), (1, 160, 4, 4), (1, 256, 2, 2) 34 | inputs = [ 35 | torch.randn((1, in_channel, *shape)) 36 | for in_channel, shape in zip(in_channels, shapes) 37 | ] 38 | temp = model(inputs) 39 | 40 | assert temp.shape == (1, 19, H // 4, W // 4) 41 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_segmenter_mask_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models.decode_heads import SegmenterMaskTransformerHead 5 | from .utils import _conv_has_norm, to_cuda 6 | 7 | 8 | def test_segmenter_mask_transformer_head(): 9 | head = SegmenterMaskTransformerHead( 10 | in_channels=2, 11 | channels=2, 12 | num_classes=150, 13 | num_layers=2, 14 | num_heads=3, 15 | embed_dims=192, 16 | dropout_ratio=0.0) 17 | assert _conv_has_norm(head, sync_bn=True) 18 | head.init_weights() 19 | 20 | inputs = [torch.randn(1, 2, 32, 32)] 21 | if torch.cuda.is_available(): 22 | head, inputs = to_cuda(head, inputs) 23 | outputs = head(inputs) 24 | assert outputs.shape == (1, head.num_classes, 32, 32) 25 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_setr_mla_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import SETRMLAHead 6 | from .utils import to_cuda 7 | 8 | 9 | def test_setr_mla_head(capsys): 10 | 11 | with pytest.raises(AssertionError): 12 | # MLA requires input multiple stage feature information. 13 | SETRMLAHead(in_channels=8, channels=4, num_classes=19, in_index=1) 14 | 15 | with pytest.raises(AssertionError): 16 | # multiple in_indexs requires multiple in_channels. 17 | SETRMLAHead( 18 | in_channels=8, channels=4, num_classes=19, in_index=(0, 1, 2, 3)) 19 | 20 | with pytest.raises(AssertionError): 21 | # channels should be len(in_channels) * mla_channels 22 | SETRMLAHead( 23 | in_channels=(8, 8, 8, 8), 24 | channels=8, 25 | mla_channels=4, 26 | in_index=(0, 1, 2, 3), 27 | num_classes=19) 28 | 29 | # test inference of MLA head 30 | img_size = (8, 8) 31 | patch_size = 4 32 | head = SETRMLAHead( 33 | in_channels=(8, 8, 8, 8), 34 | channels=16, 35 | mla_channels=4, 36 | in_index=(0, 1, 2, 3), 37 | num_classes=19, 38 | norm_cfg=dict(type='BN')) 39 | 40 | h, w = img_size[0] // patch_size, img_size[1] // patch_size 41 | # Input square NCHW format feature information 42 | x = [ 43 | torch.randn(1, 8, h, w), 44 | torch.randn(1, 8, h, w), 45 | torch.randn(1, 8, h, w), 46 | torch.randn(1, 8, h, w) 47 | ] 48 | if torch.cuda.is_available(): 49 | head, x = to_cuda(head, x) 50 | out = head(x) 51 | assert out.shape == (1, head.num_classes, h * 4, w * 4) 52 | 53 | # Input non-square NCHW format feature information 54 | x = [ 55 | torch.randn(1, 8, h, w * 2), 56 | torch.randn(1, 8, h, w * 2), 57 | torch.randn(1, 8, h, w * 2), 58 | torch.randn(1, 8, h, w * 2) 59 | ] 60 | if torch.cuda.is_available(): 61 | head, x = to_cuda(head, x) 62 | out = head(x) 63 | assert out.shape == (1, head.num_classes, h * 4, w * 8) 64 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_setr_up_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import SETRUPHead 6 | from .utils import to_cuda 7 | 8 | 9 | def test_setr_up_head(capsys): 10 | 11 | with pytest.raises(AssertionError): 12 | # kernel_size must be [1/3] 13 | SETRUPHead(num_classes=19, kernel_size=2) 14 | 15 | with pytest.raises(AssertionError): 16 | # in_channels must be int type and in_channels must be same 17 | # as embed_dim. 18 | SETRUPHead(in_channels=(4, 4), channels=2, num_classes=19) 19 | 20 | # test init_cfg of head 21 | head = SETRUPHead( 22 | in_channels=4, 23 | channels=2, 24 | norm_cfg=dict(type='SyncBN'), 25 | num_classes=19, 26 | init_cfg=dict(type='Kaiming')) 27 | super(SETRUPHead, head).init_weights() 28 | 29 | # test inference of Naive head 30 | # the auxiliary head of Naive head is same as Naive head 31 | img_size = (4, 4) 32 | patch_size = 2 33 | head = SETRUPHead( 34 | in_channels=4, 35 | channels=2, 36 | num_classes=19, 37 | num_convs=1, 38 | up_scale=4, 39 | kernel_size=1, 40 | norm_cfg=dict(type='BN')) 41 | 42 | h, w = img_size[0] // patch_size, img_size[1] // patch_size 43 | 44 | # Input square NCHW format feature information 45 | x = [torch.randn(1, 4, h, w)] 46 | if torch.cuda.is_available(): 47 | head, x = to_cuda(head, x) 48 | out = head(x) 49 | assert out.shape == (1, head.num_classes, h * 4, w * 4) 50 | 51 | # Input non-square NCHW format feature information 52 | x = [torch.randn(1, 4, h, w * 2)] 53 | if torch.cuda.is_available(): 54 | head, x = to_cuda(head, x) 55 | out = head(x) 56 | assert out.shape == (1, head.num_classes, h * 4, w * 8) 57 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/test_uper_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.decode_heads import UPerHead 6 | from .utils import _conv_has_norm, to_cuda 7 | 8 | 9 | def test_uper_head(): 10 | 11 | with pytest.raises(AssertionError): 12 | # fpn_in_channels must be list|tuple 13 | UPerHead(in_channels=4, channels=2, num_classes=19) 14 | 15 | # test no norm_cfg 16 | head = UPerHead( 17 | in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) 18 | assert not _conv_has_norm(head, sync_bn=False) 19 | 20 | # test with norm_cfg 21 | head = UPerHead( 22 | in_channels=[4, 2], 23 | channels=2, 24 | num_classes=19, 25 | norm_cfg=dict(type='SyncBN'), 26 | in_index=[-2, -1]) 27 | assert _conv_has_norm(head, sync_bn=True) 28 | 29 | inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 2, 21, 21)] 30 | head = UPerHead( 31 | in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) 32 | if torch.cuda.is_available(): 33 | head, inputs = to_cuda(head, inputs) 34 | outputs = head(inputs) 35 | assert outputs.shape == (1, head.num_classes, 45, 45) 36 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_heads/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import ConvModule 3 | from mmengine.utils.dl_utils.parrots_wrapper import SyncBatchNorm 4 | 5 | 6 | def _conv_has_norm(module, sync_bn): 7 | for m in module.modules(): 8 | if isinstance(m, ConvModule): 9 | if not m.with_norm: 10 | return False 11 | if sync_bn: 12 | if not isinstance(m.bn, SyncBatchNorm): 13 | return False 14 | return True 15 | 16 | 17 | def to_cuda(module, data): 18 | module = module.cuda() 19 | if isinstance(data, list): 20 | for i in range(len(data)): 21 | data[i] = data[i].cuda() 22 | return module, data 23 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_losses/test_huasdorff_distance_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.losses import HuasdorffDisstanceLoss 6 | 7 | 8 | def test_huasdorff_distance_loss(): 9 | loss_class = HuasdorffDisstanceLoss 10 | pred = torch.rand((10, 8, 6, 6)) 11 | target = torch.rand((10, 6, 6)) 12 | class_weight = torch.rand(8) 13 | 14 | # Test loss forward 15 | loss = loss_class()(pred, target) 16 | assert isinstance(loss, torch.Tensor) 17 | 18 | # Test loss forward with avg_factor 19 | loss = loss_class()(pred, target, avg_factor=10) 20 | assert isinstance(loss, torch.Tensor) 21 | 22 | # Test loss forward with avg_factor and reduction is None, 'sum' and 'mean' 23 | for reduction in [None, 'sum', 'mean']: 24 | loss = loss_class()(pred, target, avg_factor=10, reduction=reduction) 25 | assert isinstance(loss, torch.Tensor) 26 | 27 | # Test loss forward with class_weight 28 | with pytest.raises(AssertionError): 29 | loss_class(class_weight=class_weight)(pred, target) 30 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/test_feature2pyramid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models import Feature2Pyramid 6 | 7 | 8 | def test_feature2pyramid(): 9 | # test 10 | rescales = [4, 2, 1, 0.5] 11 | embed_dim = 64 12 | inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))] 13 | 14 | fpn = Feature2Pyramid( 15 | embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True)) 16 | outputs = fpn(inputs) 17 | assert outputs[0].shape == torch.Size([1, 64, 128, 128]) 18 | assert outputs[1].shape == torch.Size([1, 64, 64, 64]) 19 | assert outputs[2].shape == torch.Size([1, 64, 32, 32]) 20 | assert outputs[3].shape == torch.Size([1, 64, 16, 16]) 21 | 22 | # test rescales = [2, 1, 0.5, 0.25] 23 | rescales = [2, 1, 0.5, 0.25] 24 | inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))] 25 | 26 | fpn = Feature2Pyramid( 27 | embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True)) 28 | outputs = fpn(inputs) 29 | assert outputs[0].shape == torch.Size([1, 64, 64, 64]) 30 | assert outputs[1].shape == torch.Size([1, 64, 32, 32]) 31 | assert outputs[2].shape == torch.Size([1, 64, 16, 16]) 32 | assert outputs[3].shape == torch.Size([1, 64, 8, 8]) 33 | 34 | # test rescales = [4, 2, 0.25, 0] 35 | rescales = [4, 2, 0.25, 0] 36 | with pytest.raises(KeyError): 37 | fpn = Feature2Pyramid( 38 | embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True)) 39 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/test_fpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models import FPN 5 | 6 | 7 | def test_fpn(): 8 | in_channels = [64, 128, 256, 512] 9 | inputs = [ 10 | torch.randn(1, c, 56 // 2**i, 56 // 2**i) 11 | for i, c in enumerate(in_channels) 12 | ] 13 | 14 | fpn = FPN(in_channels, 64, len(in_channels)) 15 | outputs = fpn(inputs) 16 | assert outputs[0].shape == torch.Size([1, 64, 56, 56]) 17 | assert outputs[1].shape == torch.Size([1, 64, 28, 28]) 18 | assert outputs[2].shape == torch.Size([1, 64, 14, 14]) 19 | assert outputs[3].shape == torch.Size([1, 64, 7, 7]) 20 | 21 | fpn = FPN( 22 | in_channels, 23 | 64, 24 | len(in_channels), 25 | upsample_cfg=dict(mode='nearest', scale_factor=2.0)) 26 | outputs = fpn(inputs) 27 | assert outputs[0].shape == torch.Size([1, 64, 56, 56]) 28 | assert outputs[1].shape == torch.Size([1, 64, 28, 28]) 29 | assert outputs[2].shape == torch.Size([1, 64, 14, 14]) 30 | assert outputs[3].shape == torch.Size([1, 64, 7, 7]) 31 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/test_ic_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.necks import ICNeck 6 | from mmseg.models.necks.ic_neck import CascadeFeatureFusion 7 | from ..test_heads.utils import _conv_has_norm, to_cuda 8 | 9 | 10 | def test_ic_neck(): 11 | # test with norm_cfg 12 | neck = ICNeck( 13 | in_channels=(4, 16, 16), 14 | out_channels=8, 15 | norm_cfg=dict(type='SyncBN'), 16 | align_corners=False) 17 | assert _conv_has_norm(neck, sync_bn=True) 18 | 19 | inputs = [ 20 | torch.randn(1, 4, 32, 64), 21 | torch.randn(1, 16, 16, 32), 22 | torch.randn(1, 16, 8, 16) 23 | ] 24 | neck = ICNeck( 25 | in_channels=(4, 16, 16), 26 | out_channels=4, 27 | norm_cfg=dict(type='BN', requires_grad=True), 28 | align_corners=False) 29 | if torch.cuda.is_available(): 30 | neck, inputs = to_cuda(neck, inputs) 31 | 32 | outputs = neck(inputs) 33 | assert outputs[0].shape == (1, 4, 16, 32) 34 | assert outputs[1].shape == (1, 4, 32, 64) 35 | assert outputs[1].shape == (1, 4, 32, 64) 36 | 37 | 38 | def test_ic_neck_cascade_feature_fusion(): 39 | cff = CascadeFeatureFusion(64, 64, 32) 40 | assert cff.conv_low.in_channels == 64 41 | assert cff.conv_low.out_channels == 32 42 | assert cff.conv_high.in_channels == 64 43 | assert cff.conv_high.out_channels == 32 44 | 45 | 46 | def test_ic_neck_input_channels(): 47 | with pytest.raises(AssertionError): 48 | # ICNet Neck input channel constraints. 49 | ICNeck( 50 | in_channels=(16, 64, 64, 64), 51 | out_channels=32, 52 | norm_cfg=dict(type='BN', requires_grad=True), 53 | align_corners=False) 54 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/test_jpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmseg.models.necks import JPU 6 | 7 | 8 | def test_fastfcn_neck(): 9 | # Test FastFCN Standard Forward 10 | model = JPU( 11 | in_channels=(64, 128, 256), 12 | mid_channels=64, 13 | start_level=0, 14 | end_level=-1, 15 | dilations=(1, 2, 4, 8), 16 | ) 17 | model.init_weights() 18 | model.train() 19 | batch_size = 1 20 | input = [ 21 | torch.randn(batch_size, 64, 64, 128), 22 | torch.randn(batch_size, 128, 32, 64), 23 | torch.randn(batch_size, 256, 16, 32) 24 | ] 25 | feat = model(input) 26 | 27 | assert len(feat) == 3 28 | assert feat[0].shape == torch.Size([batch_size, 64, 64, 128]) 29 | assert feat[1].shape == torch.Size([batch_size, 128, 32, 64]) 30 | assert feat[2].shape == torch.Size([batch_size, 256, 64, 128]) 31 | 32 | with pytest.raises(AssertionError): 33 | # FastFCN input and in_channels constraints. 34 | JPU(in_channels=(256, 64, 128), start_level=0, end_level=5) 35 | 36 | # Test not default start_level 37 | model = JPU(in_channels=(64, 128, 256), start_level=1, end_level=-1) 38 | input = [ 39 | torch.randn(batch_size, 64, 64, 128), 40 | torch.randn(batch_size, 128, 32, 64), 41 | torch.randn(batch_size, 256, 16, 32) 42 | ] 43 | feat = model(input) 44 | assert len(feat) == 2 45 | assert feat[0].shape == torch.Size([batch_size, 128, 32, 64]) 46 | assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64]) 47 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/test_mla_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models import MLANeck 5 | 6 | 7 | def test_mla(): 8 | in_channels = [4, 4, 4, 4] 9 | mla = MLANeck(in_channels, 32) 10 | 11 | inputs = [torch.randn(1, c, 12, 12) for i, c in enumerate(in_channels)] 12 | outputs = mla(inputs) 13 | assert outputs[0].shape == torch.Size([1, 32, 12, 12]) 14 | assert outputs[1].shape == torch.Size([1, 32, 12, 12]) 15 | assert outputs[2].shape == torch.Size([1, 32, 12, 12]) 16 | assert outputs[3].shape == torch.Size([1, 32, 12, 12]) 17 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_necks/test_multilevel_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmseg.models import MultiLevelNeck 5 | 6 | 7 | def test_multilevel_neck(): 8 | 9 | # Test init_weights 10 | MultiLevelNeck([266], 32).init_weights() 11 | 12 | # Test multi feature maps 13 | in_channels = [32, 64, 128, 256] 14 | inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)] 15 | 16 | neck = MultiLevelNeck(in_channels, 32) 17 | outputs = neck(inputs) 18 | assert outputs[0].shape == torch.Size([1, 32, 7, 7]) 19 | assert outputs[1].shape == torch.Size([1, 32, 14, 14]) 20 | assert outputs[2].shape == torch.Size([1, 32, 28, 28]) 21 | assert outputs[3].shape == torch.Size([1, 32, 56, 56]) 22 | 23 | # Test one feature map 24 | in_channels = [768] 25 | inputs = [torch.randn(1, 768, 14, 14)] 26 | 27 | neck = MultiLevelNeck(in_channels, 32) 28 | outputs = neck(inputs) 29 | assert outputs[0].shape == torch.Size([1, 32, 7, 7]) 30 | assert outputs[1].shape == torch.Size([1, 32, 14, 14]) 31 | assert outputs[2].shape == torch.Size([1, 32, 28, 28]) 32 | assert outputs[3].shape == torch.Size([1, 32, 56, 56]) 33 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine import ConfigDict 3 | 4 | from mmseg.models import build_segmentor 5 | from .utils import _segmentor_forward_train_test 6 | 7 | 8 | def test_cascade_encoder_decoder(): 9 | 10 | # test 1 decode head, w.o. aux head 11 | cfg = ConfigDict( 12 | type='CascadeEncoderDecoder', 13 | num_stages=2, 14 | backbone=dict(type='ExampleBackbone'), 15 | decode_head=[ 16 | dict(type='ExampleDecodeHead'), 17 | dict(type='ExampleCascadeDecodeHead') 18 | ]) 19 | cfg.test_cfg = ConfigDict(mode='whole') 20 | segmentor = build_segmentor(cfg) 21 | _segmentor_forward_train_test(segmentor) 22 | 23 | # test slide mode 24 | cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) 25 | segmentor = build_segmentor(cfg) 26 | _segmentor_forward_train_test(segmentor) 27 | 28 | # test 1 decode head, 1 aux head 29 | cfg = ConfigDict( 30 | type='CascadeEncoderDecoder', 31 | num_stages=2, 32 | backbone=dict(type='ExampleBackbone'), 33 | decode_head=[ 34 | dict(type='ExampleDecodeHead'), 35 | dict(type='ExampleCascadeDecodeHead') 36 | ], 37 | auxiliary_head=dict(type='ExampleDecodeHead')) 38 | cfg.test_cfg = ConfigDict(mode='whole') 39 | segmentor = build_segmentor(cfg) 40 | _segmentor_forward_train_test(segmentor) 41 | 42 | # test 1 decode head, 2 aux head 43 | cfg = ConfigDict( 44 | type='CascadeEncoderDecoder', 45 | num_stages=2, 46 | backbone=dict(type='ExampleBackbone'), 47 | decode_head=[ 48 | dict(type='ExampleDecodeHead'), 49 | dict(type='ExampleCascadeDecodeHead') 50 | ], 51 | auxiliary_head=[ 52 | dict(type='ExampleDecodeHead'), 53 | dict(type='ExampleDecodeHead') 54 | ]) 55 | cfg.test_cfg = ConfigDict(mode='whole') 56 | segmentor = build_segmentor(cfg) 57 | _segmentor_forward_train_test(segmentor) 58 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_segmentors/test_seg_tta_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import tempfile 3 | 4 | import torch 5 | from mmengine import ConfigDict 6 | from mmengine.model import BaseTTAModel 7 | from mmengine.registry import init_default_scope 8 | from mmengine.structures import PixelData 9 | 10 | from mmseg.registry import MODELS 11 | from mmseg.structures import SegDataSample 12 | from .utils import * # noqa: F401,F403 13 | 14 | init_default_scope('mmseg') 15 | 16 | 17 | def test_encoder_decoder_tta(): 18 | 19 | segmentor_cfg = ConfigDict( 20 | type='EncoderDecoder', 21 | backbone=dict(type='ExampleBackbone'), 22 | decode_head=dict(type='ExampleDecodeHead'), 23 | train_cfg=None, 24 | test_cfg=dict(mode='whole')) 25 | 26 | cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg) 27 | 28 | model: BaseTTAModel = MODELS.build(cfg) 29 | 30 | imgs = [] 31 | data_samples = [] 32 | directions = ['horizontal', 'vertical'] 33 | for i in range(12): 34 | flip_direction = directions[0] if i % 3 == 0 else directions[1] 35 | imgs.append(torch.randn(1, 3, 10 + i, 10 + i)) 36 | data_samples.append([ 37 | SegDataSample( 38 | metainfo=dict( 39 | ori_shape=(10, 10), 40 | img_shape=(10 + i, 10 + i), 41 | flip=(i % 2 == 0), 42 | flip_direction=flip_direction, 43 | img_path=tempfile.mktemp()), 44 | gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10)))) 45 | ]) 46 | 47 | model.test_step(dict(inputs=imgs, data_samples=data_samples)) 48 | 49 | # test out_channels == 1 50 | segmentor_cfg = ConfigDict( 51 | type='EncoderDecoder', 52 | backbone=dict(type='ExampleBackbone'), 53 | decode_head=dict( 54 | type='ExampleDecodeHead', 55 | num_classes=2, 56 | out_channels=1, 57 | threshold=0.4), 58 | train_cfg=None, 59 | test_cfg=dict(mode='whole')) 60 | model.module = MODELS.build(segmentor_cfg) 61 | for data_sample in data_samples: 62 | data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10)) 63 | model.test_step(dict(inputs=imgs, data_samples=data_samples)) 64 | -------------------------------------------------------------------------------- /Segmentation/tests/test_models/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /Segmentation/tests/test_utils/test_io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import numpy as np 5 | import pytest 6 | from mmengine import FileClient 7 | 8 | from mmseg.utils import datafrombytes 9 | 10 | 11 | @pytest.mark.parametrize( 12 | ['backend', 'suffix'], 13 | [['nifti', '.nii.gz'], ['numpy', '.npy'], ['pickle', '.pkl']]) 14 | def test_datafrombytes(backend, suffix): 15 | 16 | file_client = FileClient('disk') 17 | file_path = osp.join(osp.dirname(__file__), '../data/biomedical' + suffix) 18 | bytes = file_client.get(file_path) 19 | data = datafrombytes(bytes, backend) 20 | 21 | if backend == 'pickle': 22 | # test pickle loading 23 | assert isinstance(data, dict) 24 | else: 25 | assert isinstance(data, np.ndarray) 26 | if backend == 'nifti': 27 | # test nifti file loading 28 | assert len(data.shape) == 3 29 | else: 30 | # test npy file loading 31 | # testing data biomedical.npy includes data and label 32 | assert len(data.shape) == 4 33 | assert data.shape[0] == 2 34 | -------------------------------------------------------------------------------- /Segmentation/tests/test_utils/test_set_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import datetime 3 | import sys 4 | from unittest import TestCase 5 | 6 | from mmengine import DefaultScope 7 | 8 | from mmseg.utils import register_all_modules 9 | 10 | 11 | class TestSetupEnv(TestCase): 12 | 13 | def test_register_all_modules(self): 14 | from mmseg.registry import DATASETS 15 | 16 | # not init default scope 17 | sys.modules.pop('mmseg.datasets', None) 18 | sys.modules.pop('mmseg.datasets.ade', None) 19 | DATASETS._module_dict.pop('ADE20KDataset', None) 20 | self.assertFalse('ADE20KDataset' in DATASETS.module_dict) 21 | register_all_modules(init_default_scope=False) 22 | self.assertTrue('ADE20KDataset' in DATASETS.module_dict) 23 | 24 | # init default scope 25 | sys.modules.pop('mmseg.datasets') 26 | sys.modules.pop('mmseg.datasets.ade') 27 | DATASETS._module_dict.pop('ADE20KDataset', None) 28 | self.assertFalse('ADE20KDataset' in DATASETS.module_dict) 29 | register_all_modules(init_default_scope=True) 30 | self.assertTrue('ADE20KDataset' in DATASETS.module_dict) 31 | self.assertEqual(DefaultScope.get_current_instance().scope_name, 32 | 'mmseg') 33 | 34 | # init default scope when another scope is init 35 | name = f'test-{datetime.datetime.now()}' 36 | DefaultScope.get_instance(name, scope_name='test') 37 | with self.assertWarnsRegex( 38 | Warning, 'The current default scope "test" is not "mmseg"'): 39 | register_all_modules(init_default_scope=True) 40 | -------------------------------------------------------------------------------- /Segmentation/tools/analysis_tools/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | COLOR_RED = "91m" 4 | COLOR_GREEN = "92m" 5 | COLOR_YELLOW = "93m" 6 | 7 | def colorful_print(fn_print, color=COLOR_RED): 8 | def actual_call(*args, **kwargs): 9 | print(f"\033[{color}", end="") 10 | fn_print(*args, **kwargs) 11 | print("\033[00m", end="") 12 | return actual_call 13 | 14 | prRed = colorful_print(print, color=COLOR_RED) 15 | prGreen = colorful_print(print, color=COLOR_GREEN) 16 | prYellow = colorful_print(print, color=COLOR_YELLOW) 17 | 18 | # def prRed(skk): 19 | # print("\033[91m{}\033[00m".format(skk)) 20 | 21 | # def prGreen(skk): 22 | # print("\033[92m{}\033[00m".format(skk)) 23 | 24 | # def prYellow(skk): 25 | # print("\033[93m{}\033[00m".format(skk)) 26 | 27 | 28 | def clever_format(nums, format="%.2f"): 29 | if not isinstance(nums, Iterable): 30 | nums = [nums] 31 | clever_nums = [] 32 | 33 | for num in nums: 34 | if num > 1e12: 35 | clever_nums.append(format % (num / 1e12) + "T") 36 | elif num > 1e9: 37 | clever_nums.append(format % (num / 1e9) + "G") 38 | elif num > 1e6: 39 | clever_nums.append(format % (num / 1e6) + "M") 40 | elif num > 1e3: 41 | clever_nums.append(format % (num / 1e3) + "K") 42 | else: 43 | clever_nums.append(format % num + "B") 44 | 45 | clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,) 46 | 47 | return clever_nums 48 | 49 | 50 | if __name__ == "__main__": 51 | prRed("hello", "world") 52 | prGreen("hello", "world") 53 | prYellow("hello", "world") -------------------------------------------------------------------------------- /Segmentation/tools/dataset_converters/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 6 | from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress, 7 | track_progress) 8 | 9 | 10 | def convert_json_to_label(json_file): 11 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 12 | json2labelImg(json_file, label_file, 'trainIds') 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description='Convert Cityscapes annotations to TrainIds') 18 | parser.add_argument('cityscapes_path', help='cityscapes data path') 19 | parser.add_argument('--gt-dir', default='gtFine', type=str) 20 | parser.add_argument('-o', '--out-dir', help='output path') 21 | parser.add_argument( 22 | '--nproc', default=1, type=int, help='number of process') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def main(): 28 | args = parse_args() 29 | cityscapes_path = args.cityscapes_path 30 | out_dir = args.out_dir if args.out_dir else cityscapes_path 31 | mkdir_or_exist(out_dir) 32 | 33 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 34 | 35 | poly_files = [] 36 | for poly in scandir(gt_dir, '_polygons.json', recursive=True): 37 | poly_file = osp.join(gt_dir, poly) 38 | poly_files.append(poly_file) 39 | if args.nproc > 1: 40 | track_parallel_progress(convert_json_to_label, poly_files, args.nproc) 41 | else: 42 | track_progress(convert_json_to_label, poly_files) 43 | 44 | split_names = ['train', 'val', 'test'] 45 | 46 | for split in split_names: 47 | filenames = [] 48 | for poly in scandir( 49 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 50 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 52 | f.writelines(f + '\n' for f in filenames) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /Segmentation/tools/dataset_converters/prophesee/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/tools/dataset_converters/prophesee/__init__.py -------------------------------------------------------------------------------- /Segmentation/tools/dataset_converters/prophesee/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/tools/dataset_converters/prophesee/io/__init__.py -------------------------------------------------------------------------------- /Segmentation/tools/dataset_converters/prophesee/io/box_filtering.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Prophesee S.A. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed 7 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8 | # See the License for the specific language governing permissions and limitations under the License. 9 | 10 | """ 11 | Define same filtering that we apply in: 12 | "Learning to detect objects on a 1 Megapixel Event Camera" by Etienne Perot et al. 13 | 14 | Namely, we apply 2 different filters: 15 | 1. skip all boxes before 0.5s (before we assume it is unlikely you have sufficient historic) 16 | 2. filter all boxes whose diagonal <= min_box_diag**2 and whose side <= min_box_side 17 | """ 18 | 19 | from __future__ import print_function 20 | import numpy as np 21 | 22 | 23 | def filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=20): 24 | """Filters boxes according to the paper rule. 25 | 26 | To note: the default represents our threshold when evaluating GEN4 resolution (1280x720) 27 | To note: we assume the initial time of the video is always 0 28 | 29 | Args: 30 | boxes (np.ndarray): structured box array with fields ['t','x','y','w','h','class_id','track_id','class_confidence'] 31 | (example BBOX_DTYPE is provided in src/box_loading.py) 32 | 33 | Returns: 34 | boxes: filtered boxes 35 | """ 36 | ts = boxes['t'] 37 | width = boxes['w'] 38 | height = boxes['h'] 39 | diag_square = width**2+height**2 40 | mask = (ts>skip_ts)*(diag_square >= min_box_diag**2)*(width >= min_box_side)*(height >= min_box_side) 41 | return boxes[mask] 42 | 43 | -------------------------------------------------------------------------------- /Segmentation/tools/dataset_converters/prophesee/io/box_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Prophesee S.A. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed 7 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8 | # See the License for the specific language governing permissions and limitations under the License. 9 | 10 | """ 11 | Defines some tools to handle events. 12 | In particular : 13 | -> defines events' types 14 | -> defines functions to read events from binary .dat files using numpy 15 | -> defines functions to write events to binary .dat files using numpy 16 | """ 17 | 18 | from __future__ import print_function 19 | import numpy as np 20 | 21 | BBOX_DTYPE = np.dtype({'names':['t','x','y','w','h','class_id','track_id','class_confidence'], 'formats':[' str: 21 | """Compute SHA256 message digest from a file.""" 22 | hash_func = sha256() 23 | byte_array = bytearray(BLOCK_SIZE) 24 | memory_view = memoryview(byte_array) 25 | with open(filename, 'rb', buffering=0) as file: 26 | for block in iter(lambda: file.readinto(memory_view), 0): 27 | hash_func.update(memory_view[:block]) 28 | return hash_func.hexdigest() 29 | 30 | 31 | def process_checkpoint(in_file, out_file): 32 | checkpoint = torch.load(in_file, map_location='cpu') 33 | # remove optimizer for smaller file size 34 | if 'optimizer' in checkpoint: 35 | del checkpoint['optimizer'] 36 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 37 | # add the code here. 38 | torch.save(checkpoint, out_file) 39 | sha = sha256sum(in_file) 40 | final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' 41 | subprocess.Popen(['mv', out_file, final_file]) 42 | 43 | 44 | def main(): 45 | args = parse_args() 46 | process_checkpoint(args.in_file, args.out_file) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /Segmentation/tools/model_converters/beit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmengine 7 | import torch 8 | from mmengine.runner import CheckpointLoader 9 | 10 | 11 | def convert_beit(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | for k, v in ckpt.items(): 15 | if k.startswith('patch_embed'): 16 | new_key = k.replace('patch_embed.proj', 'patch_embed.projection') 17 | new_ckpt[new_key] = v 18 | if k.startswith('blocks'): 19 | new_key = k.replace('blocks', 'layers') 20 | if 'norm' in new_key: 21 | new_key = new_key.replace('norm', 'ln') 22 | elif 'mlp.fc1' in new_key: 23 | new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0') 24 | elif 'mlp.fc2' in new_key: 25 | new_key = new_key.replace('mlp.fc2', 'ffn.layers.1') 26 | new_ckpt[new_key] = v 27 | else: 28 | new_key = k 29 | new_ckpt[new_key] = v 30 | 31 | return new_ckpt 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser( 36 | description='Convert keys in official pretrained beit models to' 37 | 'MMSegmentation style.') 38 | parser.add_argument('src', help='src model path or url') 39 | # The dst path must be a full path of the new checkpoint. 40 | parser.add_argument('dst', help='save path') 41 | args = parser.parse_args() 42 | 43 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 44 | if 'state_dict' in checkpoint: 45 | state_dict = checkpoint['state_dict'] 46 | elif 'model' in checkpoint: 47 | state_dict = checkpoint['model'] 48 | else: 49 | state_dict = checkpoint 50 | weight = convert_beit(state_dict) 51 | mmengine.mkdir_or_exist(osp.dirname(args.dst)) 52 | torch.save(weight, args.dst) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /Segmentation/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /Segmentation/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-4} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /Segmentation/tools/test.sh: -------------------------------------------------------------------------------- 1 | # CONFIG=../configs/Spikeformer/SDTv2_maskformer_DCNpixelDecoder_ade20k.py 2 | CONFIG=../configs/Spikeformer/SDTv2_maskformer_DCNPixelDecoder_CityScapes.py 3 | # CONFIG=../configs/Spikeformer/SDTv2_Spike2former_voc_512x512.py 4 | # CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/SDTv2_maskformer_DCNpixelDecoder_ade20k/best_mIoU_iter_102500.pth' # ADE20k Train 46.3 Test 44.5 () 5 | # CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/Ablation/v2_spike2former_voc2012_1x4/best_mIoU_iter_97500.pth' # VOC2012 Train 76.3 Test (Q_IFNode) 6 | # CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/V2_Spike2former_withoutshortcut/best_mIoU_iter_152500.pth' # ADE20k 1x4 without shortcut 7 | CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/SDTv2_maskformer_DCNPixelDecoder_CityScapes/best_mIoU_iter_47500.pth' # CityScapes 74.2 (Multi-Spikenorm) 8 | 9 | GPUS=1 10 | 11 | PORT=${PORT:-29500} 12 | 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | python -m torch.distributed.launch \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/test.py \ 19 | $CONFIG \ 20 | $CHECKPOINT \ 21 | --launcher pytorch \ 22 | ${@:4} \ 23 | --out ./work_dir/vis_result/CityScapes\ 24 | --show-dir ./work_dirs/vis_result/CityScapes/1x4 25 | --------------------------------------------------------------------------------