├── Figure
├── Readme.md
└── img.png
├── README.md
└── Segmentation
├── .gitignore
├── 0.1.1
├── CITATION.cff
├── Figure
└── img.png
├── LICENSE
├── MANIFEST.in
├── Qtrick_architecture
├── __init__.py
└── clock_driven
│ ├── __init__.py
│ ├── ann2snn
│ ├── __init__.py
│ ├── converter.py
│ ├── examples
│ │ ├── __init__.py
│ │ ├── cnn_mnist.py
│ │ └── resnet18_cifar10.py
│ ├── modules.py
│ └── utils.py
│ ├── base.py
│ ├── configure.py
│ ├── cu_kernel_opt.py
│ ├── encoding.py
│ ├── examples
│ ├── A2C.py
│ ├── DQN_state.py
│ ├── E_Spikingformer.py
│ ├── PPO.py
│ ├── Spiking_A2C.py
│ ├── Spiking_DQN_state.py
│ ├── Spiking_PPO.py
│ ├── __init__.py
│ ├── cifar10_r11_enabling_spikebased_backpropagation.py
│ ├── classify_dvsg.py
│ ├── common
│ │ ├── __init__.py
│ │ └── multiprocessing_env.py
│ ├── conv_fashion_mnist.py
│ ├── dqn_cart_pole.py
│ ├── lif_fc_mnist.py
│ ├── rsnn_sequential_fmnist.py
│ ├── speechcommands.py
│ ├── spiking_lstm_sequential_mnist.py
│ └── spiking_lstm_text.py
│ ├── functional.py
│ ├── lava_exchange.py
│ ├── layer.py
│ ├── model
│ ├── __init__.py
│ ├── parametric_lif_net.py
│ ├── sew_resnet.py
│ ├── spiking_resnet.py
│ ├── spiking_vgg.py
│ ├── train_classify.py
│ └── train_imagenet.py
│ ├── monitor.py
│ ├── neuron.py
│ ├── neuron_kernel.py
│ ├── rnn.py
│ ├── spike_op.py
│ ├── surrogate.py
│ └── tensor_cache.py
├── README.md
├── configs
├── FPN
│ ├── fpn_sdtv2_512x512_ade20k.py
│ └── fpn_sdtv3_512x512_ade20k.py
├── Spike2Former
│ ├── SDTv2_Spike2former_voc_512x512.py
│ ├── SDTv2_maskformer_DCNPixelDecoder_CityScapes.py
│ ├── SDTv2_maskformer_DCNpixelDecoder_ade20k.py
│ ├── SDTv2_maskformer_cocostuff10k_512x512.py
│ ├── SDTv2_maskformer_cocostuff164k_512x512.py
│ ├── SDTv3_b_Spike2former_Cityscapes_512x1024.py
│ ├── SDTv3_b_Spike2former_ade20k_512x512.py
│ ├── SDTv3_b_Spike2former_voc_512x512.py
│ ├── fpn_sdtv3_512x512_10M_ade20k.py
│ └── fpn_sdtv3_512x512_19M_ade20k.py
└── _base_
│ ├── datasets
│ ├── ade20k.py
│ ├── ade20k_640x640.py
│ ├── chase_db1.py
│ ├── cityscapes.py
│ ├── cityscapes_1024x1024.py
│ ├── cityscapes_768x768.py
│ ├── cityscapes_769x769.py
│ ├── cityscapes_832x832.py
│ ├── coco-stuff10k.py
│ ├── coco-stuff164k.py
│ ├── ddd17.py
│ ├── pascal_context.py
│ ├── pascal_context_59.py
│ ├── pascal_voc12.py
│ ├── pascal_voc12_aug.py
│ └── synapse.py
│ ├── default_runtime.py
│ ├── models
│ ├── fpn_snn_r50.py
│ └── snn_sdtv2_fpn.py
│ └── schedules
│ ├── schedule_160k.py
│ ├── schedule_320k.py
│ └── schedule_80k.py
├── dataset-index.yml
├── docs
├── en
│ ├── Makefile
│ ├── _static
│ │ ├── css
│ │ │ └── readthedocs.css
│ │ └── images
│ │ │ └── mmsegmentation.png
│ ├── advanced_guides
│ │ ├── add_datasets.md
│ │ ├── add_metrics.md
│ │ ├── add_models.md
│ │ ├── add_transforms.md
│ │ ├── customize_runtime.md
│ │ ├── data_flow.md
│ │ ├── datasets.md
│ │ ├── engine.md
│ │ ├── evaluation.md
│ │ ├── index.rst
│ │ ├── models.md
│ │ ├── structures.md
│ │ ├── training_tricks.md
│ │ └── transforms.md
│ ├── api.rst
│ ├── conf.py
│ ├── device
│ │ └── npu.md
│ ├── get_started.md
│ ├── index.rst
│ ├── make.bat
│ ├── migration
│ │ ├── index.rst
│ │ ├── interface.md
│ │ └── package.md
│ ├── model_zoo.md
│ ├── modelzoo_statistics.md
│ ├── notes
│ │ ├── changelog.md
│ │ ├── changelog_v0.x.md
│ │ └── faq.md
│ ├── overview.md
│ ├── stat.py
│ ├── switch_language.md
│ └── user_guides
│ │ ├── 1_config.md
│ │ ├── 2_dataset_prepare.md
│ │ ├── 3_inference.md
│ │ ├── 4_train_test.md
│ │ ├── 5_deployment.md
│ │ ├── index.rst
│ │ ├── useful_tools.md
│ │ ├── visualization.md
│ │ └── visualization_feature_map.md
└── zh_cn
│ ├── Makefile
│ ├── _static
│ ├── css
│ │ └── readthedocs.css
│ └── images
│ │ └── mmsegmentation.png
│ ├── advanced_guides
│ ├── add_datasets.md
│ ├── add_metrics.md
│ ├── add_models.md
│ ├── add_transforms.md
│ ├── contribute_dataset.md
│ ├── customize_runtime.md
│ ├── data_flow.md
│ ├── datasets.md
│ ├── engine.md
│ ├── evaluation.md
│ ├── index.rst
│ ├── models.md
│ ├── structures.md
│ ├── training_tricks.md
│ └── transforms.md
│ ├── api.rst
│ ├── conf.py
│ ├── device
│ └── npu.md
│ ├── get_started.md
│ ├── imgs
│ ├── qq_group_qrcode.jpg
│ ├── seggroup_qrcode.jpg
│ └── zhihu_qrcode.jpg
│ ├── index.rst
│ ├── make.bat
│ ├── migration
│ ├── index.rst
│ ├── interface.md
│ └── package.md
│ ├── model_zoo.md
│ ├── modelzoo_statistics.md
│ ├── notes
│ └── faq.md
│ ├── overview.md
│ ├── stat.py
│ ├── switch_language.md
│ └── user_guides
│ ├── 1_config.md
│ ├── 2_dataset_prepare.md
│ ├── 3_inference.md
│ ├── 4_train_test.md
│ ├── 5_deployment.md
│ ├── index.rst
│ ├── useful_tools.md
│ ├── visualization.md
│ └── visualization_feature_map.md
├── mmdet
├── __init__.py
├── models
│ ├── __init__.py
│ ├── dense_heads
│ │ ├── __init__.py
│ │ ├── anchor_free_head.py
│ │ ├── base_dense_head.py
│ │ ├── base_mask_head.py
│ │ └── maskformer_head.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── brick_wrappers.py
│ │ ├── conv_upsample.py
│ │ ├── dropblock.py
│ │ ├── matrix_nms.py
│ │ ├── normed_predictor.py
│ │ ├── pixel_decoder.py
│ │ ├── positional_encoding.py
│ │ └── transformer
│ │ │ ├── Spike2former_layers.py
│ │ │ ├── __init__.py
│ │ │ ├── base_blocks.py
│ │ │ ├── dab_detr_layers.py
│ │ │ ├── deformable_detr_layers.py
│ │ │ ├── detr_layers.py
│ │ │ ├── mask2former_layers.py
│ │ │ ├── mmcv_spike
│ │ │ ├── BASE_Transformer.py
│ │ │ ├── CycleMLP.py
│ │ │ ├── SNN_core.py
│ │ │ ├── __init__.py
│ │ │ ├── base_blocks.py
│ │ │ ├── drop.py
│ │ │ ├── ext_loader.py
│ │ │ ├── multi_scale_deform_attn.py
│ │ │ ├── scale.py
│ │ │ ├── spikeformer.py
│ │ │ └── transformer.py
│ │ │ ├── ops_dcnv3
│ │ │ ├── __init__.py
│ │ │ ├── functions
│ │ │ │ ├── __init__.py
│ │ │ │ └── dcnv3_func.py
│ │ │ ├── make.sh
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ └── dcnv3.py
│ │ │ ├── setup.py
│ │ │ ├── src
│ │ │ │ ├── cpu
│ │ │ │ │ ├── dcnv3_cpu.cpp
│ │ │ │ │ └── dcnv3_cpu.h
│ │ │ │ ├── cuda
│ │ │ │ │ ├── dcnv3_cuda.cu
│ │ │ │ │ ├── dcnv3_cuda.h
│ │ │ │ │ └── dcnv3_im2col_cuda.cuh
│ │ │ │ ├── dcnv3.h
│ │ │ │ └── vision.cpp
│ │ │ └── test.py
│ │ │ └── utils.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── accuracy.py
│ │ ├── cross_entropy_loss.py
│ │ ├── dice_loss.py
│ │ ├── focal_loss.py
│ │ ├── iou_loss.py
│ │ ├── mse_loss.py
│ │ ├── multipos_cross_entropy_loss.py
│ │ └── utils.py
│ ├── task_modules
│ │ ├── __init__.py
│ │ ├── assigners
│ │ │ ├── __init__.py
│ │ │ ├── approx_max_iou_assigner.py
│ │ │ ├── assign_result.py
│ │ │ ├── atss_assigner.py
│ │ │ ├── base_assigner.py
│ │ │ ├── center_region_assigner.py
│ │ │ ├── dynamic_soft_label_assigner.py
│ │ │ ├── grid_assigner.py
│ │ │ ├── hungarian_assigner.py
│ │ │ ├── iou2d_calculator.py
│ │ │ ├── match_cost.py
│ │ │ ├── max_iou_assigner.py
│ │ │ ├── multi_instance_assigner.py
│ │ │ ├── point_assigner.py
│ │ │ ├── region_assigner.py
│ │ │ ├── sim_ota_assigner.py
│ │ │ ├── task_aligned_assigner.py
│ │ │ └── uniform_assigner.py
│ │ ├── builder.py
│ │ ├── prior_generators
│ │ │ ├── __init__.py
│ │ │ ├── anchor_generator.py
│ │ │ ├── point_generator.py
│ │ │ └── utils.py
│ │ └── samplers
│ │ │ ├── __init__.py
│ │ │ ├── base_sampler.py
│ │ │ ├── combined_sampler.py
│ │ │ ├── instance_balanced_pos_sampler.py
│ │ │ ├── iou_balanced_neg_sampler.py
│ │ │ ├── mask_pseudo_sampler.py
│ │ │ ├── mask_sampling_result.py
│ │ │ ├── multi_instance_random_sampler.py
│ │ │ ├── multi_instance_sampling_result.py
│ │ │ ├── ohem_sampler.py
│ │ │ ├── pseudo_sampler.py
│ │ │ ├── random_sampler.py
│ │ │ ├── sampling_result.py
│ │ │ └── score_hlr_sampler.py
│ ├── test_time_augs
│ │ ├── __init__.py
│ │ ├── det_tta.py
│ │ └── merge_augs.py
│ ├── tracking_heads
│ │ ├── __init__.py
│ │ └── mask2former_track_head.py
│ └── utils
│ │ ├── Qtrick.py
│ │ ├── __init__.py
│ │ ├── gaussian_target.py
│ │ ├── image.py
│ │ ├── make_divisible.py
│ │ ├── misc.py
│ │ ├── panoptic_gt_processing.py
│ │ ├── point_sample.py
│ │ └── vlfuse_helper.py
├── registry.py
├── structures
│ ├── __init__.py
│ ├── bbox
│ │ ├── __init__.py
│ │ ├── base_boxes.py
│ │ ├── bbox_overlaps.py
│ │ ├── box_type.py
│ │ ├── horizontal_boxes.py
│ │ └── transforms.py
│ ├── det_data_sample.py
│ ├── mask
│ │ ├── __init__.py
│ │ ├── mask_target.py
│ │ ├── structures.py
│ │ └── utils.py
│ ├── reid_data_sample.py
│ └── track_data_sample.py
├── utils
│ ├── __init__.py
│ ├── benchmark.py
│ ├── collect_env.py
│ ├── compat_config.py
│ ├── contextmanagers.py
│ ├── dist_utils.py
│ ├── logger.py
│ ├── memory.py
│ ├── misc.py
│ ├── mot_error_visualize.py
│ ├── profiling.py
│ ├── replace_cfg_vals.py
│ ├── setup_env.py
│ ├── split_batch.py
│ ├── typing_utils.py
│ ├── util_mixins.py
│ └── util_random.py
└── version.py
├── mmseg
├── __init__.py
├── apis
│ ├── __init__.py
│ ├── inference.py
│ └── mmseg_inferencer.py
├── datasets
│ ├── __init__.py
│ ├── ade.py
│ ├── basesegdataset.py
│ ├── chase_db1.py
│ ├── cityscapes.py
│ ├── coco_stuff.py
│ ├── dataset_wrappers.py
│ ├── ddd17.py
│ ├── drive.py
│ ├── pascal_context.py
│ ├── synapse.py
│ ├── transforms
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ ├── loading.py
│ │ └── transforms.py
│ └── voc.py
├── engine
│ ├── __init__.py
│ ├── hooks
│ │ ├── __init__.py
│ │ ├── cal_firing_rate.py
│ │ ├── resetmodel_hook.py
│ │ └── visualization_hook.py
│ └── optimizers
│ │ ├── __init__.py
│ │ └── layer_decay_optimizer_constructor.py
├── evaluation
│ ├── __init__.py
│ └── metrics
│ │ ├── __init__.py
│ │ ├── citys_metric.py
│ │ └── iou_metric.py
├── models
│ ├── __init__.py
│ ├── backbones
│ │ ├── MSResnet.py
│ │ ├── __init__.py
│ │ ├── sdtv2.py
│ │ ├── sdtv3.py
│ │ ├── sdtv3MAE.py
│ │ └── spike-sdtv3.py
│ ├── builder.py
│ ├── data_preprocessor.py
│ ├── decode_heads
│ │ ├── __init__.py
│ │ ├── anchor_free_head.py
│ │ ├── base_dense_head.py
│ │ ├── cascade_decode_head.py
│ │ ├── decode_head.py
│ │ ├── fpn_head.py
│ │ ├── mask2former_head.py
│ │ └── maskformer_head.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── accuracy.py
│ │ ├── boundary_loss.py
│ │ ├── cross_entropy_loss.py
│ │ ├── dice_loss.py
│ │ ├── focal_loss.py
│ │ ├── huasdorff_distance_loss.py
│ │ ├── lovasz_loss.py
│ │ ├── ohem_cross_entropy_loss.py
│ │ ├── tversky_loss.py
│ │ └── utils.py
│ ├── necks
│ │ ├── __init__.py
│ │ └── fpn.py
│ ├── segmentors
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── cascade_encoder_decoder.py
│ │ ├── encoder_decoder.py
│ │ └── seg_tta.py
│ └── utils
│ │ ├── Qtrick.py
│ │ ├── __init__.py
│ │ ├── basic_block.py
│ │ ├── embed.py
│ │ ├── encoding.py
│ │ ├── inverted_residual.py
│ │ ├── make_divisible.py
│ │ ├── misc.py
│ │ ├── panoptic_gt_processing.py
│ │ ├── point_sample.py
│ │ ├── ppm.py
│ │ ├── res_layer.py
│ │ ├── se_layer.py
│ │ ├── self_attention_block.py
│ │ ├── shape_convert.py
│ │ ├── up_conv_block.py
│ │ └── wrappers.py
├── registry
│ ├── __init__.py
│ └── registry.py
├── structures
│ ├── __init__.py
│ ├── sampler
│ │ ├── __init__.py
│ │ ├── base_pixel_sampler.py
│ │ ├── builder.py
│ │ └── ohem_pixel_sampler.py
│ └── seg_data_sample.py
├── utils
│ ├── __init__.py
│ ├── class_names.py
│ ├── collect_env.py
│ ├── io.py
│ ├── misc.py
│ ├── panoptic_gt_processing.py
│ ├── set_env.py
│ └── typing_utils.py
├── version.py
└── visualization
│ ├── __init__.py
│ └── local_visualizer.py
├── model-index.yml
├── requirements.txt
├── requirements
├── albu.txt
├── docs.txt
├── mminstall.txt
├── optional.txt
├── readthedocs.txt
├── runtime.txt
└── tests.txt
├── setup.cfg
├── setup.py
├── tests
├── __init__.py
├── test_apis
│ └── test_inferencer.py
├── test_config.py
├── test_datasets
│ ├── test_dataset.py
│ ├── test_dataset_builder.py
│ ├── test_formatting.py
│ ├── test_loading.py
│ ├── test_transform.py
│ └── test_tta.py
├── test_digit_version.py
├── test_engine
│ ├── test_layer_decay_optimizer_constructor.py
│ ├── test_optimizer.py
│ └── test_visualization_hook.py
├── test_evaluation
│ └── test_metrics
│ │ ├── test_citys_metric.py
│ │ └── test_iou_metric.py
├── test_models
│ ├── __init__.py
│ ├── test_backbones
│ │ ├── __init__.py
│ │ ├── test_beit.py
│ │ ├── test_bisenetv1.py
│ │ ├── test_bisenetv2.py
│ │ ├── test_blocks.py
│ │ ├── test_cgnet.py
│ │ ├── test_erfnet.py
│ │ ├── test_fast_scnn.py
│ │ ├── test_hrnet.py
│ │ ├── test_icnet.py
│ │ ├── test_mae.py
│ │ ├── test_mit.py
│ │ ├── test_mobilenet_v3.py
│ │ ├── test_mscan.py
│ │ ├── test_pidnet.py
│ │ ├── test_resnest.py
│ │ ├── test_resnet.py
│ │ ├── test_resnext.py
│ │ ├── test_stdc.py
│ │ ├── test_swin.py
│ │ ├── test_timm_backbone.py
│ │ ├── test_twins.py
│ │ ├── test_unet.py
│ │ ├── test_vit.py
│ │ └── utils.py
│ ├── test_data_preprocessor.py
│ ├── test_forward.py
│ ├── test_heads
│ │ ├── __init__.py
│ │ ├── test_ann_head.py
│ │ ├── test_apc_head.py
│ │ ├── test_aspp_head.py
│ │ ├── test_cc_head.py
│ │ ├── test_decode_head.py
│ │ ├── test_dm_head.py
│ │ ├── test_dnl_head.py
│ │ ├── test_dpt_head.py
│ │ ├── test_ema_head.py
│ │ ├── test_fcn_head.py
│ │ ├── test_gc_head.py
│ │ ├── test_ham_head.py
│ │ ├── test_isa_head.py
│ │ ├── test_lraspp_head.py
│ │ ├── test_mask2former_head.py
│ │ ├── test_maskformer_head.py
│ │ ├── test_nl_head.py
│ │ ├── test_ocr_head.py
│ │ ├── test_pidnet_head.py
│ │ ├── test_psa_head.py
│ │ ├── test_psp_head.py
│ │ ├── test_segformer_head.py
│ │ ├── test_segmenter_mask_head.py
│ │ ├── test_setr_mla_head.py
│ │ ├── test_setr_up_head.py
│ │ ├── test_uper_head.py
│ │ └── utils.py
│ ├── test_losses
│ │ ├── test_dice_loss.py
│ │ ├── test_huasdorff_distance_loss.py
│ │ └── test_tversky_loss.py
│ ├── test_necks
│ │ ├── __init__.py
│ │ ├── test_feature2pyramid.py
│ │ ├── test_fpn.py
│ │ ├── test_ic_neck.py
│ │ ├── test_jpu.py
│ │ ├── test_mla_neck.py
│ │ └── test_multilevel_neck.py
│ ├── test_segmentors
│ │ ├── __init__.py
│ │ ├── test_cascade_encoder_decoder.py
│ │ ├── test_encoder_decoder.py
│ │ ├── test_seg_tta_model.py
│ │ └── utils.py
│ └── test_utils
│ │ ├── __init__.py
│ │ ├── test_embed.py
│ │ └── test_shape_convert.py
├── test_sampler.py
├── test_structures
│ └── test_seg_data_sample.py
├── test_utils
│ ├── test_io.py
│ └── test_set_env.py
└── test_visualization
│ └── test_local_visualizer.py
└── tools
├── Calculation_tools.py
├── analysis_tools
├── analyze_logs.py
├── benchmark.py
├── browse_dataset.py
├── confusion_matrix.py
├── get_flops.py
├── profile.py
└── utils.py
├── cal_firing_num.py
├── dataset_converters
├── chase_db1.py
├── cityscapes.py
├── coco_stuff10k.py
├── coco_stuff164k.py
├── drive.py
├── hrf.py
├── isaid.py
├── levircd.py
├── loveda.py
├── pascal_context.py
├── potsdam.py
├── pro_gen1.py
├── prophesee
│ ├── __init__.py
│ ├── io
│ │ ├── __init__.py
│ │ ├── box_filtering.py
│ │ ├── box_loading.py
│ │ ├── dat_events_tools.py
│ │ ├── npy_events_tools.py
│ │ └── psee_loader.py
│ ├── metrics
│ │ ├── __init__.py
│ │ └── coco_eval.py
│ ├── psee_evaluator.py
│ └── visualize
│ │ ├── __init__.py
│ │ └── vis_utils.py
├── refuge.py
├── stare.py
├── synapse.py
├── vaihingen.py
└── voc_aug.py
├── deployment
└── pytorch2torchscript.py
├── dist_test.sh
├── dist_train.sh
├── firing_utils
├── dataset.py
└── misc.py
├── misc
├── browse_dataset.py
├── print_config.py
└── publish_model.py
├── model_converters
├── beit2mmseg.py
├── mit2mmseg.py
├── stdc2mmseg.py
├── swin2mmseg.py
├── twins2mmseg.py
├── vit2mmseg.py
└── vitjax2mmseg.py
├── slurm_test.sh
├── slurm_train.sh
├── test.py
├── test.sh
└── train.py
/Figure/Readme.md:
--------------------------------------------------------------------------------
1 | Put Figure Here
2 |
--------------------------------------------------------------------------------
/Figure/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Figure/img.png
--------------------------------------------------------------------------------
/Segmentation/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/en/_build/
68 | docs/zh_cn/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 | .DS_Store
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | data
109 | .vscode
110 | .idea
111 |
112 | # custom
113 | *.pkl
114 | *.pkl.json
115 | *.log.json
116 | work_dirs/
117 | mmseg/.mim
118 |
119 | # Pytorch
120 | *.pth
121 |
--------------------------------------------------------------------------------
/Segmentation/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - name: "MMSegmentation Contributors"
5 | title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark"
6 | date-released: 2020-07-10
7 | url: "https://github.com/open-mmlab/mmsegmentation"
8 | license: Apache-2.0
9 |
--------------------------------------------------------------------------------
/Segmentation/Figure/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Figure/img.png
--------------------------------------------------------------------------------
/Segmentation/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements/*.txt
2 | include mmseg/.mim/model-index.yml
3 | include mmaction/.mim/dataset-index.yml
4 | recursive-include mmseg/.mim/configs *.py *.yaml
5 | recursive-include mmseg/.mim/tools *.py *.sh
6 |
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/__init__.py:
--------------------------------------------------------------------------------
1 | from . import clock_driven
2 |
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/__init__.py
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/ann2snn/__init__.py:
--------------------------------------------------------------------------------
1 | from spikingjelly.clock_driven.ann2snn.converter import Converter
2 | from spikingjelly.clock_driven.ann2snn.utils import download_url
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/ann2snn/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/ann2snn/examples/__init__.py
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/ann2snn/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import os
3 | from tqdm import tqdm
4 |
5 | def download_url(url, dst):
6 | headers = {
7 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:67.0) Gecko/20100101 Firefox/67.0'
8 | }
9 |
10 | response = requests.get(url, headers=headers, stream=True) # (1)
11 | file_size = int(response.headers['content-length']) # (2)
12 | if os.path.exists(dst):
13 | first_byte = os.path.getsize(dst) # (3)
14 | else:
15 | first_byte = 0
16 | if first_byte >= file_size: # (4)
17 | return file_size
18 |
19 | header = {"Range": f"bytes={first_byte}-{file_size}"}
20 |
21 | pbar = tqdm(total=file_size, initial=first_byte, unit='B', unit_scale=True, desc=dst)
22 | req = requests.get(url, headers=header, stream=True) # (5)
23 | with open(dst, 'ab') as f:
24 | for chunk in req.iter_content(chunk_size=1024): # (6)
25 | if chunk:
26 | f.write(chunk)
27 | pbar.update(1024)
28 | pbar.close()
29 | return file_size
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/examples/__init__.py
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/examples/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/examples/common/__init__.py
--------------------------------------------------------------------------------
/Segmentation/Qtrick_architecture/clock_driven/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/Qtrick_architecture/clock_driven/model/__init__.py
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/cityscapes_1024x1024.py:
--------------------------------------------------------------------------------
1 | _base_ = './cityscapes.py'
2 | crop_size = (1024, 1024)
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile'),
5 | dict(type='LoadAnnotations'),
6 | dict(
7 | type='RandomResize',
8 | scale=(2048, 1024),
9 | ratio_range=(0.5, 2.0),
10 | keep_ratio=True),
11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12 | dict(type='RandomFlip', prob=0.5),
13 | dict(type='PhotoMetricDistortion'),
14 | dict(type='PackSegInputs')
15 | ]
16 | test_pipeline = [
17 | dict(type='LoadImageFromFile'),
18 | dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
19 | # add loading annotation after ``Resize`` because ground truth
20 | # does not need to do resize data transform
21 | dict(type='LoadAnnotations'),
22 | dict(type='PackSegInputs')
23 | ]
24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
26 | test_dataloader = val_dataloader
27 |
28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
29 | test_evaluator = val_evaluator
30 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/cityscapes_768x768.py:
--------------------------------------------------------------------------------
1 | _base_ = './cityscapes.py'
2 | crop_size = (768, 768)
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile'),
5 | dict(type='LoadAnnotations'),
6 | dict(
7 | type='RandomResize',
8 | scale=(2049, 1025),
9 | ratio_range=(0.5, 2.0),
10 | keep_ratio=True),
11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12 | dict(type='RandomFlip', prob=0.5),
13 | dict(type='PhotoMetricDistortion'),
14 | dict(type='PackSegInputs')
15 | ]
16 | test_pipeline = [
17 | dict(type='LoadImageFromFile'),
18 | dict(type='Resize', scale=(2049, 1025), keep_ratio=True),
19 | # add loading annotation after ``Resize`` because ground truth
20 | # does not need to do resize data transform
21 | dict(type='LoadAnnotations'),
22 | dict(type='PackSegInputs')
23 | ]
24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
26 | test_dataloader = val_dataloader
27 |
28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
29 | test_evaluator = val_evaluator
30 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/cityscapes_769x769.py:
--------------------------------------------------------------------------------
1 | _base_ = './cityscapes.py'
2 | crop_size = (769, 769)
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile'),
5 | dict(type='LoadAnnotations'),
6 | dict(
7 | type='RandomResize',
8 | scale=(2049, 1025),
9 | ratio_range=(0.5, 2.0),
10 | keep_ratio=True),
11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12 | dict(type='RandomFlip', prob=0.5),
13 | dict(type='PhotoMetricDistortion'),
14 | dict(type='PackSegInputs')
15 | ]
16 | test_pipeline = [
17 | dict(type='LoadImageFromFile'),
18 | dict(type='Resize', scale=(2049, 1025), keep_ratio=True),
19 | # add loading annotation after ``Resize`` because ground truth
20 | # does not need to do resize data transform
21 | dict(type='LoadAnnotations'),
22 | dict(type='PackSegInputs')
23 | ]
24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
26 | test_dataloader = val_dataloader
27 |
28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
29 | test_evaluator = val_evaluator
30 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/cityscapes_832x832.py:
--------------------------------------------------------------------------------
1 | _base_ = './cityscapes.py'
2 | crop_size = (832, 832)
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile'),
5 | dict(type='LoadAnnotations'),
6 | dict(
7 | type='RandomResize',
8 | scale=(2048, 1024),
9 | ratio_range=(0.5, 2.0),
10 | keep_ratio=True),
11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12 | dict(type='RandomFlip', prob=0.5),
13 | dict(type='PhotoMetricDistortion'),
14 | dict(type='PackSegInputs')
15 | ]
16 | test_pipeline = [
17 | dict(type='LoadImageFromFile'),
18 | dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
19 | # add loading annotation after ``Resize`` because ground truth
20 | # does not need to do resize data transform
21 | dict(type='LoadAnnotations'),
22 | dict(type='PackSegInputs')
23 | ]
24 | train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
25 | val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
26 | test_dataloader = val_dataloader
27 |
28 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
29 | test_evaluator = val_evaluator
30 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/ddd17.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'DDD17Dataset'
3 |
4 | data_root = '/public/liguoqi/lzx/data/ddd17_seg/T4'
5 |
6 | crop_size = (200, 352)
7 | train_pipeline = [
8 | dict(type='LoadImageFromNpyFile'),
9 | dict(type='LoadAnnotations'),
10 | # dict(type='RandomCrop', crop_size=crop_size),
11 | dict(type='PackSegInputs')
12 | ]
13 | test_pipeline = [
14 | dict(type='LoadImageFromNpyFile'),
15 | # add loading annotation after ``Resize`` because ground truth
16 | # does not need to do resize data transform
17 | dict(type='LoadAnnotations'),
18 | # dict(type='RandomCrop', crop_size=crop_size),
19 | dict(type='PackSegInputs')
20 | ]
21 | img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
22 | # tta_pipeline = [
23 | # dict(type='LoadImageFromFile', backend_args=None),
24 | # dict(
25 | # type='TestTimeAug',
26 | # transforms=[
27 | # [
28 | # dict(type='Resize', scale_factor=r, keep_ratio=True)
29 | # for r in img_ratios
30 | # ],
31 | # [
32 | # dict(type='RandomFlip', prob=0., direction='horizontal'),
33 | # dict(type='RandomFlip', prob=1., direction='horizontal')
34 | # ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
35 | # ])
36 | # ]
37 | train_dataloader = dict(
38 | batch_size=6,
39 | num_workers=16,
40 | persistent_workers=True,
41 | sampler=dict(type='InfiniteSampler', shuffle=True),
42 | dataset=dict(
43 | type=dataset_type,
44 | data_root=data_root,
45 | data_prefix=dict(
46 | img_path='images/training', seg_map_path='annotations/training'),
47 | pipeline=train_pipeline))
48 | val_dataloader = dict(
49 | batch_size=1,
50 | num_workers=16,
51 | persistent_workers=True,
52 | sampler=dict(type='DefaultSampler', shuffle=False),
53 | dataset=dict(
54 | type=dataset_type,
55 | data_root=data_root,
56 | data_prefix=dict(
57 | img_path='images/validation',
58 | seg_map_path='annotations/validation'),
59 | pipeline=test_pipeline))
60 | test_dataloader = val_dataloader
61 |
62 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
63 | test_evaluator = val_evaluator
64 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/pascal_context.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'PascalContextDataset'
3 | data_root = 'data/VOCdevkit/VOC2010/'
4 |
5 | img_scale = (520, 520)
6 | crop_size = (480, 480)
7 |
8 | train_pipeline = [
9 | dict(type='LoadImageFromFile'),
10 | dict(type='LoadAnnotations'),
11 | dict(
12 | type='RandomResize',
13 | scale=img_scale,
14 | ratio_range=(0.5, 2.0),
15 | keep_ratio=True),
16 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
17 | dict(type='RandomFlip', prob=0.5),
18 | dict(type='PhotoMetricDistortion'),
19 | dict(type='PackSegInputs')
20 | ]
21 | test_pipeline = [
22 | dict(type='LoadImageFromFile'),
23 | dict(type='Resize', scale=img_scale, keep_ratio=True),
24 | # add loading annotation after ``Resize`` because ground truth
25 | # does not need to do resize data transform
26 | dict(type='LoadAnnotations'),
27 | dict(type='PackSegInputs')
28 | ]
29 | train_dataloader = dict(
30 | batch_size=4,
31 | num_workers=4,
32 | persistent_workers=True,
33 | sampler=dict(type='InfiniteSampler', shuffle=True),
34 | dataset=dict(
35 | type=dataset_type,
36 | data_root=data_root,
37 | data_prefix=dict(
38 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
39 | ann_file='ImageSets/SegmentationContext/train.txt',
40 | pipeline=train_pipeline))
41 | val_dataloader = dict(
42 | batch_size=1,
43 | num_workers=4,
44 | persistent_workers=True,
45 | sampler=dict(type='DefaultSampler', shuffle=False),
46 | dataset=dict(
47 | type=dataset_type,
48 | data_root=data_root,
49 | data_prefix=dict(
50 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
51 | ann_file='ImageSets/SegmentationContext/val.txt',
52 | pipeline=test_pipeline))
53 | test_dataloader = val_dataloader
54 |
55 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
56 | test_evaluator = val_evaluator
57 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/datasets/synapse.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'SynapseDataset'
2 | data_root = 'data/synapse/'
3 | img_scale = (224, 224)
4 | train_pipeline = [
5 | dict(type='LoadImageFromFile'),
6 | dict(type='LoadAnnotations'),
7 | dict(type='Resize', scale=img_scale, keep_ratio=True),
8 | dict(type='RandomRotFlip', rotate_prob=0.5, flip_prob=0.5, degree=20),
9 | dict(type='PackSegInputs')
10 | ]
11 | test_pipeline = [
12 | dict(type='LoadImageFromFile'),
13 | dict(type='Resize', scale=img_scale, keep_ratio=True),
14 | dict(type='LoadAnnotations'),
15 | dict(type='PackSegInputs')
16 | ]
17 | train_dataloader = dict(
18 | batch_size=6,
19 | num_workers=2,
20 | persistent_workers=True,
21 | sampler=dict(type='InfiniteSampler', shuffle=True),
22 | dataset=dict(
23 | type=dataset_type,
24 | data_root=data_root,
25 | data_prefix=dict(
26 | img_path='img_dir/train', seg_map_path='ann_dir/train'),
27 | pipeline=train_pipeline))
28 | val_dataloader = dict(
29 | batch_size=1,
30 | num_workers=4,
31 | persistent_workers=True,
32 | sampler=dict(type='DefaultSampler', shuffle=False),
33 | dataset=dict(
34 | type=dataset_type,
35 | data_root=data_root,
36 | data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
37 | pipeline=test_pipeline))
38 | test_dataloader = val_dataloader
39 |
40 | val_evaluator = dict(type='IoUMetric', iou_metrics=['mDice'])
41 | test_evaluator = val_evaluator
42 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | default_scope = 'mmseg'
2 | env_cfg = dict(
3 | cudnn_benchmark=True,
4 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
5 | dist_cfg=dict(backend='nccl'),
6 | )
7 | vis_backends = [dict(type='LocalVisBackend')]
8 |
9 | # vis_backends = []
10 | visualizer = dict(
11 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
12 |
13 |
14 | log_processor = dict(by_epoch=False)
15 | log_level = 'INFO'
16 | load_from = None
17 | resume = False
18 | custom_imports = dict(imports=['mmseg.engine.hooks.resetmodel_hook'], allow_failed_imports=False)
19 | custom_hooks = [
20 | dict(type='ResetModelHook')
21 | ]
22 |
23 | tta_model = dict(type='SegTTAModel')
24 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/models/fpn_snn_r50.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | norm_cfg = dict(type='SyncBN', requires_grad=True)
3 | data_preprocessor = dict(
4 | type='SegDataPreProcessor',
5 | mean=[123.675, 116.28, 103.53],
6 | std=[58.395, 57.12, 57.375],
7 | bgr_to_rgb=True,
8 | pad_val=0,
9 | seg_pad_val=255)
10 | model = dict(
11 | type='EncoderDecoder',
12 | data_preprocessor=data_preprocessor,
13 | # pretrained='open-mmlab://resnet50_v1c',
14 | backbone=dict(
15 | type='ResNetV1c',
16 | ),
17 | neck=dict(
18 | type='QFPN',
19 | in_channels=[256, 512, 1024, 2048],
20 | out_channels=256,
21 | num_outs=4,
22 | T=4,
23 | ),
24 | decode_head=dict(
25 | type='QFPNHead',
26 | in_channels=[256, 256, 256, 256],
27 | in_index=[0, 1, 2, 3],
28 | feature_strides=[4, 8, 16, 32],
29 | channels=128,
30 | dropout_ratio=0.1,
31 | num_classes=150,
32 | T=4,
33 | norm_cfg=norm_cfg,
34 | align_corners=False,
35 | loss_decode=dict(
36 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
37 | # model training and testing settings
38 | train_cfg=dict(),
39 | test_cfg=dict(mode='whole'))
40 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/models/snn_sdtv2_fpn.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | norm_cfg = dict(type='SyncBN', requires_grad=True)
3 | data_preprocessor = dict(
4 | type='SegDataPreProcessor',
5 | mean=[123.675, 116.28, 103.53],
6 | std=[58.395, 57.12, 57.375],
7 | bgr_to_rgb=True,
8 | pad_val=0,
9 | seg_pad_val=255)
10 | model = dict(
11 | type='EncoderDecoder',
12 | data_preprocessor=data_preprocessor,
13 | pretrained='/raid/ligq/lzx/spikeformerv2/seg/checkpoint/checkpoint-199.pth',
14 | backbone=dict(
15 | # init_cfg=dict(type='Pretrained', checkpoint="/raid/ligq/lzx/spikeformerv2/seg/checkpoint/checkpoint-199.pth"),
16 | type='Sdtv2',
17 | img_size_h=512,
18 | img_size_w=512,
19 | embed_dim=[128, 256, 512, 640],
20 | num_classes=150,
21 | T=1,
22 | qkv_bias=False,
23 | decode_mode='snn',
24 | ),
25 | neck=dict(
26 | type='FPN_SNN',
27 | in_channels=[256, 512, 1024, 2048],
28 | out_channels=256,
29 | num_outs=4),
30 | decode_head=dict(
31 | type='FPNHead_SNN',
32 | in_channels=[256, 256, 256, 256],
33 | in_index=[0, 1, 2, 3],
34 | feature_strides=[4, 8, 16, 32],
35 | channels=128,
36 | dropout_ratio=0.1,
37 | num_classes=150,
38 | norm_cfg=norm_cfg,
39 | align_corners=False,
40 | loss_decode=dict(
41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
42 | # model training and testing settings
43 | train_cfg=dict(),
44 | test_cfg=dict(mode='whole'))
45 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/schedules/schedule_160k.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3 | optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
4 | # learning policy
5 | param_scheduler = [
6 | dict(
7 | type='PolyLR',
8 | eta_min=1e-5,
9 | power=0.9,
10 | begin=0,
11 | end=160000,
12 | by_epoch=False)
13 | ]
14 | # training schedule for 160k
15 | train_cfg = dict(
16 | type='IterBasedTrainLoop', max_iters=160000, val_interval=16000)
17 | val_cfg = dict(type='ValLoop')
18 | test_cfg = dict(type='TestLoop')
19 | default_hooks = dict(
20 | timer=dict(type='IterTimerHook'),
21 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
22 | param_scheduler=dict(type='ParamSchedulerHook'),
23 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=16000),
24 | sampler_seed=dict(type='DistSamplerSeedHook'),
25 | visualization=dict(type='SegVisualizationHook', draw=True, interval=1))
26 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/schedules/schedule_320k.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3 | optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
4 | # learning policy
5 | param_scheduler = [
6 | dict(
7 | type='PolyLR',
8 | eta_min=1e-4,
9 | power=0.9,
10 | begin=0,
11 | end=320000,
12 | by_epoch=False)
13 | ]
14 | # training schedule for 320k
15 | train_cfg = dict(
16 | type='IterBasedTrainLoop', max_iters=320000, val_interval=32000)
17 | val_cfg = dict(type='ValLoop')
18 | test_cfg = dict(type='TestLoop')
19 | default_hooks = dict(
20 | timer=dict(type='IterTimerHook'),
21 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
22 | param_scheduler=dict(type='ParamSchedulerHook'),
23 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=32000),
24 | sampler_seed=dict(type='DistSamplerSeedHook'),
25 | visualization=dict(type='SegVisualizationHook'))
26 |
--------------------------------------------------------------------------------
/Segmentation/configs/_base_/schedules/schedule_80k.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3 | optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
4 | # learning policy
5 | param_scheduler = [
6 | dict(
7 | type='PolyLR',
8 | eta_min=1e-4,
9 | power=0.9,
10 | begin=0,
11 | end=80000,
12 | by_epoch=False)
13 | ]
14 | # training schedule for 80k
15 | train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
16 | val_cfg = dict(type='ValLoop')
17 | test_cfg = dict(type='TestLoop')
18 | default_hooks = dict(
19 | timer=dict(type='IterTimerHook'),
20 | logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
21 | param_scheduler=dict(type='ParamSchedulerHook'),
22 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
23 | sampler_seed=dict(type='DistSamplerSeedHook'),
24 | visualization=dict(type='SegVisualizationHook'))
25 |
--------------------------------------------------------------------------------
/Segmentation/dataset-index.yml:
--------------------------------------------------------------------------------
1 | ade20k:
2 | dataset: ADE20K_2016
3 | download_root: data
4 | data_root: data/ade
5 |
6 | cityscapes:
7 | dataset: CityScapes
8 | download_root: data
9 | data_root: data/cityscapes
10 |
11 | voc2012:
12 | dataset: PASCAL_VOC2012
13 | download_root: data
14 | data_root: data/VOCdevkit/VOC2012
15 |
16 | cocostuff:
17 | dataset: COCO-Stuff
18 | download_root: data
19 | data_root: data/coco_stuff164k
20 |
21 | mapillary:
22 | dataset: Mapillary
23 | download_root: data
24 | data_root: data/mapillary
25 |
26 | pascal_context:
27 | dataset: VOC2010
28 | download_root: data
29 | data_root: data/VOCdevkit/VOC2010
30 |
31 | isaid:
32 | dataset: iSAID
33 | download_root: data
34 | data_root: data/iSAID
35 |
36 | isprs_potsdam:
37 | dataset: ISPRS_Potsdam
38 | download_root: data
39 | data_root: data/potsdam
40 |
41 | loveda:
42 | dataset: LoveDA
43 | download_root: data
44 | data_root: data/loveDA
45 |
46 | chase_db1:
47 | dataset: CHASE_DB1
48 | download_root: data
49 | data_root: data/CHASE_DB1
50 |
51 | drive:
52 | dataset: DRIVE
53 | download_root: data
54 | data_root: data/DRIVE
55 |
56 | hrf:
57 | dataset: HRF
58 | download_root: data
59 | data_root: data/HRF
60 |
61 | stare:
62 | dataset: STARE
63 | download_root: data
64 | data_root: data/STARE
65 |
66 | synapse:
67 | dataset: SurgVisDom
68 | download_root: data
69 | data_root: data/synapse
70 |
71 | refuge:
72 | dataset: REFUGE_Challenge
73 | download_root: data
74 | data_root: data/REFUGE
75 |
76 | lip:
77 | dataset: LIP
78 | download_root: data
79 | data_root: data/LIP
80 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/_static/css/readthedocs.css:
--------------------------------------------------------------------------------
1 | .header-logo {
2 | background-image: url("../images/mmsegmentation.png");
3 | background-size: 201px 40px;
4 | height: 40px;
5 | width: 201px;
6 | }
7 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/_static/images/mmsegmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/en/_static/images/mmsegmentation.png
--------------------------------------------------------------------------------
/Segmentation/docs/en/advanced_guides/add_transforms.md:
--------------------------------------------------------------------------------
1 | # Adding New Data Transforms
2 |
3 | ## Customization data transformation
4 |
5 | The customized data transformation must inherited from `BaseTransform` and implement `transform` function.
6 | Here we use a simple flipping transformation as example:
7 |
8 | ```python
9 | import random
10 | import mmcv
11 | from mmcv.transforms import BaseTransform, TRANSFORMS
12 |
13 | @TRANSFORMS.register_module()
14 | class MyFlip(BaseTransform):
15 | def __init__(self, direction: str):
16 | super().__init__()
17 | self.direction = direction
18 |
19 | def transform(self, results: dict) -> dict:
20 | img = results['img']
21 | results['img'] = mmcv.imflip(img, direction=self.direction)
22 | return results
23 | ```
24 |
25 | Moreover, import the new class.
26 |
27 | ```python
28 | from .my_pipeline import MyFlip
29 | ```
30 |
31 | Thus, we can instantiate a `MyFlip` object and use it to process the data dict.
32 |
33 | ```python
34 | import numpy as np
35 |
36 | transform = MyFlip(direction='horizontal')
37 | data_dict = {'img': np.random.rand(224, 224, 3)}
38 | data_dict = transform(data_dict)
39 | processed_img = data_dict['img']
40 | ```
41 |
42 | Or, we can use `MyFlip` transformation in data pipeline in our config file.
43 |
44 | ```python
45 | pipeline = [
46 | ...
47 | dict(type='MyFlip', direction='horizontal'),
48 | ...
49 | ]
50 | ```
51 |
52 | Note that if you want to use `MyFlip` in config, you must ensure the file containing `MyFlip` is imported during runtime.
53 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/advanced_guides/index.rst:
--------------------------------------------------------------------------------
1 | Basic Concepts
2 | ***************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | data_flow.md
8 | structures.md
9 | models.md
10 | datasets.md
11 | transforms.md
12 | evaluation.md
13 | engine.md
14 | training_tricks.md
15 |
16 | Component Customization
17 | ************************
18 |
19 | .. toctree::
20 | :maxdepth: 1
21 |
22 | add_models.md
23 | add_datasets.md
24 | add_transforms.md
25 | add_metrics.md
26 | customize_runtime.md
27 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/api.rst:
--------------------------------------------------------------------------------
1 | mmseg.apis
2 | --------------
3 | .. automodule:: mmseg.apis
4 | :members:
5 |
6 | mmseg.datasets
7 | --------------
8 |
9 | datasets
10 | ^^^^^^^^^^
11 | .. automodule:: mmseg.datasets
12 | :members:
13 |
14 | transforms
15 | ^^^^^^^^^^^^
16 | .. automodule:: mmseg.datasets.transforms
17 | :members:
18 |
19 | mmseg.engine
20 | --------------
21 |
22 | hooks
23 | ^^^^^^^^^^
24 | .. automodule:: mmseg.engine.hooks
25 | :members:
26 |
27 | optimizers
28 | ^^^^^^^^^^^^^^^
29 | .. automodule:: mmseg.engine.optimizers
30 | :members:
31 |
32 | mmseg.evaluation
33 | -----------------
34 |
35 | metrics
36 | ^^^^^^^^^^
37 | .. automodule:: mmseg.evaluation.metrics
38 | :members:
39 |
40 | mmseg.models
41 | --------------
42 |
43 | backbones
44 | ^^^^^^^^^^^^^^^^^^
45 | .. automodule:: mmseg.models.backbones
46 | :members:
47 |
48 | decode_heads
49 | ^^^^^^^^^^^^^^^
50 | .. automodule:: mmseg.models.decode_heads
51 | :members:
52 |
53 | segmentors
54 | ^^^^^^^^^^
55 | .. automodule:: mmseg.models.segmentors
56 | :members:
57 |
58 | losses
59 | ^^^^^^^^^^
60 | .. automodule:: mmseg.models.losses
61 | :members:
62 |
63 | necks
64 | ^^^^^^^^^^^^
65 | .. automodule:: mmseg.models.necks
66 | :members:
67 |
68 | utils
69 | ^^^^^^^^^^
70 | .. automodule:: mmseg.models.utils
71 | :members:
72 |
73 |
74 | mmseg.structures
75 | --------------------
76 |
77 | structures
78 | ^^^^^^^^^^^^^^^^^
79 | .. automodule:: mmseg.structures
80 | :members:
81 |
82 | sampler
83 | ^^^^^^^^^^
84 | .. automodule:: mmseg.structures.sampler
85 | :members:
86 |
87 | mmseg.visualization
88 | --------------------
89 | .. automodule:: mmseg.visualization
90 | :members:
91 |
92 | mmseg.utils
93 | --------------
94 | .. automodule:: mmseg.utils
95 | :members:
96 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to MMSegmentation's documentation!
2 | ===========================================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Get Started
7 |
8 | overview.md
9 | get_started.md
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 | :caption: User Guides
14 |
15 | user_guides/index.rst
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: Advanced Guides
20 |
21 | advanced_guides/index.rst
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 | :caption: Migration
26 |
27 | migration/index.rst
28 |
29 | .. toctree::
30 | :caption: API Reference
31 |
32 | api.rst
33 |
34 | .. toctree::
35 | :maxdepth: 1
36 | :caption: Model Zoo
37 |
38 | model_zoo.md
39 | modelzoo_statistics.md
40 |
41 | .. toctree::
42 | :maxdepth: 1
43 | :caption: Notes
44 |
45 | notes/changelog.md
46 | notes/faq.md
47 |
48 | .. toctree::
49 | :caption: Device Support
50 |
51 | device/npu.md
52 |
53 | .. toctree::
54 | :caption: Switch Language
55 |
56 | switch_language.md
57 |
58 |
59 | Indices and tables
60 | ==================
61 |
62 | * :ref:`genindex`
63 | * :ref:`search`
64 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/migration/index.rst:
--------------------------------------------------------------------------------
1 | Migration
2 | ***************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | interface.md
8 | package.md
9 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/stat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | import functools as func
4 | import glob
5 | import os.path as osp
6 | import re
7 |
8 | import numpy as np
9 |
10 | url_prefix = 'https://github.com/open-mmlab/mmsegmentation/blob/master/'
11 |
12 | files = sorted(glob.glob('../../configs/*/README.md'))
13 |
14 | stats = []
15 | titles = []
16 | num_ckpts = 0
17 |
18 | for f in files:
19 | url = osp.dirname(f.replace('../../', url_prefix))
20 |
21 | with open(f) as content_file:
22 | content = content_file.read()
23 |
24 | title = content.split('\n')[0].replace('#', '').strip()
25 | ckpts = {
26 | x.lower().strip()
27 | for x in re.findall(r'https?://download.*\.pth', content)
28 | if 'mmsegmentation' in x
29 | }
30 | if len(ckpts) == 0:
31 | continue
32 |
33 | _papertype = [
34 | x for x in re.findall(r'', content)
35 | ]
36 | assert len(_papertype) > 0
37 | papertype = _papertype[0]
38 |
39 | paper = {(papertype, title)}
40 |
41 | titles.append(title)
42 | num_ckpts += len(ckpts)
43 | statsmsg = f"""
44 | \t* [{papertype}] [{title}]({url}) ({len(ckpts)} ckpts)
45 | """
46 | stats.append((paper, ckpts, statsmsg))
47 |
48 | allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _ in stats])
49 | msglist = '\n'.join(x for _, _, x in stats)
50 |
51 | papertypes, papercounts = np.unique([t for t, _ in allpapers],
52 | return_counts=True)
53 | countstr = '\n'.join(
54 | [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)])
55 |
56 | modelzoo = f"""
57 | # Model Zoo Statistics
58 |
59 | * Number of papers: {len(set(titles))}
60 | {countstr}
61 |
62 | * Number of checkpoints: {num_ckpts}
63 | {msglist}
64 | """
65 |
66 | with open('modelzoo_statistics.md', 'w') as f:
67 | f.write(modelzoo)
68 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/switch_language.md:
--------------------------------------------------------------------------------
1 | ## English
2 |
3 | ## 简体中文
4 |
--------------------------------------------------------------------------------
/Segmentation/docs/en/user_guides/index.rst:
--------------------------------------------------------------------------------
1 | Train & Test
2 | **************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | 1_config.md
8 | 2_dataset_prepare.md
9 | 3_inference.md
10 | 4_train_test.md
11 |
12 | Useful Tools
13 | *************
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 |
18 | visualization.md
19 | useful_tools.md
20 | deployment.md
21 | visualization_feature_map.md
22 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/_static/css/readthedocs.css:
--------------------------------------------------------------------------------
1 | .header-logo {
2 | background-image: url("../images/mmsegmentation.png");
3 | background-size: 201px 40px;
4 | height: 40px;
5 | width: 201px;
6 | }
7 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/_static/images/mmsegmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/_static/images/mmsegmentation.png
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/advanced_guides/add_transforms.md:
--------------------------------------------------------------------------------
1 | # 新增数据增强
2 |
3 | ## 自定义数据增强
4 |
5 | 自定义数据增强必须继承 `BaseTransform` 并实现 `transform` 函数。这里我们使用一个简单的翻转变换作为示例:
6 |
7 | ```python
8 | import random
9 | import mmcv
10 | from mmcv.transforms import BaseTransform, TRANSFORMS
11 |
12 | @TRANSFORMS.register_module()
13 | class MyFlip(BaseTransform):
14 | def __init__(self, direction: str):
15 | super().__init__()
16 | self.direction = direction
17 |
18 | def transform(self, results: dict) -> dict:
19 | img = results['img']
20 | results['img'] = mmcv.imflip(img, direction=self.direction)
21 | return results
22 | ```
23 |
24 | 此外,新的类需要被导入。
25 |
26 | ```python
27 | from .my_pipeline import MyFlip
28 | ```
29 |
30 | 这样,我们就可以实例化一个 `MyFlip` 对象并使用它来处理数据字典。
31 |
32 | ```python
33 | import numpy as np
34 |
35 | transform = MyFlip(direction='horizontal')
36 | data_dict = {'img': np.random.rand(224, 224, 3)}
37 | data_dict = transform(data_dict)
38 | processed_img = data_dict['img']
39 | ```
40 |
41 | 或者,我们可以在配置文件中的数据流程中使用 `MyFlip` 变换。
42 |
43 | ```python
44 | pipeline = [
45 | ...
46 | dict(type='MyFlip', direction='horizontal'),
47 | ...
48 | ]
49 | ```
50 |
51 | 需要注意,如果要在配置文件中使用 `MyFlip`,必须确保在运行时导入了包含 `MyFlip` 的文件。
52 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/advanced_guides/index.rst:
--------------------------------------------------------------------------------
1 | 基本概念
2 | ***************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | data_flow.md
8 | structures.md
9 | models.md
10 | datasets.md
11 | transforms.md
12 | evaluation.md
13 | engine.md
14 | training_tricks.md
15 |
16 | 自定义组件
17 | ************************
18 |
19 | .. toctree::
20 | :maxdepth: 1
21 |
22 | add_models.md
23 | add_datasets.md
24 | add_transforms.md
25 | add_metrics.md
26 | customize_runtime.md
27 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/api.rst:
--------------------------------------------------------------------------------
1 | mmseg.apis
2 | --------------
3 | .. automodule:: mmseg.apis
4 | :members:
5 |
6 | mmseg.datasets
7 | --------------
8 |
9 | datasets
10 | ^^^^^^^^^^
11 | .. automodule:: mmseg.datasets
12 | :members:
13 |
14 | transforms
15 | ^^^^^^^^^^^^
16 | .. automodule:: mmseg.datasets.transforms
17 | :members:
18 |
19 | mmseg.engine
20 | --------------
21 |
22 | hooks
23 | ^^^^^^^^^^
24 | .. automodule:: mmseg.engine.hooks
25 | :members:
26 |
27 | optimizers
28 | ^^^^^^^^^^^^^^^
29 | .. automodule:: mmseg.engine.optimizers
30 | :members:
31 |
32 | mmseg.evaluation
33 | --------------
34 |
35 | metrics
36 | ^^^^^^^^^^
37 | .. automodule:: mmseg.evaluation.metrics
38 | :members:
39 |
40 | mmseg.models
41 | --------------
42 |
43 | backbones
44 | ^^^^^^^^^^^^^^^^^^
45 | .. automodule:: mmseg.models.backbones
46 | :members:
47 |
48 | decode_heads
49 | ^^^^^^^^^^^^^^^
50 | .. automodule:: mmseg.models.decode_heads
51 | :members:
52 |
53 | segmentors
54 | ^^^^^^^^^^
55 | .. automodule:: mmseg.models.segmentors
56 | :members:
57 |
58 | losses
59 | ^^^^^^^^^^
60 | .. automodule:: mmseg.models.losses
61 | :members:
62 |
63 | necks
64 | ^^^^^^^^^^^^
65 | .. automodule:: mmseg.models.necks
66 | :members:
67 |
68 | utils
69 | ^^^^^^^^^^
70 | .. automodule:: mmseg.models.utils
71 | :members:
72 |
73 |
74 | mmseg.structures
75 | --------------------
76 |
77 | structures
78 | ^^^^^^^^^^^^^^^^^
79 | .. automodule:: mmseg.structures
80 | :members:
81 |
82 | sampler
83 | ^^^^^^^^^^
84 | .. automodule:: mmseg.structures.sampler
85 | :members:
86 |
87 | mmseg.visualization
88 | --------------------
89 | .. automodule:: mmseg.visualization
90 | :members:
91 |
92 | mmseg.utils
93 | --------------
94 | .. automodule:: mmseg.utils
95 | :members:
96 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/imgs/qq_group_qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/imgs/qq_group_qrcode.jpg
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/imgs/seggroup_qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/imgs/seggroup_qrcode.jpg
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/imgs/zhihu_qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/docs/zh_cn/imgs/zhihu_qrcode.jpg
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/index.rst:
--------------------------------------------------------------------------------
1 | 欢迎来到 MMSegmentation 的文档!
2 | =======================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: 开始你的第一步
7 |
8 | get_started.md
9 |
10 | .. toctree::
11 | :maxdepth: 2
12 | :caption: 用户指南
13 |
14 | user_guides/index.rst
15 |
16 | .. toctree::
17 | :maxdepth: 2
18 | :caption: 进阶指南
19 |
20 | advanced_guides/index.rst
21 |
22 | .. toctree::
23 | :maxdepth: 1
24 | :caption: 迁移指引
25 |
26 | migration/index.rst
27 |
28 | .. toctree::
29 | :caption: 接口文档(英文)
30 |
31 | api.rst
32 |
33 | .. toctree::
34 | :maxdepth: 1
35 | :caption: 模型库
36 |
37 | model_zoo.md
38 | modelzoo_statistics.md
39 |
40 | .. toctree::
41 | :maxdepth: 2
42 | :caption: 说明
43 |
44 | changelog.md
45 | faq.md
46 |
47 | .. toctree::
48 | :caption: 语言切换
49 |
50 | switch_language.md
51 |
52 |
53 | Indices and tables
54 | ==================
55 |
56 | * :ref:`genindex`
57 | * :ref:`search`
58 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/migration/index.rst:
--------------------------------------------------------------------------------
1 | 迁移
2 | ***************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | interface.md
8 | package.md
9 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/stat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | import functools as func
4 | import glob
5 | import os.path as osp
6 | import re
7 |
8 | import numpy as np
9 |
10 | url_prefix = 'https://github.com/open-mmlab/mmsegmentation/blob/master/'
11 |
12 | files = sorted(glob.glob('../../configs/*/README.md'))
13 |
14 | stats = []
15 | titles = []
16 | num_ckpts = 0
17 |
18 | for f in files:
19 | url = osp.dirname(f.replace('../../', url_prefix))
20 |
21 | with open(f) as content_file:
22 | content = content_file.read()
23 |
24 | title = content.split('\n')[0].replace('#', '').strip()
25 | ckpts = {
26 | x.lower().strip()
27 | for x in re.findall(r'https?://download.*\.pth', content)
28 | if 'mmsegmentation' in x
29 | }
30 | if len(ckpts) == 0:
31 | continue
32 |
33 | _papertype = [
34 | x for x in re.findall(r'', content)
35 | ]
36 | assert len(_papertype) > 0
37 | papertype = _papertype[0]
38 |
39 | paper = {(papertype, title)}
40 |
41 | titles.append(title)
42 | num_ckpts += len(ckpts)
43 | statsmsg = f"""
44 | \t* [{papertype}] [{title}]({url}) ({len(ckpts)} ckpts)
45 | """
46 | stats.append((paper, ckpts, statsmsg))
47 |
48 | allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _ in stats])
49 | msglist = '\n'.join(x for _, _, x in stats)
50 |
51 | papertypes, papercounts = np.unique([t for t, _ in allpapers],
52 | return_counts=True)
53 | countstr = '\n'.join(
54 | [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)])
55 |
56 | modelzoo = f"""
57 | # 模型库统计数据
58 |
59 | * 论文数量: {len(set(titles))}
60 | {countstr}
61 |
62 | * 模型数量: {num_ckpts}
63 | {msglist}
64 | """
65 |
66 | with open('modelzoo_statistics.md', 'w') as f:
67 | f.write(modelzoo)
68 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/switch_language.md:
--------------------------------------------------------------------------------
1 | ## English
2 |
3 | ## 简体中文
4 |
--------------------------------------------------------------------------------
/Segmentation/docs/zh_cn/user_guides/index.rst:
--------------------------------------------------------------------------------
1 | 训练 & 测试
2 | **************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | 1_config.md
8 | 2_dataset_prepare.md
9 | 3_inference.md
10 | 4_train_test.md
11 |
12 | 实用工具
13 | *************
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 |
18 | visualization.md
19 | useful_tools.md
20 | deployment.md
21 | visualization_feature_map.md
22 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 | import mmengine
4 | from mmengine.utils import digit_version
5 |
6 | from .version import __version__, version_info
7 |
8 | mmcv_minimum_version = '2.0.0rc4'
9 | mmcv_maximum_version = '2.1.1'
10 | mmcv_version = digit_version(mmcv.__version__)
11 |
12 | mmengine_minimum_version = '0.7.1'
13 | mmengine_maximum_version = '1.0.0'
14 | mmengine_version = digit_version(mmengine.__version__)
15 |
16 | assert (mmcv_version >= digit_version(mmcv_minimum_version)
17 | and mmcv_version < digit_version(mmcv_maximum_version)), \
18 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
19 | f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
20 |
21 | assert (mmengine_version >= digit_version(mmengine_minimum_version)
22 | and mmengine_version < digit_version(mmengine_maximum_version)), \
23 | f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
24 | f'Please install mmengine>={mmengine_minimum_version}, ' \
25 | f'<{mmengine_maximum_version}.'
26 |
27 | __all__ = ['__version__', 'version_info', 'digit_version']
28 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # from .data_preprocessors import * # noqa: F401,F403
3 | from .dense_heads import * # noqa: F401,F403
4 | from .layers import * # noqa: F401,F403
5 | from .losses import * # noqa: F401,F403
6 | from .task_modules import *
7 | from .test_time_augs import *
8 | from .tracking_heads import * # noqa: F401,F403
9 | # from .vis import * # noqa: F401,F403
10 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/dense_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .anchor_free_head import AnchorFreeHead
3 | from .maskformer_head import MaskFormerHead
4 |
5 |
6 |
7 | __all__ = [
8 | 'AnchorFreeHead',
9 | 'MaskFormerHead',
10 | ]
11 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/activations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | from mmengine.utils import digit_version
5 |
6 | from mmdet.registry import MODELS
7 |
8 | if digit_version(torch.__version__) >= digit_version('1.7.0'):
9 | from torch.nn import SiLU
10 | else:
11 |
12 | class SiLU(nn.Module):
13 | """Sigmoid Weighted Liner Unit."""
14 |
15 | def __init__(self, inplace=True):
16 | super().__init__()
17 |
18 | def forward(self, inputs) -> torch.Tensor:
19 | return inputs * torch.sigmoid(inputs)
20 |
21 |
22 | MODELS.register_module(module=SiLU, name='SiLU')
23 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/brick_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version
6 |
7 | if torch.__version__ == 'parrots':
8 | TORCH_VERSION = torch.__version__
9 | else:
10 | # torch.__version__ could be 1.3.1+cu92, we only need the first two
11 | # for comparison
12 | TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
13 |
14 |
15 | def adaptive_avg_pool2d(input, output_size):
16 | """Handle empty batch dimension to adaptive_avg_pool2d.
17 |
18 | Args:
19 | input (tensor): 4D tensor.
20 | output_size (int, tuple[int,int]): the target output size.
21 | """
22 | if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
23 | if isinstance(output_size, int):
24 | output_size = [output_size, output_size]
25 | output_size = [*input.shape[:2], *output_size]
26 | empty = NewEmptyTensorOp.apply(input, output_size)
27 | return empty
28 | else:
29 | return F.adaptive_avg_pool2d(input, output_size)
30 |
31 |
32 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
33 | """Handle empty batch dimension to AdaptiveAvgPool2d."""
34 |
35 | def forward(self, x):
36 | # PyTorch 1.9 does not support empty tensor inference yet
37 | if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
38 | output_size = self.output_size
39 | if isinstance(output_size, int):
40 | output_size = [output_size, output_size]
41 | else:
42 | output_size = [
43 | v if v is not None else d
44 | for v, d in zip(output_size,
45 | x.size()[-2:])
46 | ]
47 | output_size = [*x.shape[:2], *output_size]
48 | empty = NewEmptyTensorOp.apply(x, output_size)
49 | return empty
50 |
51 | return super().forward(x)
52 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/mmcv_spike/__init__.py:
--------------------------------------------------------------------------------
1 | from .multi_scale_deform_attn import SpikeMultiScaleDeformableAttention
2 | from .transformer import MultiheadAttention, FFN, MSDA_FFN, MS_MLP
3 | from .spikeformer import MSTransformerDecoder, CrossAttention, SelfAttention, MLP
4 | from .BASE_Transformer import Transformer
5 | # NOTE: Move the mmcv function here to change the basic version
6 | __all__ = [
7 | "SpikeMultiScaleDeformableAttention", "MultiheadAttention",
8 | "MSTransformerDecoder", "CrossAttention", "SelfAttention", "MLP", "MSDA_FFN",
9 | "Transformer", "MS_MLP"
10 | ]
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/mmcv_spike/scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Scale(nn.Module):
7 | """A learnable scale parameter.
8 |
9 | This layer scales the input by a learnable factor. It multiplies a
10 | learnable scale parameter of shape (1,) with input of any shape.
11 |
12 | Args:
13 | scale (float): Initial value of scale factor. Default: 1.0
14 | """
15 |
16 | def __init__(self, scale: float = 1.0):
17 | super().__init__()
18 | self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
19 |
20 | def forward(self, x: torch.Tensor) -> torch.Tensor:
21 | return x * self.scale
22 |
23 |
24 | class LayerScale(nn.Module):
25 | """LayerScale layer.
26 |
27 | Args:
28 | dim (int): Dimension of input features.
29 | inplace (bool): Whether performs operation in-place.
30 | Default: `False`.
31 | data_format (str): The input data format, could be 'channels_last'
32 | or 'channels_first', representing (B, C, H, W) and
33 | (B, N, C) format data respectively. Default: 'channels_last'.
34 | scale (float): Initial value of scale factor. Default: 1.0
35 | """
36 |
37 | def __init__(self,
38 | dim: int,
39 | inplace: bool = False,
40 | data_format: str = 'channels_last',
41 | scale: float = 1e-5):
42 | super().__init__()
43 | assert data_format in ('channels_last', 'channels_first'), \
44 | "'data_format' could only be channels_last or channels_first."
45 | self.inplace = inplace
46 | self.data_format = data_format
47 | self.weight = nn.Parameter(torch.ones(dim) * scale)
48 |
49 | def forward(self, x) -> torch.Tensor:
50 | if self.data_format == 'channels_first':
51 | shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
52 | else:
53 | shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
54 | if self.inplace:
55 | return x.mul_(self.weight.view(*shape))
56 | else:
57 | return x * self.weight.view(*shape)
58 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules.dcnv3 import DCNv3_pytorch
2 | __all__ = [
3 | "DCNv3_pytorch"
4 | ]
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternImage
3 | # Copyright (c) 2022 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from .dcnv3_func import dcnv3_core_pytorch
8 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # --------------------------------------------------------
3 | # InternImage
4 | # Copyright (c) 2022 OpenGVLab
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # --------------------------------------------------------
7 |
8 | python setup.py build install
9 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternImage
3 | # Copyright (c) 2022 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from .dcnv3 import DCNv3_pytorch
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/cpu/dcnv3_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * InternImage
4 | * Copyright (c) 2022 OpenGVLab
5 | * Licensed under The MIT License [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from
8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
9 | **************************************************************************************************
10 | */
11 |
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
18 | const at::Tensor &mask, const int kernel_h,
19 | const int kernel_w, const int stride_h,
20 | const int stride_w, const int pad_h,
21 | const int pad_w, const int dilation_h,
22 | const int dilation_w, const int group,
23 | const int group_channels, const float offset_scale,
24 | const int im2col_step) {
25 | AT_ERROR("Not implement on cpu");
26 | }
27 |
28 | std::vector
29 | dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
30 | const at::Tensor &mask, const int kernel_h,
31 | const int kernel_w, const int stride_h, const int stride_w,
32 | const int pad_h, const int pad_w, const int dilation_h,
33 | const int dilation_w, const int group,
34 | const int group_channels, const float offset_scale,
35 | const at::Tensor &grad_output, const int im2col_step) {
36 | AT_ERROR("Not implement on cpu");
37 | }
38 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/cpu/dcnv3_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * InternImage
4 | * Copyright (c) 2022 OpenGVLab
5 | * Licensed under The MIT License [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from
8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
9 | **************************************************************************************************
10 | */
11 |
12 | #pragma once
13 | #include
14 |
15 | at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
16 | const at::Tensor &mask, const int kernel_h,
17 | const int kernel_w, const int stride_h,
18 | const int stride_w, const int pad_h,
19 | const int pad_w, const int dilation_h,
20 | const int dilation_w, const int group,
21 | const int group_channels, const float offset_scale,
22 | const int im2col_step);
23 |
24 | std::vector
25 | dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
26 | const at::Tensor &mask, const int kernel_h,
27 | const int kernel_w, const int stride_h, const int stride_w,
28 | const int pad_h, const int pad_w, const int dilation_h,
29 | const int dilation_w, const int group,
30 | const int group_channels, const float offset_scale,
31 | const at::Tensor &grad_output, const int im2col_step);
32 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/cuda/dcnv3_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * InternImage
4 | * Copyright (c) 2022 OpenGVLab
5 | * Licensed under The MIT License [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from
8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
9 | **************************************************************************************************
10 | */
11 |
12 | #pragma once
13 | #include
14 |
15 | at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
16 | const at::Tensor &mask, const int kernel_h,
17 | const int kernel_w, const int stride_h,
18 | const int stride_w, const int pad_h,
19 | const int pad_w, const int dilation_h,
20 | const int dilation_w, const int group,
21 | const int group_channels,
22 | const float offset_scale, const int im2col_step);
23 |
24 | std::vector
25 | dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
26 | const at::Tensor &mask, const int kernel_h,
27 | const int kernel_w, const int stride_h, const int stride_w,
28 | const int pad_h, const int pad_w, const int dilation_h,
29 | const int dilation_w, const int group,
30 | const int group_channels, const float offset_scale,
31 | const at::Tensor &grad_output, const int im2col_step);
32 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/layers/transformer/ops_dcnv3/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * InternImage
4 | * Copyright (c) 2022 OpenGVLab
5 | * Licensed under The MIT License [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from
8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
9 | **************************************************************************************************
10 | */
11 |
12 | #include "dcnv3.h"
13 |
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 | m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward");
16 | m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward");
17 | }
18 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .accuracy import Accuracy, accuracy
3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
4 | cross_entropy, mask_cross_entropy)
5 | from .dice_loss import DiceLoss
6 | from .focal_loss import FocalLoss, sigmoid_focal_loss
7 | from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, EIoULoss, GIoULoss,
8 | IoULoss, SIoULoss, bounded_iou_loss, iou_loss)
9 | from .mse_loss import MSELoss, mse_loss
10 | from .multipos_cross_entropy_loss import MultiPosCrossEntropyLoss
11 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss
12 |
13 |
14 | __all__ = [
15 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
16 | 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
17 | 'FocalLoss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
18 | 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss',
19 | 'EIoULoss', 'SIoULoss', 'reduce_loss',
20 | 'weight_reduce_loss', 'weighted_loss',
21 | 'DiceLoss', 'MultiPosCrossEntropyLoss',
22 | ]
23 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/task_modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 | from .builder import (ANCHOR_GENERATORS, BBOX_ASSIGNERS, BBOX_CODERS,
4 | BBOX_SAMPLERS, IOU_CALCULATORS, MATCH_COSTS,
5 | PRIOR_GENERATORS, build_anchor_generator, build_assigner,
6 | build_bbox_coder, build_iou_calculator, build_match_cost,
7 | build_prior_generator, build_sampler)
8 |
9 | from .prior_generators import * # noqa: F401,F403
10 | from .assigners import *
11 | from .samplers import *
12 |
13 |
14 | __all__ = [
15 | 'ANCHOR_GENERATORS', 'PRIOR_GENERATORS', 'BBOX_ASSIGNERS', 'BBOX_SAMPLERS',
16 | 'MATCH_COSTS', 'BBOX_CODERS', 'IOU_CALCULATORS', 'build_anchor_generator',
17 | 'build_prior_generator', 'build_assigner', 'build_sampler',
18 | 'build_iou_calculator', 'build_match_cost', 'build_bbox_coder'
19 | ]
20 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/task_modules/assigners/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .approx_max_iou_assigner import ApproxMaxIoUAssigner
3 | from .assign_result import AssignResult
4 | from .atss_assigner import ATSSAssigner
5 | from .base_assigner import BaseAssigner
6 | from .center_region_assigner import CenterRegionAssigner
7 | from .dynamic_soft_label_assigner import DynamicSoftLabelAssigner
8 | from .grid_assigner import GridAssigner
9 | from .hungarian_assigner import HungarianAssigner
10 | from .iou2d_calculator import BboxOverlaps2D
11 | from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost,
12 | DiceCost, FocalLossCost, IoUCost)
13 | from .max_iou_assigner import MaxIoUAssigner
14 | from .multi_instance_assigner import MultiInstanceAssigner
15 | from .point_assigner import PointAssigner
16 | from .region_assigner import RegionAssigner
17 | from .sim_ota_assigner import SimOTAAssigner
18 | from .task_aligned_assigner import TaskAlignedAssigner
19 | from .uniform_assigner import UniformAssigner
20 |
21 | __all__ = [
22 | 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
23 | 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
24 | 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner',
25 | 'TaskAlignedAssigner', 'BBoxL1Cost', 'ClassificationCost',
26 | 'CrossEntropyLossCost', 'DiceCost', 'FocalLossCost', 'IoUCost',
27 | 'BboxOverlaps2D', 'DynamicSoftLabelAssigner', 'MultiInstanceAssigner'
28 | ]
29 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/task_modules/assigners/base_assigner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Optional
4 |
5 | from mmengine.structures import InstanceData
6 |
7 |
8 | class BaseAssigner(metaclass=ABCMeta):
9 | """Base assigner that assigns boxes to ground truth boxes."""
10 |
11 | @abstractmethod
12 | def assign(self,
13 | pred_instances: InstanceData,
14 | gt_instances: InstanceData,
15 | gt_instances_ignore: Optional[InstanceData] = None,
16 | **kwargs):
17 | """Assign boxes to either a ground truth boxes or a negative boxes."""
18 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/task_modules/prior_generators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
3 | SSDAnchorGenerator, YOLOAnchorGenerator)
4 | from .point_generator import MlvlPointGenerator, PointGenerator
5 | from .utils import anchor_inside_flags, calc_region
6 |
7 | __all__ = [
8 | 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
9 | 'PointGenerator', 'calc_region', 'YOLOAnchorGenerator',
10 | 'MlvlPointGenerator', 'SSDAnchorGenerator'
11 | ]
12 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/task_modules/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_sampler import BaseSampler
3 | from .combined_sampler import CombinedSampler
4 | from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
5 | from .iou_balanced_neg_sampler import IoUBalancedNegSampler
6 | from .mask_pseudo_sampler import MaskPseudoSampler
7 | from .mask_sampling_result import MaskSamplingResult
8 | from .multi_instance_random_sampler import MultiInsRandomSampler
9 | from .multi_instance_sampling_result import MultiInstanceSamplingResult
10 | from .ohem_sampler import OHEMSampler
11 | from .pseudo_sampler import PseudoSampler
12 | from .random_sampler import RandomSampler
13 | from .sampling_result import SamplingResult
14 | from .score_hlr_sampler import ScoreHLRSampler
15 |
16 | __all__ = [
17 | 'BaseSampler', 'PseudoSampler', 'RandomSampler',
18 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
19 | 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler',
20 | 'MaskSamplingResult', 'MultiInstanceSamplingResult',
21 | 'MultiInsRandomSampler'
22 | ]
23 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/task_modules/samplers/combined_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmdet.registry import TASK_UTILS
3 | from .base_sampler import BaseSampler
4 |
5 |
6 | @TASK_UTILS.register_module()
7 | class CombinedSampler(BaseSampler):
8 | """A sampler that combines positive sampler and negative sampler."""
9 |
10 | def __init__(self, pos_sampler, neg_sampler, **kwargs):
11 | super(CombinedSampler, self).__init__(**kwargs)
12 | self.pos_sampler = TASK_UTILS.build(pos_sampler, default_args=kwargs)
13 | self.neg_sampler = TASK_UTILS.build(neg_sampler, default_args=kwargs)
14 |
15 | def _sample_pos(self, **kwargs):
16 | """Sample positive samples."""
17 | raise NotImplementedError
18 |
19 | def _sample_neg(self, **kwargs):
20 | """Sample negative samples."""
21 | raise NotImplementedError
22 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/test_time_augs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .det_tta import DetTTAModel
3 | from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
4 | merge_aug_proposals, merge_aug_results,
5 | merge_aug_scores)
6 |
7 | __all__ = [
8 | 'merge_aug_bboxes', 'merge_aug_masks', 'merge_aug_proposals',
9 | 'merge_aug_scores', 'merge_aug_results', 'DetTTAModel'
10 | ]
11 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/tracking_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .mask2former_track_head import Mask2FormerTrackHead
3 | __all__ = [
4 | 'Mask2FormerTrackHead',
5 | ]
6 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/utils/Qtrick.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | quant = 4
5 | T = quant
6 |
7 |
8 | class Quant(torch.autograd.Function):
9 | @staticmethod
10 | @torch.cuda.amp.custom_fwd
11 | def forward(ctx, i, min_value=0, max_value=quant):
12 | ctx.min = min_value
13 | ctx.max = max_value
14 | ctx.save_for_backward(i)
15 | return torch.round(torch.clamp(i, min=min_value, max=max_value))
16 |
17 | @staticmethod
18 | @torch.cuda.amp.custom_fwd
19 | def backward(ctx, grad_output):
20 | grad_input = grad_output.clone()
21 | i, = ctx.saved_tensors
22 | grad_input[i < ctx.min] = 0
23 | grad_input[i > ctx.max] = 0
24 | return grad_input, None, None
25 |
26 |
27 | class MultiSpike_norm4(nn.Module):
28 | def __init__(
29 | self,
30 | Vth=1.0,
31 | T=T, # 在T上进行Norm
32 | ):
33 | super().__init__()
34 | self.spike = Quant()
35 | self.Vth = Vth
36 | self.T = T
37 |
38 | def forward(self, x):
39 | return self.spike.apply(x) / self.T
40 |
41 |
42 | #
43 | class MultiSpike_4(nn.Module):
44 | def __init__(
45 | self,
46 | T=T, # 在T上进行Norm
47 | ):
48 | super().__init__()
49 | self.spike = Quant()
50 | self.T = T
51 |
52 | def forward(self, x):
53 | return self.spike.apply(x)
54 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .gaussian_target import (gather_feat, gaussian_radius,
3 | gen_gaussian_target, get_local_maximum,
4 | get_topk_from_heatmap, transpose_and_gather_feat)
5 | from .image import imrenormalize
6 | from .make_divisible import make_divisible
7 | from .misc import (aligned_bilinear, center_of_mass, empty_instances,
8 | filter_gt_instances, filter_scores_and_topk, flip_tensor,
9 | generate_coordinate, images_to_levels, interpolate_as,
10 | levels_to_images, mask2ndarray, multi_apply,
11 | relative_coordinate_maps, rename_loss_dict,
12 | reweight_loss_dict, samplelist_boxtype2tensor,
13 | select_single_mlvl, sigmoid_geometric_mean,
14 | unfold_wo_center, unmap, unpack_gt_instances)
15 | from .panoptic_gt_processing import preprocess_panoptic_gt
16 | from .point_sample import (get_uncertain_point_coords_with_randomness,
17 | get_uncertainty)
18 | from .vlfuse_helper import BertEncoderLayer, VLFuse, permute_and_flatten
19 | from .Qtrick import MultiSpike_norm4
20 |
21 | __all__ = [
22 | 'gaussian_radius', 'gen_gaussian_target', 'make_divisible',
23 | 'get_local_maximum', 'get_topk_from_heatmap', 'transpose_and_gather_feat',
24 | 'interpolate_as', 'sigmoid_geometric_mean', 'gather_feat',
25 | 'preprocess_panoptic_gt', 'get_uncertain_point_coords_with_randomness',
26 | 'get_uncertainty', 'unpack_gt_instances', 'empty_instances',
27 | 'center_of_mass', 'filter_scores_and_topk', 'flip_tensor',
28 | 'generate_coordinate', 'levels_to_images', 'mask2ndarray', 'multi_apply',
29 | 'select_single_mlvl', 'unmap', 'images_to_levels',
30 | 'samplelist_boxtype2tensor', 'filter_gt_instances', 'rename_loss_dict',
31 | 'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear',
32 | 'unfold_wo_center', 'imrenormalize', 'VLFuse', 'permute_and_flatten',
33 | 'BertEncoderLayer', "MultiSpike_norm4"
34 | ]
35 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/utils/image.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Union
3 |
4 | import mmcv
5 | import numpy as np
6 | import torch
7 | from torch import Tensor
8 |
9 |
10 | def imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict,
11 | new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]:
12 | """Re-normalize the image.
13 |
14 | Args:
15 | img (Tensor | ndarray): Input image. If the input is a Tensor, the
16 | shape is (1, C, H, W). If the input is a ndarray, the shape
17 | is (H, W, C).
18 | img_norm_cfg (dict): Original configuration for the normalization.
19 | new_img_norm_cfg (dict): New configuration for the normalization.
20 |
21 | Returns:
22 | Tensor | ndarray: Output image with the same type and shape of
23 | the input.
24 | """
25 | if isinstance(img, torch.Tensor):
26 | assert img.ndim == 4 and img.shape[0] == 1
27 | new_img = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
28 | new_img = _imrenormalize(new_img, img_norm_cfg, new_img_norm_cfg)
29 | new_img = new_img.transpose(2, 0, 1)[None]
30 | return torch.from_numpy(new_img).to(img)
31 | else:
32 | return _imrenormalize(img, img_norm_cfg, new_img_norm_cfg)
33 |
34 |
35 | def _imrenormalize(img: Union[Tensor, np.ndarray], img_norm_cfg: dict,
36 | new_img_norm_cfg: dict) -> Union[Tensor, np.ndarray]:
37 | """Re-normalize the image."""
38 | img_norm_cfg = img_norm_cfg.copy()
39 | new_img_norm_cfg = new_img_norm_cfg.copy()
40 | for k, v in img_norm_cfg.items():
41 | if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
42 | img_norm_cfg[k] = np.array(v, dtype=img.dtype)
43 | # reverse cfg
44 | if 'bgr_to_rgb' in img_norm_cfg:
45 | img_norm_cfg['rgb_to_bgr'] = img_norm_cfg['bgr_to_rgb']
46 | img_norm_cfg.pop('bgr_to_rgb')
47 | for k, v in new_img_norm_cfg.items():
48 | if (k == 'mean' or k == 'std') and not isinstance(v, np.ndarray):
49 | new_img_norm_cfg[k] = np.array(v, dtype=img.dtype)
50 | img = mmcv.imdenormalize(img, **img_norm_cfg)
51 | img = mmcv.imnormalize(img, **new_img_norm_cfg)
52 | return img
53 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/models/utils/make_divisible.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
3 | """Make divisible function.
4 |
5 | This function rounds the channel number to the nearest value that can be
6 | divisible by the divisor. It is taken from the original tf repo. It ensures
7 | that all layers have a channel number that is divisible by divisor. It can
8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
9 |
10 | Args:
11 | value (int): The original channel number.
12 | divisor (int): The divisor to fully divide the channel number.
13 | min_value (int): The minimum value of the output channel.
14 | Default: None, means that the minimum value equal to the divisor.
15 | min_ratio (float): The minimum ratio of the rounded channel number to
16 | the original channel number. Default: 0.9.
17 |
18 | Returns:
19 | int: The modified output channel number.
20 | """
21 |
22 | if min_value is None:
23 | min_value = divisor
24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
25 | # Make sure that round down does not go down by more than (1-min_ratio).
26 | if new_value < min_ratio * value:
27 | new_value += divisor
28 | return new_value
29 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/structures/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .det_data_sample import DetDataSample, OptSampleList, SampleList
3 | from .reid_data_sample import ReIDDataSample
4 | from .track_data_sample import (OptTrackSampleList, TrackDataSample,
5 | TrackSampleList)
6 |
7 | __all__ = [
8 | 'DetDataSample', 'SampleList', 'OptSampleList', 'TrackDataSample',
9 | 'TrackSampleList', 'OptTrackSampleList', 'ReIDDataSample'
10 | ]
11 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/structures/bbox/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_boxes import BaseBoxes
3 | from .bbox_overlaps import bbox_overlaps
4 | from .box_type import (autocast_box_type, convert_box_type, get_box_type,
5 | register_box, register_box_converter)
6 | from .horizontal_boxes import HorizontalBoxes
7 | from .transforms import bbox_cxcyah_to_xyxy # noqa: E501
8 | from .transforms import (bbox2corner, bbox2distance, bbox2result, bbox2roi,
9 | bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
10 | bbox_mapping_back, bbox_project, bbox_rescale,
11 | bbox_xyxy_to_cxcyah, bbox_xyxy_to_cxcywh, cat_boxes,
12 | corner2bbox, distance2bbox, empty_box_as,
13 | find_inside_bboxes, get_box_tensor, get_box_wh,
14 | roi2bbox, scale_boxes, stack_boxes)
15 |
16 | __all__ = [
17 | 'bbox_overlaps', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
18 | 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
19 | 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
20 | 'find_inside_bboxes', 'bbox2corner', 'corner2bbox', 'bbox_project',
21 | 'BaseBoxes', 'convert_box_type', 'get_box_type', 'register_box',
22 | 'register_box_converter', 'HorizontalBoxes', 'autocast_box_type',
23 | 'cat_boxes', 'stack_boxes', 'scale_boxes', 'get_box_wh', 'get_box_tensor',
24 | 'empty_box_as', 'bbox_xyxy_to_cxcyah', 'bbox_cxcyah_to_xyxy'
25 | ]
26 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/structures/mask/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .mask_target import mask_target
3 | from .structures import (BaseInstanceMasks, BitmapMasks, PolygonMasks,
4 | bitmap_to_polygon, polygon_to_bitmap)
5 | from .utils import encode_mask_results, mask2bbox, split_combined_polys
6 |
7 | __all__ = [
8 | 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks',
9 | 'PolygonMasks', 'encode_mask_results', 'mask2bbox', 'polygon_to_bitmap',
10 | 'bitmap_to_polygon'
11 | ]
12 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .collect_env import collect_env
3 | from .compat_config import compat_cfg
4 | from .dist_utils import (all_reduce_dict, allreduce_grads, reduce_mean,
5 | sync_random_seed)
6 | from .logger import get_caller_name, log_img_scale
7 | from .memory import AvoidCUDAOOM, AvoidOOM
8 | from .misc import (find_latest_checkpoint, get_test_pipeline_cfg,
9 | update_data_root)
10 | from .mot_error_visualize import imshow_mot_errors
11 | from .replace_cfg_vals import replace_cfg_vals
12 | from .setup_env import (register_all_modules, setup_cache_size_limit_of_dynamo,
13 | setup_multi_processes)
14 | from .split_batch import split_batch
15 | from .typing_utils import (ConfigType, InstanceList, MultiConfig,
16 | OptConfigType, OptInstanceList, OptMultiConfig,
17 | OptPixelList, PixelList, RangeType)
18 |
19 | __all__ = [
20 | 'collect_env', 'find_latest_checkpoint', 'update_data_root',
21 | 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg',
22 | 'split_batch', 'register_all_modules', 'replace_cfg_vals', 'AvoidOOM',
23 | 'AvoidCUDAOOM', 'all_reduce_dict', 'allreduce_grads', 'reduce_mean',
24 | 'sync_random_seed', 'ConfigType', 'InstanceList', 'MultiConfig',
25 | 'OptConfigType', 'OptInstanceList', 'OptMultiConfig', 'OptPixelList',
26 | 'PixelList', 'RangeType', 'get_test_pipeline_cfg',
27 | 'setup_cache_size_limit_of_dynamo', 'imshow_mot_errors'
28 | ]
29 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmengine.utils import get_git_hash
3 | from mmengine.utils.dl_utils import collect_env as collect_base_env
4 |
5 | import mmdet
6 |
7 |
8 | def collect_env():
9 | """Collect the information of the running environments."""
10 | env_info = collect_base_env()
11 | env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7]
12 | return env_info
13 |
14 |
15 | if __name__ == '__main__':
16 | for name, val in collect_env().items():
17 | print(f'{name}: {val}')
18 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import inspect
3 |
4 | from mmengine.logging import print_log
5 |
6 |
7 | def get_caller_name():
8 | """Get name of caller method."""
9 | # this_func_frame = inspect.stack()[0][0] # i.e., get_caller_name
10 | # callee_frame = inspect.stack()[1][0] # e.g., log_img_scale
11 | caller_frame = inspect.stack()[2][0] # e.g., caller of log_img_scale
12 | caller_method = caller_frame.f_code.co_name
13 | try:
14 | caller_class = caller_frame.f_locals['self'].__class__.__name__
15 | return f'{caller_class}.{caller_method}'
16 | except KeyError: # caller is a function
17 | return caller_method
18 |
19 |
20 | def log_img_scale(img_scale, shape_order='hw', skip_square=False):
21 | """Log image size.
22 |
23 | Args:
24 | img_scale (tuple): Image size to be logged.
25 | shape_order (str, optional): The order of image shape.
26 | 'hw' for (height, width) and 'wh' for (width, height).
27 | Defaults to 'hw'.
28 | skip_square (bool, optional): Whether to skip logging for square
29 | img_scale. Defaults to False.
30 |
31 | Returns:
32 | bool: Whether to have done logging.
33 | """
34 | if shape_order == 'hw':
35 | height, width = img_scale
36 | elif shape_order == 'wh':
37 | width, height = img_scale
38 | else:
39 | raise ValueError(f'Invalid shape_order {shape_order}.')
40 |
41 | if skip_square and (height == width):
42 | return False
43 |
44 | caller = get_caller_name()
45 | print_log(
46 | f'image shape: height={height}, width={width} in {caller}',
47 | logger='current')
48 |
49 | return True
50 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/profiling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import contextlib
3 | import sys
4 | import time
5 |
6 | import torch
7 |
8 | if sys.version_info >= (3, 7):
9 |
10 | @contextlib.contextmanager
11 | def profile_time(trace_name,
12 | name,
13 | enabled=True,
14 | stream=None,
15 | end_stream=None):
16 | """Print time spent by CPU and GPU.
17 |
18 | Useful as a temporary context manager to find sweet spots of code
19 | suitable for async implementation.
20 | """
21 | if (not enabled) or not torch.cuda.is_available():
22 | yield
23 | return
24 | stream = stream if stream else torch.cuda.current_stream()
25 | end_stream = end_stream if end_stream else stream
26 | start = torch.cuda.Event(enable_timing=True)
27 | end = torch.cuda.Event(enable_timing=True)
28 | stream.record_event(start)
29 | try:
30 | cpu_start = time.monotonic()
31 | yield
32 | finally:
33 | cpu_end = time.monotonic()
34 | end_stream.record_event(end)
35 | end.synchronize()
36 | cpu_time = (cpu_end - cpu_start) * 1000
37 | gpu_time = start.elapsed_time(end)
38 | msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms '
39 | msg += f'gpu_time {gpu_time:.2f} ms stream {stream}'
40 | print(msg, end_stream)
41 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/split_batch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 |
5 | def split_batch(img, img_metas, kwargs):
6 | """Split data_batch by tags.
7 |
8 | Code is modified from
9 | # noqa: E501
10 |
11 | Args:
12 | img (Tensor): of shape (N, C, H, W) encoding input images.
13 | Typically these should be mean centered and std scaled.
14 | img_metas (list[dict]): List of image info dict where each dict
15 | has: 'img_shape', 'scale_factor', 'flip', and may also contain
16 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
17 | For details on the values of these keys, see
18 | :class:`mmdet.datasets.pipelines.Collect`.
19 | kwargs (dict): Specific to concrete implementation.
20 |
21 | Returns:
22 | data_groups (dict): a dict that data_batch splited by tags,
23 | such as 'sup', 'unsup_teacher', and 'unsup_student'.
24 | """
25 |
26 | # only stack img in the batch
27 | def fuse_list(obj_list, obj):
28 | return torch.stack(obj_list) if isinstance(obj,
29 | torch.Tensor) else obj_list
30 |
31 | # select data with tag from data_batch
32 | def select_group(data_batch, current_tag):
33 | group_flag = [tag == current_tag for tag in data_batch['tag']]
34 | return {
35 | k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v)
36 | for k, v in data_batch.items()
37 | }
38 |
39 | kwargs.update({'img': img, 'img_metas': img_metas})
40 | kwargs.update({'tag': [meta['tag'] for meta in img_metas]})
41 | tags = list(set(kwargs['tag']))
42 | data_groups = {tag: select_group(kwargs, tag) for tag in tags}
43 | for tag, group in data_groups.items():
44 | group.pop('tag')
45 | return data_groups
46 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/typing_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """Collecting some commonly used type hint in mmdetection."""
3 | from typing import List, Optional, Sequence, Tuple, Union
4 |
5 | from mmengine.config import ConfigDict
6 | from mmengine.structures import InstanceData, PixelData
7 |
8 | # TODO: Need to avoid circular import with assigner and sampler
9 | # Type hint of config data
10 | ConfigType = Union[ConfigDict, dict]
11 | OptConfigType = Optional[ConfigType]
12 | # Type hint of one or more config data
13 | MultiConfig = Union[ConfigType, List[ConfigType]]
14 | OptMultiConfig = Optional[MultiConfig]
15 |
16 | InstanceList = List[InstanceData]
17 | OptInstanceList = Optional[InstanceList]
18 |
19 | PixelList = List[PixelData]
20 | OptPixelList = Optional[PixelList]
21 |
22 | RangeType = Sequence[Tuple[int, int]]
23 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/utils/util_random.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """Helpers for random number generators."""
3 | import numpy as np
4 |
5 |
6 | def ensure_rng(rng=None):
7 | """Coerces input into a random number generator.
8 |
9 | If the input is None, then a global random state is returned.
10 |
11 | If the input is a numeric value, then that is used as a seed to construct a
12 | random state. Otherwise the input is returned as-is.
13 |
14 | Adapted from [1]_.
15 |
16 | Args:
17 | rng (int | numpy.random.RandomState | None):
18 | if None, then defaults to the global rng. Otherwise this can be an
19 | integer or a RandomState class
20 | Returns:
21 | (numpy.random.RandomState) : rng -
22 | a numpy random number generator
23 |
24 | References:
25 | .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
26 | """
27 |
28 | if rng is None:
29 | rng = np.random.mtrand._rand
30 | elif isinstance(rng, int):
31 | rng = np.random.RandomState(rng)
32 | else:
33 | rng = rng
34 | return rng
35 |
--------------------------------------------------------------------------------
/Segmentation/mmdet/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 | __version__ = '3.1.0'
4 | short_version = __version__
5 |
6 |
7 | def parse_version_info(version_str):
8 | """Parse a version string into a tuple.
9 |
10 | Args:
11 | version_str (str): The version string.
12 | Returns:
13 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
14 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
15 | """
16 | version_info = []
17 | for x in version_str.split('.'):
18 | if x.isdigit():
19 | version_info.append(int(x))
20 | elif x.find('rc') != -1:
21 | patch_version = x.split('rc')
22 | version_info.append(int(patch_version[0]))
23 | version_info.append(f'rc{patch_version[1]}')
24 | return tuple(version_info)
25 |
26 |
27 | version_info = parse_version_info(__version__)
28 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/apis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .inference import inference_model, init_model, show_result_pyplot
3 | from .mmseg_inferencer import MMSegInferencer
4 |
5 | __all__ = [
6 | 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
7 | ]
8 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/chase_db1.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmengine.fileio as fileio
3 |
4 | from mmseg.registry import DATASETS
5 | from .basesegdataset import BaseSegDataset
6 |
7 |
8 | @DATASETS.register_module()
9 | class ChaseDB1Dataset(BaseSegDataset):
10 | """Chase_db1 dataset.
11 |
12 | In segmentation map annotation for Chase_db1, 0 stands for background,
13 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
14 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
15 | '_1stHO.png'.
16 | """
17 | METAINFO = dict(
18 | classes=('background', 'vessel'),
19 | palette=[[120, 120, 120], [6, 230, 230]])
20 |
21 | def __init__(self,
22 | img_suffix='.png',
23 | seg_map_suffix='_1stHO.png',
24 | reduce_zero_label=False,
25 | **kwargs) -> None:
26 | super().__init__(
27 | img_suffix=img_suffix,
28 | seg_map_suffix=seg_map_suffix,
29 | reduce_zero_label=reduce_zero_label,
30 | **kwargs)
31 | assert fileio.exists(
32 | self.data_prefix['img_path'], backend_args=self.backend_args)
33 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmseg.registry import DATASETS
3 | from .basesegdataset import BaseSegDataset
4 |
5 |
6 | @DATASETS.register_module()
7 | class CityscapesDataset(BaseSegDataset):
8 | """Cityscapes dataset.
9 |
10 | The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
11 | fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
12 | """
13 | METAINFO = dict(
14 | classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
15 | 'traffic light', 'traffic sign', 'vegetation', 'terrain',
16 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
17 | 'motorcycle', 'bicycle'),
18 | palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
19 | [190, 153, 153], [153, 153, 153], [250, 170,
20 | 30], [220, 220, 0],
21 | [107, 142, 35], [152, 251, 152], [70, 130, 180],
22 | [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
23 | [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
24 |
25 | def __init__(self,
26 | img_suffix='_leftImg8bit.png',
27 | seg_map_suffix='_gtFine_labelTrainIds.png',
28 | **kwargs) -> None:
29 | super().__init__(
30 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
31 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/ddd17.py:
--------------------------------------------------------------------------------
1 | # 大概的转换思路是,首先将数据集转化为ade20k的存储格式,然后按照ade20k的存储方法存放
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | from mmseg.registry import DATASETS
4 | from .basesegdataset import BaseSegDataset
5 | import copy
6 | import os.path as osp
7 | from typing import Callable, Dict, List, Optional, Sequence, Union
8 |
9 | import mmengine
10 | import mmengine.fileio as fileio
11 | import numpy as np
12 | from mmengine.dataset import BaseDataset, Compose
13 |
14 | from mmseg.registry import DATASETS
15 |
16 | @DATASETS.register_module()
17 | class DDD17Dataset(BaseSegDataset):
18 | """ADE20K dataset.
19 |
20 | In segmentation map annotation for ADE20K, 0 stands for background, which
21 | is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
22 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
23 | '.png'.
24 | """
25 | METAINFO = dict(
26 | classes=('flat', 'construction+sky', 'object', 'nature', 'human', 'vehicle'),
27 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
28 | [4, 200, 3], [120, 120, 80]])
29 |
30 | def __init__(self,
31 | img_suffix='.npy',
32 | seg_map_suffix='.png',
33 | **kwargs) -> None:
34 | super().__init__(
35 | img_suffix=img_suffix,
36 | seg_map_suffix=seg_map_suffix,
37 | **kwargs)
38 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/drive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmengine.fileio as fileio
3 |
4 | from mmseg.registry import DATASETS
5 | from .basesegdataset import BaseSegDataset
6 |
7 |
8 | @DATASETS.register_module()
9 | class DRIVEDataset(BaseSegDataset):
10 | """DRIVE dataset.
11 |
12 | In segmentation map annotation for DRIVE, 0 stands for background, which is
13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The
14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
15 | '_manual1.png'.
16 | """
17 | METAINFO = dict(
18 | classes=('background', 'vessel'),
19 | palette=[[120, 120, 120], [6, 230, 230]])
20 |
21 | def __init__(self,
22 | img_suffix='.png',
23 | seg_map_suffix='_manual1.png',
24 | reduce_zero_label=False,
25 | **kwargs) -> None:
26 | super().__init__(
27 | img_suffix=img_suffix,
28 | seg_map_suffix=seg_map_suffix,
29 | reduce_zero_label=reduce_zero_label,
30 | **kwargs)
31 | assert fileio.exists(
32 | self.data_prefix['img_path'], backend_args=self.backend_args)
33 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/synapse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmseg.registry import DATASETS
3 | from .basesegdataset import BaseSegDataset
4 |
5 |
6 | @DATASETS.register_module()
7 | class SynapseDataset(BaseSegDataset):
8 | """Synapse dataset.
9 |
10 | Before dataset preprocess of Synapse, there are total 13 categories of
11 | foreground which does not include background. After preprocessing, 8
12 | foreground categories are kept while the other 5 foreground categories are
13 | handled as background. The ``img_suffix`` is fixed to '.jpg' and
14 | ``seg_map_suffix`` is fixed to '.png'.
15 | """
16 | METAINFO = dict(
17 | classes=('background', 'aorta', 'gallbladder', 'left_kidney',
18 | 'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'),
19 | palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0],
20 | [0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255],
21 | [240, 240, 240]])
22 |
23 | def __init__(self,
24 | img_suffix='.jpg',
25 | seg_map_suffix='.png',
26 | **kwargs) -> None:
27 | super().__init__(
28 | img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
29 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .formatting import PackSegInputs
3 | from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
4 | LoadBiomedicalData, LoadBiomedicalImageFromFile,
5 | LoadImageFromNDArray, LoadMultipleRSImageFromFile,
6 | LoadSingleRSImageFromFile, LoadImageFromNpyFile)
7 | # yapf: disable
8 | from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
9 | BioMedical3DRandomCrop, BioMedical3DRandomFlip,
10 | BioMedicalGaussianBlur, BioMedicalGaussianNoise,
11 | BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
12 | PhotoMetricDistortion, RandomCrop, RandomCutOut,
13 | RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
14 | ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
15 | SegRescale)
16 |
17 | # yapf: enable
18 | __all__ = [
19 | 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
20 | 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
21 | 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
22 | 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
23 | 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'LoadImageFromNpyFile',
24 | 'GenerateEdge', 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
25 | 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
26 | 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
27 | 'LoadMultipleRSImageFromFile'
28 | ]
29 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/datasets/voc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | import mmengine.fileio as fileio
5 |
6 | from mmseg.registry import DATASETS
7 | from .basesegdataset import BaseSegDataset
8 |
9 |
10 | @DATASETS.register_module()
11 | class PascalVOCDataset(BaseSegDataset):
12 | """Pascal VOC dataset.
13 |
14 | Args:
15 | split (str): Split txt file for Pascal VOC.
16 | """
17 | METAINFO = dict(
18 | classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat',
19 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
20 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
21 | 'sofa', 'train', 'tvmonitor'),
22 | palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
23 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
24 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
25 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
26 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
27 | [0, 64, 128]])
28 |
29 | def __init__(self,
30 | ann_file,
31 | img_suffix='.jpg',
32 | seg_map_suffix='.png',
33 | **kwargs) -> None:
34 | super().__init__(
35 | img_suffix=img_suffix,
36 | seg_map_suffix=seg_map_suffix,
37 | ann_file=ann_file,
38 | **kwargs)
39 | assert fileio.exists(self.data_prefix['img_path'],
40 | self.backend_args) and osp.isfile(self.ann_file)
41 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/engine/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .hooks import SegVisualizationHook
3 | from .optimizers import (LayerDecayOptimizerConstructor,
4 | LearningRateDecayOptimizerConstructor)
5 |
6 | __all__ = [
7 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
8 | 'SegVisualizationHook'
9 | ]
10 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/engine/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .visualization_hook import SegVisualizationHook
3 | from .resetmodel_hook import ResetModelHook
4 |
5 | __all__ = ['SegVisualizationHook', 'ResetModelHook']
6 | # __all__ = ['SegVisualizationHook']
--------------------------------------------------------------------------------
/Segmentation/mmseg/engine/hooks/cal_firing_rate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from spikingjelly.clock_driven import functional
3 | from typing import Optional, Sequence
4 | from mmengine.hooks import Hook
5 |
6 | from mmseg.registry import HOOKS
7 |
8 |
9 | @HOOKS.register_module()
10 | class Get_lif_firing_num(Hook):
11 | """Docstring for NewHook.
12 | """
13 | def __init__(self, **kwargs):
14 | super(Get_lif_firing_num, self).__init__(
15 | **kwargs)
16 |
17 | # def before_train_iter(self,
18 | # runner,
19 | # batch_idx: int,
20 | # data_batch: Optional[Sequence[dict]] = None) -> None:
21 | # import pdb; pdb.set_trace()
22 |
23 | # def after_train_iter(self,
24 | # runner,
25 | # batch_idx: int,
26 | # outputs: None,
27 | # data_batch: Optional[Sequence[dict]] = None) -> None:
28 | # torch.cuda.synchronize()
29 | # functional.reset_net(runner.model)
30 |
31 | def before_test_iter(self,
32 | runner,
33 | batch_idx: int,
34 | data_batch: Optional[Sequence[dict]] = None) -> None:
35 | import pdb; pdb.set_trace()
36 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/engine/hooks/resetmodel_hook.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from spikingjelly.clock_driven import functional
3 | from typing import Optional, Sequence
4 | from mmengine.hooks import Hook
5 |
6 | from mmseg.registry import HOOKS
7 |
8 |
9 | @HOOKS.register_module()
10 | class ResetModelHook(Hook):
11 | """Docstring for NewHook.
12 | """
13 | def __init__(self, **kwargs):
14 | super(ResetModelHook, self).__init__(
15 | **kwargs)
16 |
17 | def before_train_iter(self,
18 | runner,
19 | batch_idx: int,
20 | data_batch: Optional[Sequence[dict]] = None) -> None:
21 | torch.cuda.synchronize()
22 | functional.reset_net(runner.model)
23 |
24 |
25 | def before_val_iter(self,
26 | runner,
27 | batch_idx: int,
28 | data_batch: Optional[Sequence[dict]] = None) -> None:
29 | torch.cuda.synchronize()
30 | functional.reset_net(runner.model)
31 |
32 | def before_test_iter(self,
33 | runner,
34 | batch_idx: int,
35 | data_batch: Optional[Sequence[dict]] = None) -> None:
36 | torch.cuda.synchronize()
37 | functional.reset_net(runner.model)
38 |
39 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/engine/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .layer_decay_optimizer_constructor import (
3 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
4 |
5 | __all__ = [
6 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
7 | ]
8 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .metrics import CityscapesMetric, IoUMetric
3 |
4 | __all__ = ['IoUMetric', 'CityscapesMetric']
5 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/evaluation/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .citys_metric import CityscapesMetric
3 | from .iou_metric import IoUMetric
4 |
5 | __all__ = ['IoUMetric', 'CityscapesMetric']
6 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .backbones import * # noqa: F401,F403
3 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
4 | build_head, build_loss, build_segmentor)
5 | from .data_preprocessor import SegDataPreProcessor
6 | from .decode_heads import * # noqa: F401,F403
7 | from .losses import * # noqa: F401,F403
8 | from .necks import * # noqa: F401,F403
9 | from .segmentors import * # noqa: F401,F403
10 |
11 | __all__ = [
12 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
13 | 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
14 | ]
15 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 | from .sdtv2 import Spiking_vit_MetaFormer
4 | from .sdtv3 import Spiking_vit_MetaFormerv2
5 | from .sdtv3MAE import Spiking_vit_MetaFormerv3
6 |
7 | __all__ = [
8 | 'Spiking_vit_MetaFormer', 'Spiking_vit_MetaFormerv2', 'Spiking_vit_MetaFormerv3'
9 | ]
10 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | from mmseg.registry import MODELS
5 |
6 | BACKBONES = MODELS
7 | NECKS = MODELS
8 | HEADS = MODELS
9 | LOSSES = MODELS
10 | SEGMENTORS = MODELS
11 |
12 |
13 | def build_backbone(cfg):
14 | """Build backbone."""
15 | warnings.warn('``build_backbone`` would be deprecated soon, please use '
16 | '``mmseg.registry.MODELS.build()`` ')
17 | return BACKBONES.build(cfg)
18 |
19 |
20 | def build_neck(cfg):
21 | """Build neck."""
22 | warnings.warn('``build_neck`` would be deprecated soon, please use '
23 | '``mmseg.registry.MODELS.build()`` ')
24 | return NECKS.build(cfg)
25 |
26 |
27 | def build_head(cfg):
28 | """Build head."""
29 | warnings.warn('``build_head`` would be deprecated soon, please use '
30 | '``mmseg.registry.MODELS.build()`` ')
31 | return HEADS.build(cfg)
32 |
33 |
34 | def build_loss(cfg):
35 | """Build loss."""
36 | warnings.warn('``build_loss`` would be deprecated soon, please use '
37 | '``mmseg.registry.MODELS.build()`` ')
38 | return LOSSES.build(cfg)
39 |
40 |
41 | def build_segmentor(cfg, train_cfg=None, test_cfg=None):
42 | """Build segmentor."""
43 | if train_cfg is not None or test_cfg is not None:
44 | warnings.warn(
45 | 'train_cfg and test_cfg is deprecated, '
46 | 'please specify them in model', UserWarning)
47 | assert cfg.get('train_cfg') is None or train_cfg is None, \
48 | 'train_cfg specified in both outer field and model field '
49 | assert cfg.get('test_cfg') is None or test_cfg is None, \
50 | 'test_cfg specified in both outer field and model field '
51 | return SEGMENTORS.build(
52 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
53 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/decode_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .fpn_head import *
3 | from .mask2former_head import Mask2FormerHead
4 | from .maskformer_head import MaskFormerHead
5 | from mmdet.models import *
6 |
7 |
8 | __all__ = [
9 | 'FPNHead',
10 | 'MaskFormerHead', 'Mask2FormerHead',
11 | 'FPNHead_SNN', 'QFPNHead'
12 | ]
13 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .accuracy import Accuracy, accuracy
3 | from .boundary_loss import BoundaryLoss
4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
5 | cross_entropy, mask_cross_entropy)
6 | from .dice_loss import DiceLoss
7 | from .focal_loss import FocalLoss
8 | from .huasdorff_distance_loss import HuasdorffDisstanceLoss
9 | from .lovasz_loss import LovaszLoss
10 | from .ohem_cross_entropy_loss import OhemCrossEntropy
11 | from .tversky_loss import TverskyLoss
12 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss
13 |
14 | __all__ = [
15 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
16 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
17 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
18 | 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
19 | 'HuasdorffDisstanceLoss'
20 | ]
21 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/losses/boundary_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 |
7 | from mmseg.registry import MODELS
8 |
9 |
10 | @MODELS.register_module()
11 | class BoundaryLoss(nn.Module):
12 | """Boundary loss.
13 |
14 | This function is modified from
15 | `PIDNet `_. # noqa
16 | Licensed under the MIT License.
17 |
18 |
19 | Args:
20 | loss_weight (float): Weight of the loss. Defaults to 1.0.
21 | loss_name (str): Name of the loss item. If you want this loss
22 | item to be included into the backward graph, `loss_` must be the
23 | prefix of the name. Defaults to 'loss_boundary'.
24 | """
25 |
26 | def __init__(self,
27 | loss_weight: float = 1.0,
28 | loss_name: str = 'loss_boundary'):
29 | super().__init__()
30 | self.loss_weight = loss_weight
31 | self.loss_name_ = loss_name
32 |
33 | def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
34 | """Forward function.
35 | Args:
36 | bd_pre (Tensor): Predictions of the boundary head.
37 | bd_gt (Tensor): Ground truth of the boundary.
38 |
39 | Returns:
40 | Tensor: Loss tensor.
41 | """
42 | log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
43 | target_t = bd_gt.view(1, -1).float()
44 |
45 | pos_index = (target_t == 1)
46 | neg_index = (target_t == 0)
47 |
48 | weight = torch.zeros_like(log_p)
49 | pos_num = pos_index.sum()
50 | neg_num = neg_index.sum()
51 | sum_num = pos_num + neg_num
52 | weight[pos_index] = neg_num * 1.0 / sum_num
53 | weight[neg_index] = pos_num * 1.0 / sum_num
54 |
55 | loss = F.binary_cross_entropy_with_logits(
56 | log_p, target_t, weight, reduction='mean')
57 |
58 | return self.loss_weight * loss
59 |
60 | @property
61 | def loss_name(self):
62 | return self.loss_name_
63 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/necks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .fpn import *
3 |
4 | __all__ = [
5 | 'FPN', 'FPN_SNN', 'QFPN'
6 | ]
7 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/segmentors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base import BaseSegmentor
3 | from .cascade_encoder_decoder import CascadeEncoderDecoder
4 | from .encoder_decoder import EncoderDecoder
5 | from .seg_tta import SegTTAModel
6 |
7 | __all__ = [
8 | 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel'
9 | ]
10 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/segmentors/seg_tta.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List
3 |
4 | import torch
5 | from mmengine.model import BaseTTAModel
6 | from mmengine.structures import PixelData
7 |
8 | from mmseg.registry import MODELS
9 | from mmseg.utils import SampleList
10 |
11 |
12 | @MODELS.register_module()
13 | class SegTTAModel(BaseTTAModel):
14 |
15 | def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
16 | """Merge predictions of enhanced data to one prediction.
17 |
18 | Args:
19 | data_samples_list (List[SampleList]): List of predictions
20 | of all enhanced data.
21 |
22 | Returns:
23 | SampleList: Merged prediction.
24 | """
25 | predictions = []
26 | for data_samples in data_samples_list:
27 | seg_logits = data_samples[0].seg_logits.data
28 | logits = torch.zeros(seg_logits.shape).to(seg_logits)
29 | for data_sample in data_samples:
30 | seg_logit = data_sample.seg_logits.data
31 | if self.module.out_channels > 1:
32 | logits += seg_logit.softmax(dim=0)
33 | else:
34 | logits += seg_logit.sigmoid()
35 | logits /= len(data_samples)
36 | if self.module.out_channels == 1:
37 | seg_pred = (logits > self.module.decode_head.threshold
38 | ).to(logits).squeeze(1)
39 | else:
40 | seg_pred = logits.argmax(dim=0)
41 | data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)})
42 | if hasattr(data_samples[0], 'gt_sem_seg'):
43 | data_sample.set_data(
44 | {'gt_sem_seg': data_samples[0].gt_sem_seg})
45 | data_sample.set_metainfo({'img_path': data_samples[0].img_path})
46 | predictions.append(data_sample)
47 | return predictions
48 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/utils/Qtrick.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | quant = 4
5 | T = quant
6 |
7 |
8 | class Quant(torch.autograd.Function):
9 | @staticmethod
10 | @torch.cuda.amp.custom_fwd
11 | def forward(ctx, i, min_value=0, max_value=quant):
12 | ctx.min = min_value
13 | ctx.max = max_value
14 | ctx.save_for_backward(i)
15 | return torch.round(torch.clamp(i, min=min_value, max=max_value))
16 |
17 | @staticmethod
18 | @torch.cuda.amp.custom_fwd
19 | def backward(ctx, grad_output):
20 | grad_input = grad_output.clone()
21 | i, = ctx.saved_tensors
22 | grad_input[i < ctx.min] = 0
23 | grad_input[i > ctx.max] = 0
24 | return grad_input, None, None
25 |
26 |
27 | class Multispike_norm(nn.Module):
28 | def __init__(
29 | self,
30 | T=T, # 在T上进行Norm
31 | ):
32 | super().__init__()
33 | self.spike = Quant()
34 | self.T = T
35 |
36 | def forward(self, x):
37 |
38 | return self.spike.apply(x) / self.T
39 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .basic_block import BasicBlock, Bottleneck
3 | from .embed import PatchEmbed
4 | from .encoding import Encoding
5 | from .inverted_residual import InvertedResidual, InvertedResidualV3
6 | from .make_divisible import make_divisible
7 | from .ppm import DAPPM, PAPPM
8 | from .res_layer import ResLayer
9 | from .se_layer import SELayer
10 | from .self_attention_block import SelfAttentionBlock
11 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
12 | nlc_to_nchw)
13 | from .up_conv_block import UpConvBlock
14 | from .wrappers import Upsample, resize
15 | from .point_sample import (get_uncertain_point_coords_with_randomness,
16 | get_uncertainty)
17 | from .Qtrick import Multispike_norm
18 |
19 | # NOTE: from mmdet
20 |
21 | from .panoptic_gt_processing import preprocess_panoptic_gt
22 | from .misc import (aligned_bilinear, center_of_mass, empty_instances,
23 | filter_gt_instances, filter_scores_and_topk, flip_tensor,
24 | generate_coordinate, images_to_levels, interpolate_as,
25 | levels_to_images, mask2ndarray, multi_apply,
26 | relative_coordinate_maps, rename_loss_dict,
27 | reweight_loss_dict, samplelist_boxtype2tensor,
28 | select_single_mlvl, sigmoid_geometric_mean,
29 | unfold_wo_center, unmap, unpack_gt_instances)
30 |
31 | __all__ = [
32 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
33 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
34 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
35 | 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck',
36 | 'get_uncertain_point_coords_with_randomness', 'get_uncertainty',
37 | 'multi_apply', 'preprocess_panoptic_gt', "Multispike_norm"
38 | ]
39 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/utils/make_divisible.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
3 | """Make divisible function.
4 |
5 | This function rounds the channel number to the nearest value that can be
6 | divisible by the divisor. It is taken from the original tf repo. It ensures
7 | that all layers have a channel number that is divisible by divisor. It can
8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
9 |
10 | Args:
11 | value (int): The original channel number.
12 | divisor (int): The divisor to fully divide the channel number.
13 | min_value (int): The minimum value of the output channel.
14 | Default: None, means that the minimum value equal to the divisor.
15 | min_ratio (float): The minimum ratio of the rounded channel number to
16 | the original channel number. Default: 0.9.
17 |
18 | Returns:
19 | int: The modified output channel number.
20 | """
21 |
22 | if min_value is None:
23 | min_value = divisor
24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
25 | # Make sure that round down does not go down by more than (1-min_ratio).
26 | if new_value < min_ratio * value:
27 | new_value += divisor
28 | return new_value
29 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/models/utils/wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def resize(input,
9 | size=None,
10 | scale_factor=None,
11 | mode='nearest',
12 | align_corners=None,
13 | warning=True):
14 | if warning:
15 | if size is not None and align_corners:
16 | input_h, input_w = tuple(int(x) for x in input.shape[2:])
17 | output_h, output_w = tuple(int(x) for x in size)
18 | if output_h > input_h or output_w > output_h:
19 | if ((output_h > 1 and output_w > 1 and input_h > 1
20 | and input_w > 1) and (output_h - 1) % (input_h - 1)
21 | and (output_w - 1) % (input_w - 1)):
22 | warnings.warn(
23 | f'When align_corners={align_corners}, '
24 | 'the output would more aligned if '
25 | f'input size {(input_h, input_w)} is `x+1` and '
26 | f'out size {(output_h, output_w)} is `nx+1`')
27 |
28 | return F.interpolate(input, size, scale_factor, mode, align_corners)
29 |
30 |
31 | class Upsample(nn.Module):
32 |
33 | def __init__(self,
34 | size=None,
35 | scale_factor=None,
36 | mode='nearest',
37 | align_corners=None):
38 | super().__init__()
39 | self.size = size
40 | if isinstance(scale_factor, tuple):
41 | self.scale_factor = tuple(float(factor) for factor in scale_factor)
42 | else:
43 | self.scale_factor = float(scale_factor) if scale_factor else None
44 | self.mode = mode
45 | self.align_corners = align_corners
46 |
47 | def forward(self, x):
48 | if not self.size:
49 | size = [int(t * self.scale_factor) for t in x.shape[-2:]]
50 | else:
51 | size = self.size
52 | return resize(x, size, None, self.mode, self.align_corners)
53 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/registry/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS,
3 | LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
4 | OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
5 | PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
6 | TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
7 | WEIGHT_INITIALIZERS)
8 |
9 | __all__ = [
10 | 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS',
11 | 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS',
12 | 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS',
13 | 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS',
14 | 'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS'
15 | ]
16 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/structures/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
3 | from .seg_data_sample import SegDataSample
4 |
5 | __all__ = [
6 | 'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler',
7 | 'build_pixel_sampler'
8 | ]
9 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/structures/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_pixel_sampler import BasePixelSampler
3 | from .builder import build_pixel_sampler
4 | from .ohem_pixel_sampler import OHEMPixelSampler
5 |
6 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
7 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/structures/sampler/base_pixel_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BasePixelSampler(metaclass=ABCMeta):
6 | """Base class of pixel sampler."""
7 |
8 | def __init__(self, **kwargs):
9 | pass
10 |
11 | @abstractmethod
12 | def sample(self, seg_logit, seg_label):
13 | """Placeholder for sample function."""
14 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/structures/sampler/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | from mmseg.registry import TASK_UTILS
5 |
6 | PIXEL_SAMPLERS = TASK_UTILS
7 |
8 |
9 | def build_pixel_sampler(cfg, **default_args):
10 | """Build pixel sampler for segmentation map."""
11 | warnings.warn(
12 | '``build_pixel_sampler`` would be deprecated soon, please use '
13 | '``mmseg.registry.TASK_UTILS.build()`` ')
14 | return TASK_UTILS.build(cfg, default_args=default_args)
15 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # yapf: disable
3 | from .class_names import (ade_classes, ade_palette, cityscapes_classes,
4 | cityscapes_palette, cocostuff_classes,
5 | cocostuff_palette, dataset_aliases, get_classes,
6 | get_palette, isaid_classes, isaid_palette,
7 | loveda_classes, loveda_palette, potsdam_classes,
8 | potsdam_palette, stare_classes, stare_palette,
9 | synapse_classes, synapse_palette, vaihingen_classes,
10 | vaihingen_palette, voc_classes, voc_palette)
11 | # yapf: enable
12 | from .collect_env import collect_env
13 | from .io import datafrombytes
14 | from .misc import add_prefix, stack_batch, multi_apply
15 | from .set_env import register_all_modules
16 | from .typing_utils import (ConfigType, ForwardResults, MultiConfig,
17 | OptConfigType, OptMultiConfig, OptSampleList,
18 | SampleList, TensorDict, TensorList)
19 |
20 | __all__ = [
21 | 'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix',
22 | 'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
23 | 'SampleList', 'OptSampleList', 'TensorDict', 'TensorList',
24 | 'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes',
25 | 'cocostuff_classes', 'loveda_classes', 'potsdam_classes',
26 | 'vaihingen_classes', 'isaid_classes', 'stare_classes',
27 | 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
28 | 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
29 | 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette',
30 | 'datafrombytes', 'synapse_palette', 'synapse_classes', 'multi_apply',
31 | ]
32 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmengine.utils import get_git_hash
3 | from mmengine.utils.dl_utils import collect_env as collect_base_env
4 |
5 | import mmseg
6 |
7 |
8 | def collect_env():
9 | """Collect the information of the running environments."""
10 | env_info = collect_base_env()
11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
12 |
13 | return env_info
14 |
15 |
16 | if __name__ == '__main__':
17 | for name, val in collect_env().items():
18 | print(f'{name}: {val}')
19 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/utils/io.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import gzip
3 | import io
4 | import pickle
5 |
6 | import numpy as np
7 |
8 |
9 | def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
10 | """Data decoding from bytes.
11 |
12 | Args:
13 | content (bytes): The data bytes got from files or other streams.
14 | backend (str): The data decoding backend type. Options are 'numpy',
15 | 'nifti' and 'pickle'. Defaults to 'numpy'.
16 |
17 | Returns:
18 | numpy.ndarray: Loaded data array.
19 | """
20 | if backend == 'pickle':
21 | data = pickle.loads(content)
22 | else:
23 | with io.BytesIO(content) as f:
24 | if backend == 'nifti':
25 | f = gzip.open(f)
26 | try:
27 | from nibabel import FileHolder, Nifti1Image
28 | except ImportError:
29 | print('nifti files io depends on nibabel, please run'
30 | '`pip install nibabel` to install it')
31 | fh = FileHolder(fileobj=f)
32 | data = Nifti1Image.from_file_map({'header': fh, 'image': fh})
33 | data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
34 | elif backend == 'numpy':
35 | data = np.load(f)
36 | else:
37 | raise ValueError
38 | return data
39 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/utils/set_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import datetime
3 | import warnings
4 |
5 | from mmengine import DefaultScope
6 |
7 |
8 | def register_all_modules(init_default_scope: bool = True) -> None:
9 | """Register all modules in mmseg into the registries.
10 |
11 | Args:
12 | init_default_scope (bool): Whether initialize the mmseg default scope.
13 | When `init_default_scope=True`, the global default scope will be
14 | set to `mmseg`, and all registries will build modules from mmseg's
15 | registry node. To understand more about the registry, please refer
16 | to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
17 | Defaults to True.
18 | """ # noqa
19 | import mmseg.datasets # noqa: F401,F403
20 | import mmseg.engine # noqa: F401,F403
21 | import mmseg.evaluation # noqa: F401,F403
22 | import mmseg.models # noqa: F401,F403
23 | import mmseg.structures # noqa: F401,F403
24 |
25 | if init_default_scope:
26 | never_created = DefaultScope.get_current_instance() is None \
27 | or not DefaultScope.check_instance_created('mmseg')
28 | if never_created:
29 | DefaultScope.get_instance('mmseg', scope_name='mmseg')
30 | return
31 | current_scope = DefaultScope.get_current_instance()
32 | if current_scope.scope_name != 'mmseg':
33 | warnings.warn('The current default scope '
34 | f'"{current_scope.scope_name}" is not "mmseg", '
35 | '`register_all_modules` will force the current'
36 | 'default scope to be "mmseg". If this is not '
37 | 'expected, please set `init_default_scope=False`.')
38 | # avoid name conflict
39 | new_instance_name = f'mmseg-{datetime.datetime.now()}'
40 | DefaultScope.get_instance(new_instance_name, scope_name='mmseg')
41 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/utils/typing_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """Collecting some commonly used type hint in mmflow."""
3 | from typing import Dict, List, Optional, Sequence, Tuple, Union
4 |
5 | import torch
6 | from mmengine.config import ConfigDict
7 |
8 | from mmseg.structures import SegDataSample
9 |
10 | # Type hint of config data
11 | ConfigType = Union[ConfigDict, dict]
12 | OptConfigType = Optional[ConfigType]
13 | # Type hint of one or more config data
14 | MultiConfig = Union[ConfigType, Sequence[ConfigType]]
15 | OptMultiConfig = Optional[MultiConfig]
16 |
17 | SampleList = Sequence[SegDataSample]
18 | OptSampleList = Optional[SampleList]
19 |
20 | # Type hint of Tensor
21 | TensorDict = Dict[str, torch.Tensor]
22 | TensorList = Sequence[torch.Tensor]
23 |
24 | ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample],
25 | Tuple[torch.Tensor], torch.Tensor]
26 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 |
3 | __version__ = '1.1.1'
4 |
5 |
6 | def parse_version_info(version_str):
7 | version_info = []
8 | for x in version_str.split('.'):
9 | if x.isdigit():
10 | version_info.append(int(x))
11 | elif x.find('rc') != -1:
12 | patch_version = x.split('rc')
13 | version_info.append(int(patch_version[0]))
14 | version_info.append(f'rc{patch_version[1]}')
15 | return tuple(version_info)
16 |
17 |
18 | version_info = parse_version_info(__version__)
19 |
--------------------------------------------------------------------------------
/Segmentation/mmseg/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .local_visualizer import SegLocalVisualizer
3 |
4 | __all__ = ['SegLocalVisualizer']
5 |
--------------------------------------------------------------------------------
/Segmentation/model-index.yml:
--------------------------------------------------------------------------------
1 | Import:
2 | - configs/ann/metafile.yaml
3 | - configs/apcnet/metafile.yaml
4 | - configs/beit/metafile.yaml
5 | - configs/bisenetv1/metafile.yaml
6 | - configs/bisenetv2/metafile.yaml
7 | - configs/ccnet/metafile.yaml
8 | - configs/cgnet/metafile.yaml
9 | - configs/convnext/metafile.yaml
10 | - configs/danet/metafile.yaml
11 | - configs/ddrnet/metafile.yaml
12 | - configs/deeplabv3/metafile.yaml
13 | - configs/deeplabv3plus/metafile.yaml
14 | - configs/dmnet/metafile.yaml
15 | - configs/dnlnet/metafile.yaml
16 | - configs/dpt/metafile.yaml
17 | - configs/emanet/metafile.yaml
18 | - configs/encnet/metafile.yaml
19 | - configs/erfnet/metafile.yaml
20 | - configs/fastfcn/metafile.yaml
21 | - configs/fastscnn/metafile.yaml
22 | - configs/fcn/metafile.yaml
23 | - configs/gcnet/metafile.yaml
24 | - configs/hrnet/metafile.yaml
25 | - configs/icnet/metafile.yaml
26 | - configs/isanet/metafile.yaml
27 | - configs/knet/metafile.yaml
28 | - configs/mae/metafile.yaml
29 | - configs/mask2former/metafile.yaml
30 | - configs/maskformer/metafile.yaml
31 | - configs/mobilenet_v2/metafile.yaml
32 | - configs/mobilenet_v3/metafile.yaml
33 | - configs/nonlocal_net/metafile.yaml
34 | - configs/ocrnet/metafile.yaml
35 | - configs/pidnet/metafile.yaml
36 | - configs/point_rend/metafile.yaml
37 | - configs/poolformer/metafile.yaml
38 | - configs/psanet/metafile.yaml
39 | - configs/pspnet/metafile.yaml
40 | - configs/resnest/metafile.yaml
41 | - configs/segformer/metafile.yaml
42 | - configs/segmenter/metafile.yaml
43 | - configs/segnext/metafile.yaml
44 | - configs/sem_fpn/metafile.yaml
45 | - configs/setr/metafile.yaml
46 | - configs/stdc/metafile.yaml
47 | - configs/swin/metafile.yaml
48 | - configs/twins/metafile.yaml
49 | - configs/unet/metafile.yaml
50 | - configs/upernet/metafile.yaml
51 | - configs/vit/metafile.yaml
52 |
--------------------------------------------------------------------------------
/Segmentation/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/optional.txt
2 | -r requirements/runtime.txt
3 | -r requirements/tests.txt
4 |
--------------------------------------------------------------------------------
/Segmentation/requirements/albu.txt:
--------------------------------------------------------------------------------
1 | albumentations>=0.3.2 --no-binary qudida,albumentations
2 |
--------------------------------------------------------------------------------
/Segmentation/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | docutils==0.16.0
2 | myst-parser
3 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
4 | sphinx==4.0.2
5 | sphinx_copybutton
6 | sphinx_markdown_tables
7 | urllib3<2.0.0
8 |
--------------------------------------------------------------------------------
/Segmentation/requirements/mminstall.txt:
--------------------------------------------------------------------------------
1 | mmcv>=2.0.0rc4
2 | mmengine>=0.5.0,<1.0.0
3 |
--------------------------------------------------------------------------------
/Segmentation/requirements/optional.txt:
--------------------------------------------------------------------------------
1 | cityscapesscripts
2 | nibabel
3 |
--------------------------------------------------------------------------------
/Segmentation/requirements/readthedocs.txt:
--------------------------------------------------------------------------------
1 | mmcv>=2.0.0rc1,<2.1.0
2 | mmengine>=0.4.0,<1.0.0
3 | prettytable
4 | scipy
5 | torch
6 | torchvision
7 |
--------------------------------------------------------------------------------
/Segmentation/requirements/runtime.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | packaging
4 | prettytable
5 | scipy
6 |
--------------------------------------------------------------------------------
/Segmentation/requirements/tests.txt:
--------------------------------------------------------------------------------
1 | codecov
2 | flake8
3 | interrogate
4 | pytest
5 | xdoctest>=0.10.0
6 | yapf
7 |
--------------------------------------------------------------------------------
/Segmentation/setup.cfg:
--------------------------------------------------------------------------------
1 | [yapf]
2 | based_on_style = pep8
3 | blank_line_before_nested_class_or_def = true
4 | split_before_expression_after_opening_paren = true
5 |
6 | [isort]
7 | line_length = 79
8 | multi_line_output = 0
9 | extra_standard_library = setuptools
10 | known_first_party = mmseg
11 | known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,packaging,prettytable,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,torch,ts
12 | no_lines_before = STDLIB,LOCALFOLDER
13 | default_section = THIRDPARTY
14 |
15 | [codespell]
16 | skip = *.po,*.ts,*.ipynb
17 | count =
18 | quiet-level = 3
19 | ignore-words-list = formating,sur,hist,dota,warmup,damon
20 |
--------------------------------------------------------------------------------
/Segmentation/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_digit_version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmseg import digit_version
3 |
4 |
5 | def test_digit_version():
6 | assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
7 | assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
8 | assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
9 | assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
10 | assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
11 | assert digit_version('1.0') == digit_version('1.0.0')
12 | assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
13 | assert digit_version('1.0.0dev') < digit_version('1.0.0a')
14 | assert digit_version('1.0.0a') < digit_version('1.0.0a1')
15 | assert digit_version('1.0.0a') < digit_version('1.0.0b')
16 | assert digit_version('1.0.0b') < digit_version('1.0.0rc')
17 | assert digit_version('1.0.0rc1') < digit_version('1.0.0')
18 | assert digit_version('1.0.0') < digit_version('1.0.0post')
19 | assert digit_version('1.0.0post') < digit_version('1.0.0post1')
20 | assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
21 | assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
22 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_engine/test_optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | from mmengine.optim import build_optim_wrapper
5 |
6 |
7 | class ExampleModel(nn.Module):
8 |
9 | def __init__(self):
10 | super().__init__()
11 | self.param1 = nn.Parameter(torch.ones(1))
12 | self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
13 | self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
14 | self.bn = nn.BatchNorm2d(2)
15 |
16 | def forward(self, x):
17 | return x
18 |
19 |
20 | base_lr = 0.01
21 | base_wd = 0.0001
22 | momentum = 0.9
23 |
24 |
25 | def test_build_optimizer():
26 | model = ExampleModel()
27 | optim_wrapper_cfg = dict(
28 | type='OptimWrapper',
29 | optimizer=dict(
30 | type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
31 | optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
32 | # test whether optimizer is successfully built from parent.
33 | assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
34 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/test_bisenetv2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from mmcv.cnn import ConvModule
4 |
5 | from mmseg.models.backbones import BiSeNetV2
6 | from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch,
7 | SemanticBranch)
8 |
9 |
10 | def test_bisenetv2_backbone():
11 | # Test BiSeNetV2 Standard Forward
12 | model = BiSeNetV2()
13 | model.init_weights()
14 | model.train()
15 | batch_size = 2
16 | imgs = torch.randn(batch_size, 3, 128, 256)
17 | feat = model(imgs)
18 |
19 | assert len(feat) == 5
20 | # output for segment Head
21 | assert feat[0].shape == torch.Size([batch_size, 128, 16, 32])
22 | # for auxiliary head 1
23 | assert feat[1].shape == torch.Size([batch_size, 16, 32, 64])
24 | # for auxiliary head 2
25 | assert feat[2].shape == torch.Size([batch_size, 32, 16, 32])
26 | # for auxiliary head 3
27 | assert feat[3].shape == torch.Size([batch_size, 64, 8, 16])
28 | # for auxiliary head 4
29 | assert feat[4].shape == torch.Size([batch_size, 128, 4, 8])
30 |
31 | # Test input with rare shape
32 | batch_size = 2
33 | imgs = torch.randn(batch_size, 3, 95, 27)
34 | feat = model(imgs)
35 | assert len(feat) == 5
36 |
37 |
38 | def test_bisenetv2_DetailBranch():
39 | x = torch.randn(1, 3, 32, 64)
40 | detail_branch = DetailBranch(detail_channels=(64, 16, 32))
41 | assert isinstance(detail_branch.detail_branch[0][0], ConvModule)
42 | x_out = detail_branch(x)
43 | assert x_out.shape == torch.Size([1, 32, 4, 8])
44 |
45 |
46 | def test_bisenetv2_SemanticBranch():
47 | semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128))
48 | assert semantic_branch.stage1.pool.stride == 2
49 |
50 |
51 | def test_bisenetv2_BGALayer():
52 | x_a = torch.randn(1, 8, 8, 16)
53 | x_b = torch.randn(1, 8, 2, 4)
54 | bga = BGALayer(out_channels=8)
55 | assert isinstance(bga.conv, ConvModule)
56 | x_out = bga(x_a, x_b)
57 | assert x_out.shape == torch.Size([1, 8, 8, 16])
58 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/test_fast_scnn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.backbones import FastSCNN
6 |
7 |
8 | def test_fastscnn_backbone():
9 | with pytest.raises(AssertionError):
10 | # Fast-SCNN channel constraints.
11 | FastSCNN(
12 | 3, (32, 48),
13 | 64, (64, 96, 128), (2, 2, 1),
14 | global_out_channels=127,
15 | higher_in_channels=64,
16 | lower_in_channels=128)
17 |
18 | # Test FastSCNN Standard Forward
19 | model = FastSCNN(
20 | in_channels=3,
21 | downsample_dw_channels=(4, 6),
22 | global_in_channels=8,
23 | global_block_channels=(8, 12, 16),
24 | global_block_strides=(2, 2, 1),
25 | global_out_channels=16,
26 | higher_in_channels=8,
27 | lower_in_channels=16,
28 | fusion_out_channels=16,
29 | )
30 | model.init_weights()
31 | model.train()
32 | batch_size = 4
33 | imgs = torch.randn(batch_size, 3, 64, 128)
34 | feat = model(imgs)
35 |
36 | assert len(feat) == 3
37 | # higher-res
38 | assert feat[0].shape == torch.Size([batch_size, 8, 8, 16])
39 | # lower-res
40 | assert feat[1].shape == torch.Size([batch_size, 16, 2, 4])
41 | # FFM output
42 | assert feat[2].shape == torch.Size([batch_size, 16, 8, 16])
43 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/test_icnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.backbones import ICNet
6 |
7 |
8 | def test_icnet_backbone():
9 | with pytest.raises(TypeError):
10 | # Must give backbone dict in config file.
11 | ICNet(
12 | in_channels=3,
13 | layer_channels=(128, 512),
14 | light_branch_middle_channels=8,
15 | psp_out_channels=128,
16 | out_channels=(16, 128, 128),
17 | backbone_cfg=None)
18 |
19 | # Test ICNet Standard Forward
20 | model = ICNet(
21 | layer_channels=(128, 512),
22 | backbone_cfg=dict(
23 | type='ResNetV1c',
24 | in_channels=3,
25 | depth=18,
26 | num_stages=4,
27 | out_indices=(0, 1, 2, 3),
28 | dilations=(1, 1, 2, 4),
29 | strides=(1, 2, 1, 1),
30 | norm_cfg=dict(type='BN', requires_grad=True),
31 | norm_eval=False,
32 | style='pytorch',
33 | contract_dilation=True),
34 | )
35 | assert hasattr(model.backbone,
36 | 'maxpool') and model.backbone.maxpool.ceil_mode is True
37 | model.init_weights()
38 | model.train()
39 | batch_size = 2
40 | imgs = torch.randn(batch_size, 3, 32, 64)
41 | feat = model(imgs)
42 |
43 | assert model.psp_modules[0][0].output_size == 1
44 | assert model.psp_modules[1][0].output_size == 2
45 | assert model.psp_modules[2][0].output_size == 3
46 | assert model.psp_bottleneck.padding == 1
47 | assert model.conv_sub1[0].padding == 1
48 |
49 | assert len(feat) == 3
50 | assert feat[0].shape == torch.Size([batch_size, 64, 4, 8])
51 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/test_mobilenet_v3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.backbones import MobileNetV3
6 |
7 |
8 | def test_mobilenet_v3():
9 | with pytest.raises(AssertionError):
10 | # check invalid arch
11 | MobileNetV3('big')
12 |
13 | with pytest.raises(AssertionError):
14 | # check invalid reduction_factor
15 | MobileNetV3(reduction_factor=0)
16 |
17 | with pytest.raises(ValueError):
18 | # check invalid out_indices
19 | MobileNetV3(out_indices=(0, 1, 15))
20 |
21 | with pytest.raises(ValueError):
22 | # check invalid frozen_stages
23 | MobileNetV3(frozen_stages=15)
24 |
25 | with pytest.raises(TypeError):
26 | # check invalid pretrained
27 | model = MobileNetV3()
28 | model.init_weights(pretrained=8)
29 |
30 | # Test MobileNetV3 with default settings
31 | model = MobileNetV3()
32 | model.init_weights()
33 | model.train()
34 |
35 | imgs = torch.randn(2, 3, 56, 56)
36 | feat = model(imgs)
37 | assert len(feat) == 3
38 | assert feat[0].shape == (2, 16, 28, 28)
39 | assert feat[1].shape == (2, 16, 14, 14)
40 | assert feat[2].shape == (2, 576, 7, 7)
41 |
42 | # Test MobileNetV3 with arch = 'large'
43 | model = MobileNetV3(arch='large', out_indices=(1, 3, 16))
44 | model.init_weights()
45 | model.train()
46 |
47 | imgs = torch.randn(2, 3, 56, 56)
48 | feat = model(imgs)
49 | assert len(feat) == 3
50 | assert feat[0].shape == (2, 16, 28, 28)
51 | assert feat[1].shape == (2, 24, 14, 14)
52 | assert feat[2].shape == (2, 960, 7, 7)
53 |
54 | # Test MobileNetV3 with norm_eval True, with_cp True and frozen_stages=5
55 | model = MobileNetV3(norm_eval=True, with_cp=True, frozen_stages=5)
56 | with pytest.raises(TypeError):
57 | # check invalid pretrained
58 | model.init_weights(pretrained=8)
59 | model.init_weights()
60 | model.train()
61 |
62 | imgs = torch.randn(2, 3, 56, 56)
63 | feat = model(imgs)
64 | assert len(feat) == 3
65 | assert feat[0].shape == (2, 16, 28, 28)
66 | assert feat[1].shape == (2, 16, 14, 14)
67 | assert feat[2].shape == (2, 576, 7, 7)
68 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/test_resnest.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.backbones import ResNeSt
6 | from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
7 |
8 |
9 | def test_resnest_bottleneck():
10 | with pytest.raises(AssertionError):
11 | # Style must be in ['pytorch', 'caffe']
12 | BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow')
13 |
14 | # Test ResNeSt Bottleneck structure
15 | block = BottleneckS(
16 | 64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
17 | assert block.avd_layer.stride == 2
18 | assert block.conv2.channels == 256
19 |
20 | # Test ResNeSt Bottleneck forward
21 | block = BottleneckS(64, 16, radix=2, reduction_factor=4)
22 | x = torch.randn(2, 64, 56, 56)
23 | x_out = block(x)
24 | assert x_out.shape == torch.Size([2, 64, 56, 56])
25 |
26 |
27 | def test_resnest_backbone():
28 | with pytest.raises(KeyError):
29 | # ResNeSt depth should be in [50, 101, 152, 200]
30 | ResNeSt(depth=18)
31 |
32 | # Test ResNeSt with radix 2, reduction_factor 4
33 | model = ResNeSt(
34 | depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
35 | model.init_weights()
36 | model.train()
37 |
38 | imgs = torch.randn(2, 3, 224, 224)
39 | feat = model(imgs)
40 | assert len(feat) == 4
41 | assert feat[0].shape == torch.Size([2, 256, 56, 56])
42 | assert feat[1].shape == torch.Size([2, 512, 28, 28])
43 | assert feat[2].shape == torch.Size([2, 1024, 14, 14])
44 | assert feat[3].shape == torch.Size([2, 2048, 7, 7])
45 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/test_resnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.backbones import ResNeXt
6 | from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
7 | from .utils import is_block
8 |
9 |
10 | def test_renext_bottleneck():
11 | with pytest.raises(AssertionError):
12 | # Style must be in ['pytorch', 'caffe']
13 | BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow')
14 |
15 | # Test ResNeXt Bottleneck structure
16 | block = BottleneckX(
17 | 64, 64, groups=32, base_width=4, stride=2, style='pytorch')
18 | assert block.conv2.stride == (2, 2)
19 | assert block.conv2.groups == 32
20 | assert block.conv2.out_channels == 128
21 |
22 | # Test ResNeXt Bottleneck with DCN
23 | dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
24 | with pytest.raises(AssertionError):
25 | # conv_cfg must be None if dcn is not None
26 | BottleneckX(
27 | 64,
28 | 64,
29 | groups=32,
30 | base_width=4,
31 | dcn=dcn,
32 | conv_cfg=dict(type='Conv'))
33 | BottleneckX(64, 64, dcn=dcn)
34 |
35 | # Test ResNeXt Bottleneck forward
36 | block = BottleneckX(64, 16, groups=32, base_width=4)
37 | x = torch.randn(1, 64, 56, 56)
38 | x_out = block(x)
39 | assert x_out.shape == torch.Size([1, 64, 56, 56])
40 |
41 |
42 | def test_resnext_backbone():
43 | with pytest.raises(KeyError):
44 | # ResNeXt depth should be in [50, 101, 152]
45 | ResNeXt(depth=18)
46 |
47 | # Test ResNeXt with group 32, base_width 4
48 | model = ResNeXt(depth=50, groups=32, base_width=4)
49 | print(model)
50 | for m in model.modules():
51 | if is_block(m):
52 | assert m.conv2.groups == 32
53 | model.init_weights()
54 | model.train()
55 |
56 | imgs = torch.randn(1, 3, 224, 224)
57 | feat = model(imgs)
58 | assert len(feat) == 4
59 | assert feat[0].shape == torch.Size([1, 256, 56, 56])
60 | assert feat[1].shape == torch.Size([1, 512, 28, 28])
61 | assert feat[2].shape == torch.Size([1, 1024, 14, 14])
62 | assert feat[3].shape == torch.Size([1, 2048, 7, 7])
63 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from torch.nn.modules import GroupNorm
4 | from torch.nn.modules.batchnorm import _BatchNorm
5 |
6 | from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
7 | from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
8 |
9 |
10 | def is_block(modules):
11 | """Check if is ResNet building block."""
12 | if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)):
13 | return True
14 | return False
15 |
16 |
17 | def is_norm(modules):
18 | """Check if is one of the norms."""
19 | if isinstance(modules, (GroupNorm, _BatchNorm)):
20 | return True
21 | return False
22 |
23 |
24 | def all_zeros(modules):
25 | """Check if the weight(and bias) is all zero."""
26 | weight_zero = torch.allclose(modules.weight.data,
27 | torch.zeros_like(modules.weight.data))
28 | if hasattr(modules, 'bias'):
29 | bias_zero = torch.allclose(modules.bias.data,
30 | torch.zeros_like(modules.bias.data))
31 | else:
32 | bias_zero = True
33 |
34 | return weight_zero and bias_zero
35 |
36 |
37 | def check_norm_state(modules, train_state):
38 | """Check if norm layer is in correct train state."""
39 | for mod in modules:
40 | if isinstance(mod, _BatchNorm):
41 | if mod.training != train_state:
42 | return False
43 | return True
44 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_ann_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import ANNHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_ann_head():
9 |
10 | inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 8, 21, 21)]
11 | head = ANNHead(
12 | in_channels=[4, 8],
13 | channels=2,
14 | num_classes=19,
15 | in_index=[-2, -1],
16 | project_channels=8)
17 | if torch.cuda.is_available():
18 | head, inputs = to_cuda(head, inputs)
19 | outputs = head(inputs)
20 | assert outputs.shape == (1, head.num_classes, 21, 21)
21 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_apc_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import APCHead
6 | from .utils import _conv_has_norm, to_cuda
7 |
8 |
9 | def test_apc_head():
10 |
11 | with pytest.raises(AssertionError):
12 | # pool_scales must be list|tuple
13 | APCHead(in_channels=8, channels=2, num_classes=19, pool_scales=1)
14 |
15 | # test no norm_cfg
16 | head = APCHead(in_channels=8, channels=2, num_classes=19)
17 | assert not _conv_has_norm(head, sync_bn=False)
18 |
19 | # test with norm_cfg
20 | head = APCHead(
21 | in_channels=8,
22 | channels=2,
23 | num_classes=19,
24 | norm_cfg=dict(type='SyncBN'))
25 | assert _conv_has_norm(head, sync_bn=True)
26 |
27 | # fusion=True
28 | inputs = [torch.randn(1, 8, 45, 45)]
29 | head = APCHead(
30 | in_channels=8,
31 | channels=2,
32 | num_classes=19,
33 | pool_scales=(1, 2, 3),
34 | fusion=True)
35 | if torch.cuda.is_available():
36 | head, inputs = to_cuda(head, inputs)
37 | assert head.fusion is True
38 | assert head.acm_modules[0].pool_scale == 1
39 | assert head.acm_modules[1].pool_scale == 2
40 | assert head.acm_modules[2].pool_scale == 3
41 | outputs = head(inputs)
42 | assert outputs.shape == (1, head.num_classes, 45, 45)
43 |
44 | # fusion=False
45 | inputs = [torch.randn(1, 8, 45, 45)]
46 | head = APCHead(
47 | in_channels=8,
48 | channels=2,
49 | num_classes=19,
50 | pool_scales=(1, 2, 3),
51 | fusion=False)
52 | if torch.cuda.is_available():
53 | head, inputs = to_cuda(head, inputs)
54 | assert head.fusion is False
55 | assert head.acm_modules[0].pool_scale == 1
56 | assert head.acm_modules[1].pool_scale == 2
57 | assert head.acm_modules[2].pool_scale == 3
58 | outputs = head(inputs)
59 | assert outputs.shape == (1, head.num_classes, 45, 45)
60 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_cc_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import CCHead
6 | from .utils import to_cuda
7 |
8 |
9 | def test_cc_head():
10 | head = CCHead(in_channels=16, channels=8, num_classes=19)
11 | assert len(head.convs) == 2
12 | assert hasattr(head, 'cca')
13 | if not torch.cuda.is_available():
14 | pytest.skip('CCHead requires CUDA')
15 | inputs = [torch.randn(1, 16, 23, 23)]
16 | head, inputs = to_cuda(head, inputs)
17 | outputs = head(inputs)
18 | assert outputs.shape == (1, head.num_classes, 23, 23)
19 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_dm_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import DMHead
6 | from .utils import _conv_has_norm, to_cuda
7 |
8 |
9 | def test_dm_head():
10 |
11 | with pytest.raises(AssertionError):
12 | # filter_sizes must be list|tuple
13 | DMHead(in_channels=8, channels=4, num_classes=19, filter_sizes=1)
14 |
15 | # test no norm_cfg
16 | head = DMHead(in_channels=8, channels=4, num_classes=19)
17 | assert not _conv_has_norm(head, sync_bn=False)
18 |
19 | # test with norm_cfg
20 | head = DMHead(
21 | in_channels=8,
22 | channels=4,
23 | num_classes=19,
24 | norm_cfg=dict(type='SyncBN'))
25 | assert _conv_has_norm(head, sync_bn=True)
26 |
27 | # fusion=True
28 | inputs = [torch.randn(1, 8, 23, 23)]
29 | head = DMHead(
30 | in_channels=8,
31 | channels=4,
32 | num_classes=19,
33 | filter_sizes=(1, 3, 5),
34 | fusion=True)
35 | if torch.cuda.is_available():
36 | head, inputs = to_cuda(head, inputs)
37 | assert head.fusion is True
38 | assert head.dcm_modules[0].filter_size == 1
39 | assert head.dcm_modules[1].filter_size == 3
40 | assert head.dcm_modules[2].filter_size == 5
41 | outputs = head(inputs)
42 | assert outputs.shape == (1, head.num_classes, 23, 23)
43 |
44 | # fusion=False
45 | inputs = [torch.randn(1, 8, 23, 23)]
46 | head = DMHead(
47 | in_channels=8,
48 | channels=4,
49 | num_classes=19,
50 | filter_sizes=(1, 3, 5),
51 | fusion=False)
52 | if torch.cuda.is_available():
53 | head, inputs = to_cuda(head, inputs)
54 | assert head.fusion is False
55 | assert head.dcm_modules[0].filter_size == 1
56 | assert head.dcm_modules[1].filter_size == 3
57 | assert head.dcm_modules[2].filter_size == 5
58 | outputs = head(inputs)
59 | assert outputs.shape == (1, head.num_classes, 23, 23)
60 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_dnl_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import DNLHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_dnl_head():
9 | # DNL with 'embedded_gaussian' mode
10 | head = DNLHead(in_channels=8, channels=4, num_classes=19)
11 | assert len(head.convs) == 2
12 | assert hasattr(head, 'dnl_block')
13 | assert head.dnl_block.temperature == 0.05
14 | inputs = [torch.randn(1, 8, 23, 23)]
15 | if torch.cuda.is_available():
16 | head, inputs = to_cuda(head, inputs)
17 | outputs = head(inputs)
18 | assert outputs.shape == (1, head.num_classes, 23, 23)
19 |
20 | # NonLocal2d with 'dot_product' mode
21 | head = DNLHead(
22 | in_channels=8, channels=4, num_classes=19, mode='dot_product')
23 | inputs = [torch.randn(1, 8, 23, 23)]
24 | if torch.cuda.is_available():
25 | head, inputs = to_cuda(head, inputs)
26 | outputs = head(inputs)
27 | assert outputs.shape == (1, head.num_classes, 23, 23)
28 |
29 | # NonLocal2d with 'gaussian' mode
30 | head = DNLHead(in_channels=8, channels=4, num_classes=19, mode='gaussian')
31 | inputs = [torch.randn(1, 8, 23, 23)]
32 | if torch.cuda.is_available():
33 | head, inputs = to_cuda(head, inputs)
34 | outputs = head(inputs)
35 | assert outputs.shape == (1, head.num_classes, 23, 23)
36 |
37 | # NonLocal2d with 'concatenation' mode
38 | head = DNLHead(
39 | in_channels=8, channels=4, num_classes=19, mode='concatenation')
40 | inputs = [torch.randn(1, 8, 23, 23)]
41 | if torch.cuda.is_available():
42 | head, inputs = to_cuda(head, inputs)
43 | outputs = head(inputs)
44 | assert outputs.shape == (1, head.num_classes, 23, 23)
45 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_dpt_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import DPTHead
6 |
7 |
8 | def test_dpt_head():
9 |
10 | with pytest.raises(AssertionError):
11 | # input_transform must be 'multiple_select'
12 | head = DPTHead(
13 | in_channels=[768, 768, 768, 768],
14 | channels=4,
15 | num_classes=19,
16 | in_index=[0, 1, 2, 3])
17 |
18 | head = DPTHead(
19 | in_channels=[768, 768, 768, 768],
20 | channels=4,
21 | num_classes=19,
22 | in_index=[0, 1, 2, 3],
23 | input_transform='multiple_select')
24 |
25 | inputs = [[torch.randn(4, 768, 2, 2),
26 | torch.randn(4, 768)] for _ in range(4)]
27 | output = head(inputs)
28 | assert output.shape == torch.Size((4, 19, 16, 16))
29 |
30 | # test readout operation
31 | head = DPTHead(
32 | in_channels=[768, 768, 768, 768],
33 | channels=4,
34 | num_classes=19,
35 | in_index=[0, 1, 2, 3],
36 | input_transform='multiple_select',
37 | readout_type='add')
38 | output = head(inputs)
39 | assert output.shape == torch.Size((4, 19, 16, 16))
40 |
41 | head = DPTHead(
42 | in_channels=[768, 768, 768, 768],
43 | channels=4,
44 | num_classes=19,
45 | in_index=[0, 1, 2, 3],
46 | input_transform='multiple_select',
47 | readout_type='project')
48 | output = head(inputs)
49 | assert output.shape == torch.Size((4, 19, 16, 16))
50 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_ema_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import EMAHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_emanet_head():
9 | head = EMAHead(
10 | in_channels=4,
11 | ema_channels=3,
12 | channels=2,
13 | num_stages=3,
14 | num_bases=2,
15 | num_classes=19)
16 | for param in head.ema_mid_conv.parameters():
17 | assert not param.requires_grad
18 | assert hasattr(head, 'ema_module')
19 | inputs = [torch.randn(1, 4, 23, 23)]
20 | if torch.cuda.is_available():
21 | head, inputs = to_cuda(head, inputs)
22 | outputs = head(inputs)
23 | assert outputs.shape == (1, head.num_classes, 23, 23)
24 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_gc_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import GCHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_gc_head():
9 | head = GCHead(in_channels=4, channels=4, num_classes=19)
10 | assert len(head.convs) == 2
11 | assert hasattr(head, 'gc_block')
12 | inputs = [torch.randn(1, 4, 23, 23)]
13 | if torch.cuda.is_available():
14 | head, inputs = to_cuda(head, inputs)
15 | outputs = head(inputs)
16 | assert outputs.shape == (1, head.num_classes, 23, 23)
17 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_ham_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import LightHamHead
5 | from .utils import _conv_has_norm, to_cuda
6 |
7 | ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
8 |
9 |
10 | def test_ham_head():
11 |
12 | # test without sync_bn
13 | head = LightHamHead(
14 | in_channels=[16, 32, 64],
15 | in_index=[1, 2, 3],
16 | channels=64,
17 | ham_channels=64,
18 | dropout_ratio=0.1,
19 | num_classes=19,
20 | norm_cfg=ham_norm_cfg,
21 | align_corners=False,
22 | loss_decode=dict(
23 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
24 | ham_kwargs=dict(
25 | MD_S=1,
26 | MD_R=64,
27 | train_steps=6,
28 | eval_steps=7,
29 | inv_t=100,
30 | rand_init=True))
31 | assert not _conv_has_norm(head, sync_bn=False)
32 |
33 | inputs = [
34 | torch.randn(1, 8, 32, 32),
35 | torch.randn(1, 16, 16, 16),
36 | torch.randn(1, 32, 8, 8),
37 | torch.randn(1, 64, 4, 4)
38 | ]
39 | if torch.cuda.is_available():
40 | head, inputs = to_cuda(head, inputs)
41 | assert head.in_channels == [16, 32, 64]
42 | assert head.hamburger.ham_in.in_channels == 64
43 | outputs = head(inputs)
44 | assert outputs.shape == (1, head.num_classes, 16, 16)
45 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_isa_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import ISAHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_isa_head():
9 |
10 | inputs = [torch.randn(1, 8, 23, 23)]
11 | isa_head = ISAHead(
12 | in_channels=8,
13 | channels=4,
14 | num_classes=19,
15 | isa_channels=4,
16 | down_factor=(8, 8))
17 | if torch.cuda.is_available():
18 | isa_head, inputs = to_cuda(isa_head, inputs)
19 | output = isa_head(inputs)
20 | assert output.shape == (1, isa_head.num_classes, 23, 23)
21 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_maskformer_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from os.path import dirname, join
3 |
4 | import torch
5 | from mmengine import Config
6 | from mmengine.registry import init_default_scope
7 | from mmengine.structures import PixelData
8 |
9 | from mmseg.registry import MODELS
10 | from mmseg.structures import SegDataSample
11 |
12 |
13 | def test_maskformer_head():
14 | init_default_scope('mmseg')
15 | repo_dpath = dirname(dirname(__file__))
16 | cfg = Config.fromfile(
17 | join(
18 | repo_dpath,
19 | '../../configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py' # noqa
20 | ))
21 | cfg.model.train_cfg = None
22 | decode_head = MODELS.build(cfg.model.decode_head)
23 | inputs = (torch.randn(1, 256, 32, 32), torch.randn(1, 512, 16, 16),
24 | torch.randn(1, 1024, 8, 8), torch.randn(1, 2048, 4, 4))
25 | # test inference
26 | batch_img_metas = [
27 | dict(
28 | scale_factor=(1.0, 1.0),
29 | img_shape=(512, 683),
30 | ori_shape=(512, 683))
31 | ]
32 | test_cfg = dict(mode='whole')
33 | output = decode_head.predict(inputs, batch_img_metas, test_cfg)
34 | assert output.shape == (1, 150, 512, 683)
35 |
36 | # test training
37 | inputs = (torch.randn(2, 256, 32, 32), torch.randn(2, 512, 16, 16),
38 | torch.randn(2, 1024, 8, 8), torch.randn(2, 2048, 4, 4))
39 | batch_data_samples = []
40 | img_meta = {
41 | 'img_shape': (512, 512),
42 | 'ori_shape': (480, 640),
43 | 'pad_shape': (512, 512),
44 | 'scale_factor': (1.425, 1.425),
45 | }
46 | for _ in range(2):
47 | data_sample = SegDataSample(
48 | gt_sem_seg=PixelData(data=torch.ones(512, 512).long()))
49 | data_sample.set_metainfo(img_meta)
50 | batch_data_samples.append(data_sample)
51 | train_cfg = {}
52 | losses = decode_head.loss(inputs, batch_data_samples, train_cfg)
53 | assert (loss in losses.keys()
54 | for loss in ('loss_cls', 'loss_mask', 'loss_dice'))
55 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_nl_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import NLHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_nl_head():
9 | head = NLHead(in_channels=8, channels=4, num_classes=19)
10 | assert len(head.convs) == 2
11 | assert hasattr(head, 'nl_block')
12 | inputs = [torch.randn(1, 8, 23, 23)]
13 | if torch.cuda.is_available():
14 | head, inputs = to_cuda(head, inputs)
15 | outputs = head(inputs)
16 | assert outputs.shape == (1, head.num_classes, 23, 23)
17 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_ocr_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import FCNHead, OCRHead
5 | from .utils import to_cuda
6 |
7 |
8 | def test_ocr_head():
9 |
10 | inputs = [torch.randn(1, 8, 23, 23)]
11 | ocr_head = OCRHead(
12 | in_channels=8, channels=4, num_classes=19, ocr_channels=8)
13 | fcn_head = FCNHead(in_channels=8, channels=4, num_classes=19)
14 | if torch.cuda.is_available():
15 | head, inputs = to_cuda(ocr_head, inputs)
16 | head, inputs = to_cuda(fcn_head, inputs)
17 | prev_output = fcn_head(inputs)
18 | output = ocr_head(inputs, prev_output)
19 | assert output.shape == (1, ocr_head.num_classes, 23, 23)
20 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_psp_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import PSPHead
6 | from .utils import _conv_has_norm, to_cuda
7 |
8 |
9 | def test_psp_head():
10 |
11 | with pytest.raises(AssertionError):
12 | # pool_scales must be list|tuple
13 | PSPHead(in_channels=4, channels=2, num_classes=19, pool_scales=1)
14 |
15 | # test no norm_cfg
16 | head = PSPHead(in_channels=4, channels=2, num_classes=19)
17 | assert not _conv_has_norm(head, sync_bn=False)
18 |
19 | # test with norm_cfg
20 | head = PSPHead(
21 | in_channels=4,
22 | channels=2,
23 | num_classes=19,
24 | norm_cfg=dict(type='SyncBN'))
25 | assert _conv_has_norm(head, sync_bn=True)
26 |
27 | inputs = [torch.randn(1, 4, 23, 23)]
28 | head = PSPHead(
29 | in_channels=4, channels=2, num_classes=19, pool_scales=(1, 2, 3))
30 | if torch.cuda.is_available():
31 | head, inputs = to_cuda(head, inputs)
32 | assert head.psp_modules[0][0].output_size == 1
33 | assert head.psp_modules[1][0].output_size == 2
34 | assert head.psp_modules[2][0].output_size == 3
35 | outputs = head(inputs)
36 | assert outputs.shape == (1, head.num_classes, 23, 23)
37 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_segformer_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import SegformerHead
6 |
7 |
8 | def test_segformer_head():
9 | with pytest.raises(AssertionError):
10 | # `in_channels` must have same length as `in_index`
11 | SegformerHead(
12 | in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2)
13 |
14 | H, W = (64, 64)
15 | in_channels = (32, 64, 160, 256)
16 | shapes = [(H // 2**(i + 2), W // 2**(i + 2))
17 | for i in range(len(in_channels))]
18 | model = SegformerHead(
19 | in_channels=in_channels,
20 | in_index=[0, 1, 2, 3],
21 | channels=256,
22 | num_classes=19)
23 |
24 | with pytest.raises(IndexError):
25 | # in_index must match the input feature maps.
26 | inputs = [
27 | torch.randn((1, in_channel, *shape))
28 | for in_channel, shape in zip(in_channels, shapes)
29 | ][:3]
30 | temp = model(inputs)
31 |
32 | # Normal Input
33 | # ((1, 32, 16, 16), (1, 64, 8, 8), (1, 160, 4, 4), (1, 256, 2, 2)
34 | inputs = [
35 | torch.randn((1, in_channel, *shape))
36 | for in_channel, shape in zip(in_channels, shapes)
37 | ]
38 | temp = model(inputs)
39 |
40 | assert temp.shape == (1, 19, H // 4, W // 4)
41 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_segmenter_mask_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models.decode_heads import SegmenterMaskTransformerHead
5 | from .utils import _conv_has_norm, to_cuda
6 |
7 |
8 | def test_segmenter_mask_transformer_head():
9 | head = SegmenterMaskTransformerHead(
10 | in_channels=2,
11 | channels=2,
12 | num_classes=150,
13 | num_layers=2,
14 | num_heads=3,
15 | embed_dims=192,
16 | dropout_ratio=0.0)
17 | assert _conv_has_norm(head, sync_bn=True)
18 | head.init_weights()
19 |
20 | inputs = [torch.randn(1, 2, 32, 32)]
21 | if torch.cuda.is_available():
22 | head, inputs = to_cuda(head, inputs)
23 | outputs = head(inputs)
24 | assert outputs.shape == (1, head.num_classes, 32, 32)
25 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_setr_mla_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import SETRMLAHead
6 | from .utils import to_cuda
7 |
8 |
9 | def test_setr_mla_head(capsys):
10 |
11 | with pytest.raises(AssertionError):
12 | # MLA requires input multiple stage feature information.
13 | SETRMLAHead(in_channels=8, channels=4, num_classes=19, in_index=1)
14 |
15 | with pytest.raises(AssertionError):
16 | # multiple in_indexs requires multiple in_channels.
17 | SETRMLAHead(
18 | in_channels=8, channels=4, num_classes=19, in_index=(0, 1, 2, 3))
19 |
20 | with pytest.raises(AssertionError):
21 | # channels should be len(in_channels) * mla_channels
22 | SETRMLAHead(
23 | in_channels=(8, 8, 8, 8),
24 | channels=8,
25 | mla_channels=4,
26 | in_index=(0, 1, 2, 3),
27 | num_classes=19)
28 |
29 | # test inference of MLA head
30 | img_size = (8, 8)
31 | patch_size = 4
32 | head = SETRMLAHead(
33 | in_channels=(8, 8, 8, 8),
34 | channels=16,
35 | mla_channels=4,
36 | in_index=(0, 1, 2, 3),
37 | num_classes=19,
38 | norm_cfg=dict(type='BN'))
39 |
40 | h, w = img_size[0] // patch_size, img_size[1] // patch_size
41 | # Input square NCHW format feature information
42 | x = [
43 | torch.randn(1, 8, h, w),
44 | torch.randn(1, 8, h, w),
45 | torch.randn(1, 8, h, w),
46 | torch.randn(1, 8, h, w)
47 | ]
48 | if torch.cuda.is_available():
49 | head, x = to_cuda(head, x)
50 | out = head(x)
51 | assert out.shape == (1, head.num_classes, h * 4, w * 4)
52 |
53 | # Input non-square NCHW format feature information
54 | x = [
55 | torch.randn(1, 8, h, w * 2),
56 | torch.randn(1, 8, h, w * 2),
57 | torch.randn(1, 8, h, w * 2),
58 | torch.randn(1, 8, h, w * 2)
59 | ]
60 | if torch.cuda.is_available():
61 | head, x = to_cuda(head, x)
62 | out = head(x)
63 | assert out.shape == (1, head.num_classes, h * 4, w * 8)
64 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_setr_up_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import SETRUPHead
6 | from .utils import to_cuda
7 |
8 |
9 | def test_setr_up_head(capsys):
10 |
11 | with pytest.raises(AssertionError):
12 | # kernel_size must be [1/3]
13 | SETRUPHead(num_classes=19, kernel_size=2)
14 |
15 | with pytest.raises(AssertionError):
16 | # in_channels must be int type and in_channels must be same
17 | # as embed_dim.
18 | SETRUPHead(in_channels=(4, 4), channels=2, num_classes=19)
19 |
20 | # test init_cfg of head
21 | head = SETRUPHead(
22 | in_channels=4,
23 | channels=2,
24 | norm_cfg=dict(type='SyncBN'),
25 | num_classes=19,
26 | init_cfg=dict(type='Kaiming'))
27 | super(SETRUPHead, head).init_weights()
28 |
29 | # test inference of Naive head
30 | # the auxiliary head of Naive head is same as Naive head
31 | img_size = (4, 4)
32 | patch_size = 2
33 | head = SETRUPHead(
34 | in_channels=4,
35 | channels=2,
36 | num_classes=19,
37 | num_convs=1,
38 | up_scale=4,
39 | kernel_size=1,
40 | norm_cfg=dict(type='BN'))
41 |
42 | h, w = img_size[0] // patch_size, img_size[1] // patch_size
43 |
44 | # Input square NCHW format feature information
45 | x = [torch.randn(1, 4, h, w)]
46 | if torch.cuda.is_available():
47 | head, x = to_cuda(head, x)
48 | out = head(x)
49 | assert out.shape == (1, head.num_classes, h * 4, w * 4)
50 |
51 | # Input non-square NCHW format feature information
52 | x = [torch.randn(1, 4, h, w * 2)]
53 | if torch.cuda.is_available():
54 | head, x = to_cuda(head, x)
55 | out = head(x)
56 | assert out.shape == (1, head.num_classes, h * 4, w * 8)
57 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/test_uper_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.decode_heads import UPerHead
6 | from .utils import _conv_has_norm, to_cuda
7 |
8 |
9 | def test_uper_head():
10 |
11 | with pytest.raises(AssertionError):
12 | # fpn_in_channels must be list|tuple
13 | UPerHead(in_channels=4, channels=2, num_classes=19)
14 |
15 | # test no norm_cfg
16 | head = UPerHead(
17 | in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1])
18 | assert not _conv_has_norm(head, sync_bn=False)
19 |
20 | # test with norm_cfg
21 | head = UPerHead(
22 | in_channels=[4, 2],
23 | channels=2,
24 | num_classes=19,
25 | norm_cfg=dict(type='SyncBN'),
26 | in_index=[-2, -1])
27 | assert _conv_has_norm(head, sync_bn=True)
28 |
29 | inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 2, 21, 21)]
30 | head = UPerHead(
31 | in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1])
32 | if torch.cuda.is_available():
33 | head, inputs = to_cuda(head, inputs)
34 | outputs = head(inputs)
35 | assert outputs.shape == (1, head.num_classes, 45, 45)
36 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_heads/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.cnn import ConvModule
3 | from mmengine.utils.dl_utils.parrots_wrapper import SyncBatchNorm
4 |
5 |
6 | def _conv_has_norm(module, sync_bn):
7 | for m in module.modules():
8 | if isinstance(m, ConvModule):
9 | if not m.with_norm:
10 | return False
11 | if sync_bn:
12 | if not isinstance(m.bn, SyncBatchNorm):
13 | return False
14 | return True
15 |
16 |
17 | def to_cuda(module, data):
18 | module = module.cuda()
19 | if isinstance(data, list):
20 | for i in range(len(data)):
21 | data[i] = data[i].cuda()
22 | return module, data
23 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_losses/test_huasdorff_distance_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.losses import HuasdorffDisstanceLoss
6 |
7 |
8 | def test_huasdorff_distance_loss():
9 | loss_class = HuasdorffDisstanceLoss
10 | pred = torch.rand((10, 8, 6, 6))
11 | target = torch.rand((10, 6, 6))
12 | class_weight = torch.rand(8)
13 |
14 | # Test loss forward
15 | loss = loss_class()(pred, target)
16 | assert isinstance(loss, torch.Tensor)
17 |
18 | # Test loss forward with avg_factor
19 | loss = loss_class()(pred, target, avg_factor=10)
20 | assert isinstance(loss, torch.Tensor)
21 |
22 | # Test loss forward with avg_factor and reduction is None, 'sum' and 'mean'
23 | for reduction in [None, 'sum', 'mean']:
24 | loss = loss_class()(pred, target, avg_factor=10, reduction=reduction)
25 | assert isinstance(loss, torch.Tensor)
26 |
27 | # Test loss forward with class_weight
28 | with pytest.raises(AssertionError):
29 | loss_class(class_weight=class_weight)(pred, target)
30 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/test_feature2pyramid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models import Feature2Pyramid
6 |
7 |
8 | def test_feature2pyramid():
9 | # test
10 | rescales = [4, 2, 1, 0.5]
11 | embed_dim = 64
12 | inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
13 |
14 | fpn = Feature2Pyramid(
15 | embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
16 | outputs = fpn(inputs)
17 | assert outputs[0].shape == torch.Size([1, 64, 128, 128])
18 | assert outputs[1].shape == torch.Size([1, 64, 64, 64])
19 | assert outputs[2].shape == torch.Size([1, 64, 32, 32])
20 | assert outputs[3].shape == torch.Size([1, 64, 16, 16])
21 |
22 | # test rescales = [2, 1, 0.5, 0.25]
23 | rescales = [2, 1, 0.5, 0.25]
24 | inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
25 |
26 | fpn = Feature2Pyramid(
27 | embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
28 | outputs = fpn(inputs)
29 | assert outputs[0].shape == torch.Size([1, 64, 64, 64])
30 | assert outputs[1].shape == torch.Size([1, 64, 32, 32])
31 | assert outputs[2].shape == torch.Size([1, 64, 16, 16])
32 | assert outputs[3].shape == torch.Size([1, 64, 8, 8])
33 |
34 | # test rescales = [4, 2, 0.25, 0]
35 | rescales = [4, 2, 0.25, 0]
36 | with pytest.raises(KeyError):
37 | fpn = Feature2Pyramid(
38 | embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
39 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/test_fpn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models import FPN
5 |
6 |
7 | def test_fpn():
8 | in_channels = [64, 128, 256, 512]
9 | inputs = [
10 | torch.randn(1, c, 56 // 2**i, 56 // 2**i)
11 | for i, c in enumerate(in_channels)
12 | ]
13 |
14 | fpn = FPN(in_channels, 64, len(in_channels))
15 | outputs = fpn(inputs)
16 | assert outputs[0].shape == torch.Size([1, 64, 56, 56])
17 | assert outputs[1].shape == torch.Size([1, 64, 28, 28])
18 | assert outputs[2].shape == torch.Size([1, 64, 14, 14])
19 | assert outputs[3].shape == torch.Size([1, 64, 7, 7])
20 |
21 | fpn = FPN(
22 | in_channels,
23 | 64,
24 | len(in_channels),
25 | upsample_cfg=dict(mode='nearest', scale_factor=2.0))
26 | outputs = fpn(inputs)
27 | assert outputs[0].shape == torch.Size([1, 64, 56, 56])
28 | assert outputs[1].shape == torch.Size([1, 64, 28, 28])
29 | assert outputs[2].shape == torch.Size([1, 64, 14, 14])
30 | assert outputs[3].shape == torch.Size([1, 64, 7, 7])
31 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/test_ic_neck.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.necks import ICNeck
6 | from mmseg.models.necks.ic_neck import CascadeFeatureFusion
7 | from ..test_heads.utils import _conv_has_norm, to_cuda
8 |
9 |
10 | def test_ic_neck():
11 | # test with norm_cfg
12 | neck = ICNeck(
13 | in_channels=(4, 16, 16),
14 | out_channels=8,
15 | norm_cfg=dict(type='SyncBN'),
16 | align_corners=False)
17 | assert _conv_has_norm(neck, sync_bn=True)
18 |
19 | inputs = [
20 | torch.randn(1, 4, 32, 64),
21 | torch.randn(1, 16, 16, 32),
22 | torch.randn(1, 16, 8, 16)
23 | ]
24 | neck = ICNeck(
25 | in_channels=(4, 16, 16),
26 | out_channels=4,
27 | norm_cfg=dict(type='BN', requires_grad=True),
28 | align_corners=False)
29 | if torch.cuda.is_available():
30 | neck, inputs = to_cuda(neck, inputs)
31 |
32 | outputs = neck(inputs)
33 | assert outputs[0].shape == (1, 4, 16, 32)
34 | assert outputs[1].shape == (1, 4, 32, 64)
35 | assert outputs[1].shape == (1, 4, 32, 64)
36 |
37 |
38 | def test_ic_neck_cascade_feature_fusion():
39 | cff = CascadeFeatureFusion(64, 64, 32)
40 | assert cff.conv_low.in_channels == 64
41 | assert cff.conv_low.out_channels == 32
42 | assert cff.conv_high.in_channels == 64
43 | assert cff.conv_high.out_channels == 32
44 |
45 |
46 | def test_ic_neck_input_channels():
47 | with pytest.raises(AssertionError):
48 | # ICNet Neck input channel constraints.
49 | ICNeck(
50 | in_channels=(16, 64, 64, 64),
51 | out_channels=32,
52 | norm_cfg=dict(type='BN', requires_grad=True),
53 | align_corners=False)
54 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/test_jpu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmseg.models.necks import JPU
6 |
7 |
8 | def test_fastfcn_neck():
9 | # Test FastFCN Standard Forward
10 | model = JPU(
11 | in_channels=(64, 128, 256),
12 | mid_channels=64,
13 | start_level=0,
14 | end_level=-1,
15 | dilations=(1, 2, 4, 8),
16 | )
17 | model.init_weights()
18 | model.train()
19 | batch_size = 1
20 | input = [
21 | torch.randn(batch_size, 64, 64, 128),
22 | torch.randn(batch_size, 128, 32, 64),
23 | torch.randn(batch_size, 256, 16, 32)
24 | ]
25 | feat = model(input)
26 |
27 | assert len(feat) == 3
28 | assert feat[0].shape == torch.Size([batch_size, 64, 64, 128])
29 | assert feat[1].shape == torch.Size([batch_size, 128, 32, 64])
30 | assert feat[2].shape == torch.Size([batch_size, 256, 64, 128])
31 |
32 | with pytest.raises(AssertionError):
33 | # FastFCN input and in_channels constraints.
34 | JPU(in_channels=(256, 64, 128), start_level=0, end_level=5)
35 |
36 | # Test not default start_level
37 | model = JPU(in_channels=(64, 128, 256), start_level=1, end_level=-1)
38 | input = [
39 | torch.randn(batch_size, 64, 64, 128),
40 | torch.randn(batch_size, 128, 32, 64),
41 | torch.randn(batch_size, 256, 16, 32)
42 | ]
43 | feat = model(input)
44 | assert len(feat) == 2
45 | assert feat[0].shape == torch.Size([batch_size, 128, 32, 64])
46 | assert feat[1].shape == torch.Size([batch_size, 2048, 32, 64])
47 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/test_mla_neck.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models import MLANeck
5 |
6 |
7 | def test_mla():
8 | in_channels = [4, 4, 4, 4]
9 | mla = MLANeck(in_channels, 32)
10 |
11 | inputs = [torch.randn(1, c, 12, 12) for i, c in enumerate(in_channels)]
12 | outputs = mla(inputs)
13 | assert outputs[0].shape == torch.Size([1, 32, 12, 12])
14 | assert outputs[1].shape == torch.Size([1, 32, 12, 12])
15 | assert outputs[2].shape == torch.Size([1, 32, 12, 12])
16 | assert outputs[3].shape == torch.Size([1, 32, 12, 12])
17 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_necks/test_multilevel_neck.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmseg.models import MultiLevelNeck
5 |
6 |
7 | def test_multilevel_neck():
8 |
9 | # Test init_weights
10 | MultiLevelNeck([266], 32).init_weights()
11 |
12 | # Test multi feature maps
13 | in_channels = [32, 64, 128, 256]
14 | inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]
15 |
16 | neck = MultiLevelNeck(in_channels, 32)
17 | outputs = neck(inputs)
18 | assert outputs[0].shape == torch.Size([1, 32, 7, 7])
19 | assert outputs[1].shape == torch.Size([1, 32, 14, 14])
20 | assert outputs[2].shape == torch.Size([1, 32, 28, 28])
21 | assert outputs[3].shape == torch.Size([1, 32, 56, 56])
22 |
23 | # Test one feature map
24 | in_channels = [768]
25 | inputs = [torch.randn(1, 768, 14, 14)]
26 |
27 | neck = MultiLevelNeck(in_channels, 32)
28 | outputs = neck(inputs)
29 | assert outputs[0].shape == torch.Size([1, 32, 7, 7])
30 | assert outputs[1].shape == torch.Size([1, 32, 14, 14])
31 | assert outputs[2].shape == torch.Size([1, 32, 28, 28])
32 | assert outputs[3].shape == torch.Size([1, 32, 56, 56])
33 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_segmentors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_segmentors/test_cascade_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmengine import ConfigDict
3 |
4 | from mmseg.models import build_segmentor
5 | from .utils import _segmentor_forward_train_test
6 |
7 |
8 | def test_cascade_encoder_decoder():
9 |
10 | # test 1 decode head, w.o. aux head
11 | cfg = ConfigDict(
12 | type='CascadeEncoderDecoder',
13 | num_stages=2,
14 | backbone=dict(type='ExampleBackbone'),
15 | decode_head=[
16 | dict(type='ExampleDecodeHead'),
17 | dict(type='ExampleCascadeDecodeHead')
18 | ])
19 | cfg.test_cfg = ConfigDict(mode='whole')
20 | segmentor = build_segmentor(cfg)
21 | _segmentor_forward_train_test(segmentor)
22 |
23 | # test slide mode
24 | cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
25 | segmentor = build_segmentor(cfg)
26 | _segmentor_forward_train_test(segmentor)
27 |
28 | # test 1 decode head, 1 aux head
29 | cfg = ConfigDict(
30 | type='CascadeEncoderDecoder',
31 | num_stages=2,
32 | backbone=dict(type='ExampleBackbone'),
33 | decode_head=[
34 | dict(type='ExampleDecodeHead'),
35 | dict(type='ExampleCascadeDecodeHead')
36 | ],
37 | auxiliary_head=dict(type='ExampleDecodeHead'))
38 | cfg.test_cfg = ConfigDict(mode='whole')
39 | segmentor = build_segmentor(cfg)
40 | _segmentor_forward_train_test(segmentor)
41 |
42 | # test 1 decode head, 2 aux head
43 | cfg = ConfigDict(
44 | type='CascadeEncoderDecoder',
45 | num_stages=2,
46 | backbone=dict(type='ExampleBackbone'),
47 | decode_head=[
48 | dict(type='ExampleDecodeHead'),
49 | dict(type='ExampleCascadeDecodeHead')
50 | ],
51 | auxiliary_head=[
52 | dict(type='ExampleDecodeHead'),
53 | dict(type='ExampleDecodeHead')
54 | ])
55 | cfg.test_cfg = ConfigDict(mode='whole')
56 | segmentor = build_segmentor(cfg)
57 | _segmentor_forward_train_test(segmentor)
58 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_segmentors/test_seg_tta_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import tempfile
3 |
4 | import torch
5 | from mmengine import ConfigDict
6 | from mmengine.model import BaseTTAModel
7 | from mmengine.registry import init_default_scope
8 | from mmengine.structures import PixelData
9 |
10 | from mmseg.registry import MODELS
11 | from mmseg.structures import SegDataSample
12 | from .utils import * # noqa: F401,F403
13 |
14 | init_default_scope('mmseg')
15 |
16 |
17 | def test_encoder_decoder_tta():
18 |
19 | segmentor_cfg = ConfigDict(
20 | type='EncoderDecoder',
21 | backbone=dict(type='ExampleBackbone'),
22 | decode_head=dict(type='ExampleDecodeHead'),
23 | train_cfg=None,
24 | test_cfg=dict(mode='whole'))
25 |
26 | cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg)
27 |
28 | model: BaseTTAModel = MODELS.build(cfg)
29 |
30 | imgs = []
31 | data_samples = []
32 | directions = ['horizontal', 'vertical']
33 | for i in range(12):
34 | flip_direction = directions[0] if i % 3 == 0 else directions[1]
35 | imgs.append(torch.randn(1, 3, 10 + i, 10 + i))
36 | data_samples.append([
37 | SegDataSample(
38 | metainfo=dict(
39 | ori_shape=(10, 10),
40 | img_shape=(10 + i, 10 + i),
41 | flip=(i % 2 == 0),
42 | flip_direction=flip_direction,
43 | img_path=tempfile.mktemp()),
44 | gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
45 | ])
46 |
47 | model.test_step(dict(inputs=imgs, data_samples=data_samples))
48 |
49 | # test out_channels == 1
50 | segmentor_cfg = ConfigDict(
51 | type='EncoderDecoder',
52 | backbone=dict(type='ExampleBackbone'),
53 | decode_head=dict(
54 | type='ExampleDecodeHead',
55 | num_classes=2,
56 | out_channels=1,
57 | threshold=0.4),
58 | train_cfg=None,
59 | test_cfg=dict(mode='whole'))
60 | model.module = MODELS.build(segmentor_cfg)
61 | for data_sample in data_samples:
62 | data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10))
63 | model.test_step(dict(inputs=imgs, data_samples=data_samples))
64 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_models/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_utils/test_io.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | import numpy as np
5 | import pytest
6 | from mmengine import FileClient
7 |
8 | from mmseg.utils import datafrombytes
9 |
10 |
11 | @pytest.mark.parametrize(
12 | ['backend', 'suffix'],
13 | [['nifti', '.nii.gz'], ['numpy', '.npy'], ['pickle', '.pkl']])
14 | def test_datafrombytes(backend, suffix):
15 |
16 | file_client = FileClient('disk')
17 | file_path = osp.join(osp.dirname(__file__), '../data/biomedical' + suffix)
18 | bytes = file_client.get(file_path)
19 | data = datafrombytes(bytes, backend)
20 |
21 | if backend == 'pickle':
22 | # test pickle loading
23 | assert isinstance(data, dict)
24 | else:
25 | assert isinstance(data, np.ndarray)
26 | if backend == 'nifti':
27 | # test nifti file loading
28 | assert len(data.shape) == 3
29 | else:
30 | # test npy file loading
31 | # testing data biomedical.npy includes data and label
32 | assert len(data.shape) == 4
33 | assert data.shape[0] == 2
34 |
--------------------------------------------------------------------------------
/Segmentation/tests/test_utils/test_set_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import datetime
3 | import sys
4 | from unittest import TestCase
5 |
6 | from mmengine import DefaultScope
7 |
8 | from mmseg.utils import register_all_modules
9 |
10 |
11 | class TestSetupEnv(TestCase):
12 |
13 | def test_register_all_modules(self):
14 | from mmseg.registry import DATASETS
15 |
16 | # not init default scope
17 | sys.modules.pop('mmseg.datasets', None)
18 | sys.modules.pop('mmseg.datasets.ade', None)
19 | DATASETS._module_dict.pop('ADE20KDataset', None)
20 | self.assertFalse('ADE20KDataset' in DATASETS.module_dict)
21 | register_all_modules(init_default_scope=False)
22 | self.assertTrue('ADE20KDataset' in DATASETS.module_dict)
23 |
24 | # init default scope
25 | sys.modules.pop('mmseg.datasets')
26 | sys.modules.pop('mmseg.datasets.ade')
27 | DATASETS._module_dict.pop('ADE20KDataset', None)
28 | self.assertFalse('ADE20KDataset' in DATASETS.module_dict)
29 | register_all_modules(init_default_scope=True)
30 | self.assertTrue('ADE20KDataset' in DATASETS.module_dict)
31 | self.assertEqual(DefaultScope.get_current_instance().scope_name,
32 | 'mmseg')
33 |
34 | # init default scope when another scope is init
35 | name = f'test-{datetime.datetime.now()}'
36 | DefaultScope.get_instance(name, scope_name='test')
37 | with self.assertWarnsRegex(
38 | Warning, 'The current default scope "test" is not "mmseg"'):
39 | register_all_modules(init_default_scope=True)
40 |
--------------------------------------------------------------------------------
/Segmentation/tools/analysis_tools/utils.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 |
3 | COLOR_RED = "91m"
4 | COLOR_GREEN = "92m"
5 | COLOR_YELLOW = "93m"
6 |
7 | def colorful_print(fn_print, color=COLOR_RED):
8 | def actual_call(*args, **kwargs):
9 | print(f"\033[{color}", end="")
10 | fn_print(*args, **kwargs)
11 | print("\033[00m", end="")
12 | return actual_call
13 |
14 | prRed = colorful_print(print, color=COLOR_RED)
15 | prGreen = colorful_print(print, color=COLOR_GREEN)
16 | prYellow = colorful_print(print, color=COLOR_YELLOW)
17 |
18 | # def prRed(skk):
19 | # print("\033[91m{}\033[00m".format(skk))
20 |
21 | # def prGreen(skk):
22 | # print("\033[92m{}\033[00m".format(skk))
23 |
24 | # def prYellow(skk):
25 | # print("\033[93m{}\033[00m".format(skk))
26 |
27 |
28 | def clever_format(nums, format="%.2f"):
29 | if not isinstance(nums, Iterable):
30 | nums = [nums]
31 | clever_nums = []
32 |
33 | for num in nums:
34 | if num > 1e12:
35 | clever_nums.append(format % (num / 1e12) + "T")
36 | elif num > 1e9:
37 | clever_nums.append(format % (num / 1e9) + "G")
38 | elif num > 1e6:
39 | clever_nums.append(format % (num / 1e6) + "M")
40 | elif num > 1e3:
41 | clever_nums.append(format % (num / 1e3) + "K")
42 | else:
43 | clever_nums.append(format % num + "B")
44 |
45 | clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
46 |
47 | return clever_nums
48 |
49 |
50 | if __name__ == "__main__":
51 | prRed("hello", "world")
52 | prGreen("hello", "world")
53 | prYellow("hello", "world")
--------------------------------------------------------------------------------
/Segmentation/tools/dataset_converters/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 |
5 | from cityscapesscripts.preparation.json2labelImg import json2labelImg
6 | from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
7 | track_progress)
8 |
9 |
10 | def convert_json_to_label(json_file):
11 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
12 | json2labelImg(json_file, label_file, 'trainIds')
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(
17 | description='Convert Cityscapes annotations to TrainIds')
18 | parser.add_argument('cityscapes_path', help='cityscapes data path')
19 | parser.add_argument('--gt-dir', default='gtFine', type=str)
20 | parser.add_argument('-o', '--out-dir', help='output path')
21 | parser.add_argument(
22 | '--nproc', default=1, type=int, help='number of process')
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | def main():
28 | args = parse_args()
29 | cityscapes_path = args.cityscapes_path
30 | out_dir = args.out_dir if args.out_dir else cityscapes_path
31 | mkdir_or_exist(out_dir)
32 |
33 | gt_dir = osp.join(cityscapes_path, args.gt_dir)
34 |
35 | poly_files = []
36 | for poly in scandir(gt_dir, '_polygons.json', recursive=True):
37 | poly_file = osp.join(gt_dir, poly)
38 | poly_files.append(poly_file)
39 | if args.nproc > 1:
40 | track_parallel_progress(convert_json_to_label, poly_files, args.nproc)
41 | else:
42 | track_progress(convert_json_to_label, poly_files)
43 |
44 | split_names = ['train', 'val', 'test']
45 |
46 | for split in split_names:
47 | filenames = []
48 | for poly in scandir(
49 | osp.join(gt_dir, split), '_polygons.json', recursive=True):
50 | filenames.append(poly.replace('_gtFine_polygons.json', ''))
51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
52 | f.writelines(f + '\n' for f in filenames)
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/Segmentation/tools/dataset_converters/prophesee/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/tools/dataset_converters/prophesee/__init__.py
--------------------------------------------------------------------------------
/Segmentation/tools/dataset_converters/prophesee/io/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BICLab/Spike2Former/f7438600ca82f5ce664d8a8861c26012e7b9ca24/Segmentation/tools/dataset_converters/prophesee/io/__init__.py
--------------------------------------------------------------------------------
/Segmentation/tools/dataset_converters/prophesee/io/box_filtering.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Prophesee S.A.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
7 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 | # See the License for the specific language governing permissions and limitations under the License.
9 |
10 | """
11 | Define same filtering that we apply in:
12 | "Learning to detect objects on a 1 Megapixel Event Camera" by Etienne Perot et al.
13 |
14 | Namely, we apply 2 different filters:
15 | 1. skip all boxes before 0.5s (before we assume it is unlikely you have sufficient historic)
16 | 2. filter all boxes whose diagonal <= min_box_diag**2 and whose side <= min_box_side
17 | """
18 |
19 | from __future__ import print_function
20 | import numpy as np
21 |
22 |
23 | def filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=20):
24 | """Filters boxes according to the paper rule.
25 |
26 | To note: the default represents our threshold when evaluating GEN4 resolution (1280x720)
27 | To note: we assume the initial time of the video is always 0
28 |
29 | Args:
30 | boxes (np.ndarray): structured box array with fields ['t','x','y','w','h','class_id','track_id','class_confidence']
31 | (example BBOX_DTYPE is provided in src/box_loading.py)
32 |
33 | Returns:
34 | boxes: filtered boxes
35 | """
36 | ts = boxes['t']
37 | width = boxes['w']
38 | height = boxes['h']
39 | diag_square = width**2+height**2
40 | mask = (ts>skip_ts)*(diag_square >= min_box_diag**2)*(width >= min_box_side)*(height >= min_box_side)
41 | return boxes[mask]
42 |
43 |
--------------------------------------------------------------------------------
/Segmentation/tools/dataset_converters/prophesee/io/box_loading.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Prophesee S.A.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
7 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 | # See the License for the specific language governing permissions and limitations under the License.
9 |
10 | """
11 | Defines some tools to handle events.
12 | In particular :
13 | -> defines events' types
14 | -> defines functions to read events from binary .dat files using numpy
15 | -> defines functions to write events to binary .dat files using numpy
16 | """
17 |
18 | from __future__ import print_function
19 | import numpy as np
20 |
21 | BBOX_DTYPE = np.dtype({'names':['t','x','y','w','h','class_id','track_id','class_confidence'], 'formats':[' str:
21 | """Compute SHA256 message digest from a file."""
22 | hash_func = sha256()
23 | byte_array = bytearray(BLOCK_SIZE)
24 | memory_view = memoryview(byte_array)
25 | with open(filename, 'rb', buffering=0) as file:
26 | for block in iter(lambda: file.readinto(memory_view), 0):
27 | hash_func.update(memory_view[:block])
28 | return hash_func.hexdigest()
29 |
30 |
31 | def process_checkpoint(in_file, out_file):
32 | checkpoint = torch.load(in_file, map_location='cpu')
33 | # remove optimizer for smaller file size
34 | if 'optimizer' in checkpoint:
35 | del checkpoint['optimizer']
36 | # if it is necessary to remove some sensitive data in checkpoint['meta'],
37 | # add the code here.
38 | torch.save(checkpoint, out_file)
39 | sha = sha256sum(in_file)
40 | final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
41 | subprocess.Popen(['mv', out_file, final_file])
42 |
43 |
44 | def main():
45 | args = parse_args()
46 | process_checkpoint(args.in_file, args.out_file)
47 |
48 |
49 | if __name__ == '__main__':
50 | main()
51 |
--------------------------------------------------------------------------------
/Segmentation/tools/model_converters/beit2mmseg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 | from collections import OrderedDict
5 |
6 | import mmengine
7 | import torch
8 | from mmengine.runner import CheckpointLoader
9 |
10 |
11 | def convert_beit(ckpt):
12 | new_ckpt = OrderedDict()
13 |
14 | for k, v in ckpt.items():
15 | if k.startswith('patch_embed'):
16 | new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
17 | new_ckpt[new_key] = v
18 | if k.startswith('blocks'):
19 | new_key = k.replace('blocks', 'layers')
20 | if 'norm' in new_key:
21 | new_key = new_key.replace('norm', 'ln')
22 | elif 'mlp.fc1' in new_key:
23 | new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
24 | elif 'mlp.fc2' in new_key:
25 | new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
26 | new_ckpt[new_key] = v
27 | else:
28 | new_key = k
29 | new_ckpt[new_key] = v
30 |
31 | return new_ckpt
32 |
33 |
34 | def main():
35 | parser = argparse.ArgumentParser(
36 | description='Convert keys in official pretrained beit models to'
37 | 'MMSegmentation style.')
38 | parser.add_argument('src', help='src model path or url')
39 | # The dst path must be a full path of the new checkpoint.
40 | parser.add_argument('dst', help='save path')
41 | args = parser.parse_args()
42 |
43 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
44 | if 'state_dict' in checkpoint:
45 | state_dict = checkpoint['state_dict']
46 | elif 'model' in checkpoint:
47 | state_dict = checkpoint['model']
48 | else:
49 | state_dict = checkpoint
50 | weight = convert_beit(state_dict)
51 | mmengine.mkdir_or_exist(osp.dirname(args.dst))
52 | torch.save(weight, args.dst)
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/Segmentation/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CHECKPOINT=$4
9 | GPUS=${GPUS:-4}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/Segmentation/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | GPUS=${GPUS:-4}
9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4}
10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
11 | SRUN_ARGS=${SRUN_ARGS:-""}
12 | PY_ARGS=${@:4}
13 |
14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
15 | srun -p ${PARTITION} \
16 | --job-name=${JOB_NAME} \
17 | --gres=gpu:${GPUS_PER_NODE} \
18 | --ntasks=${GPUS} \
19 | --ntasks-per-node=${GPUS_PER_NODE} \
20 | --cpus-per-task=${CPUS_PER_TASK} \
21 | --kill-on-bad-exit=1 \
22 | ${SRUN_ARGS} \
23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
24 |
--------------------------------------------------------------------------------
/Segmentation/tools/test.sh:
--------------------------------------------------------------------------------
1 | # CONFIG=../configs/Spikeformer/SDTv2_maskformer_DCNpixelDecoder_ade20k.py
2 | CONFIG=../configs/Spikeformer/SDTv2_maskformer_DCNPixelDecoder_CityScapes.py
3 | # CONFIG=../configs/Spikeformer/SDTv2_Spike2former_voc_512x512.py
4 | # CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/SDTv2_maskformer_DCNpixelDecoder_ade20k/best_mIoU_iter_102500.pth' # ADE20k Train 46.3 Test 44.5 ()
5 | # CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/Ablation/v2_spike2former_voc2012_1x4/best_mIoU_iter_97500.pth' # VOC2012 Train 76.3 Test (Q_IFNode)
6 | # CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/V2_Spike2former_withoutshortcut/best_mIoU_iter_152500.pth' # ADE20k 1x4 without shortcut
7 | CHECKPOINT='/public/liguoqi/lzx/code/mmseg/tools/work_dirs/SDTv2_maskformer_DCNPixelDecoder_CityScapes/best_mIoU_iter_47500.pth' # CityScapes 74.2 (Multi-Spikenorm)
8 |
9 | GPUS=1
10 |
11 | PORT=${PORT:-29500}
12 |
13 |
14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
15 | python -m torch.distributed.launch \
16 | --nproc_per_node=$GPUS \
17 | --master_port=$PORT \
18 | $(dirname "$0")/test.py \
19 | $CONFIG \
20 | $CHECKPOINT \
21 | --launcher pytorch \
22 | ${@:4} \
23 | --out ./work_dir/vis_result/CityScapes\
24 | --show-dir ./work_dirs/vis_result/CityScapes/1x4
25 |
--------------------------------------------------------------------------------