├── README.md ├── assets └── concept.jpg ├── classification_distill ├── .DS_Store ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── celebahq_bs32.py │ │ │ ├── celebahq_bs32_strongweak.py │ │ │ ├── celebahq_bs64.py │ │ │ ├── celebahq_cls_bs64.py │ │ │ ├── cifar10_bs128.py │ │ │ ├── cifar10_bs128_2strong.py │ │ │ ├── cifar10_bs128_2task.py │ │ │ ├── cifar10_bs128_weakstrong.py │ │ │ ├── cifar10_bs128_weakstrong2.py │ │ │ ├── dsprite.py │ │ │ ├── imagenet_bs256.py │ │ │ ├── imagenet_bs256_randaug.py │ │ │ ├── imagenet_bs32.py │ │ │ ├── imagenet_bs32_randaug.py │ │ │ ├── imagenet_bs64.py │ │ │ ├── imagenet_bs64_randaug.py │ │ │ ├── pipelines │ │ │ │ ├── rand_aug.py │ │ │ │ └── rand_aug_cifar.py │ │ │ ├── shape3d.py │ │ │ └── tinyimagenet_bs128.py │ │ ├── default_runtime.py │ │ ├── models │ │ │ ├── mobilenet_v2_1x.py │ │ │ ├── resnet18.py │ │ │ ├── resnet18_cifar.py │ │ │ ├── resnet18_dsprite.py │ │ │ ├── resnet18_shape3d.py │ │ │ ├── resnet50.py │ │ │ ├── shufflenet_v2_1x.py │ │ │ ├── wide-resnet28-10.py │ │ │ └── wide-resnet28-2.py │ │ └── schedules │ │ │ ├── cifar10_bs128.py │ │ │ ├── dsprite_bs128.py │ │ │ ├── imagenet_bs1024_adamw_swin.py │ │ │ ├── imagenet_bs1024_coslr.py │ │ │ ├── imagenet_bs1024_linearlr_bn_nowd.py │ │ │ ├── imagenet_bs2048.py │ │ │ ├── imagenet_bs2048_AdamW.py │ │ │ ├── imagenet_bs2048_coslr.py │ │ │ ├── imagenet_bs256.py │ │ │ ├── imagenet_bs256_140e.py │ │ │ ├── imagenet_bs256_coslr.py │ │ │ ├── imagenet_bs256_coslr_300e.py │ │ │ ├── imagenet_bs256_coslr_mobilenetv2.py │ │ │ ├── imagenet_bs256_epochstep.py │ │ │ ├── imagenet_bs4096_AdamW.py │ │ │ └── shape3d_bs128.py │ ├── baseline │ │ ├── mbnv2_b128x1_cifar10_finetune.py │ │ ├── mbnv2_b128x1_tinyimagenet_finetune.py │ │ ├── resnet18_b128x1_cifar10_finetune.py │ │ └── resnet18_b128x1_tinyimagenet_finetune.py │ ├── celebahq-repfusion │ │ ├── ddpm_hrnet18_b32x2_celebahq_pretrain_maxtime1_clean.py │ │ ├── ddpm_resnet18_b16x4_celebahq_pretrain_maxtime1_clean.py │ │ └── ddpm_resnet50_b8x8_celebahq_pretrain_maxtime1_clean.py │ ├── cifar10-at │ │ ├── ddpm-mbnv2_at_cifar10.py │ │ └── ddpm-r18_at_cifar10.py │ ├── cifar10-hint │ │ ├── ddpm-mbnv2_hint_cifar10.py │ │ └── ddpm-r18_hint_cifar10.py │ ├── cifar10-rkd │ │ ├── ddpm-mbnv2_rkd_cifar10.py │ │ └── ddpm-r18_rkd_cifar10.py │ ├── imagenet_hint │ │ └── ddpm-r18_hint_imagenet.py │ ├── tinyimagenet_at │ │ ├── ddpm-mbnv2_at_b32x2_tinyimagenet.py │ │ └── ddpm-r18_at_b32x2_tinyimagenet.py │ ├── tinyimagenet_hint │ │ ├── ddpm-mbnv2_hint_tinyimagenet.py │ │ └── ddpm-r18_hint_tinyimagenet.py │ └── tinyimagenet_rkd │ │ ├── ddpm-mbnv2_rkd_tinyimagenet.py │ │ └── ddpm-r18_rkd_tinyimagenet.py ├── mmcls │ ├── .DS_Store │ ├── __init__.py │ ├── apis │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── multitask_test.py │ │ ├── test.py │ │ └── train.py │ ├── core │ │ ├── __init__.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── eval_hooks.py │ │ │ ├── eval_metrics.py │ │ │ ├── mean_ap.py │ │ │ ├── multilabel_eval_metrics.py │ │ │ └── multitask_eval_hooks.py │ │ ├── export │ │ │ ├── __init__.py │ │ │ └── test.py │ │ ├── fp16 │ │ │ ├── __init__.py │ │ │ ├── decorators.py │ │ │ ├── hooks.py │ │ │ └── utils.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── dist_utils.py │ │ │ ├── kd_hook.py │ │ │ ├── misc.py │ │ │ ├── timestep_decay.py │ │ │ └── visualize.py │ │ └── visualization │ │ │ ├── __init__.py │ │ │ └── image.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── builder.py │ │ ├── celeba.py │ │ ├── cifar.py │ │ ├── custom.py │ │ ├── dataset_wrappers.py │ │ ├── disentangle_data │ │ │ ├── __init__.py │ │ │ ├── dsprites.py │ │ │ ├── mpi3d.py │ │ │ ├── multi_task.py │ │ │ └── shape3d.py │ │ ├── imagenet.py │ │ ├── images.py │ │ ├── multi_label.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ ├── auto_augment.py │ │ │ ├── compose.py │ │ │ ├── formating.py │ │ │ ├── loading.py │ │ │ ├── multiview.py │ │ │ └── transforms.py │ │ ├── samplers │ │ │ ├── __init__.py │ │ │ └── distributed_sampler.py │ │ └── utils.py │ ├── models │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── base_backbone.py │ │ │ ├── disentangle.py │ │ │ ├── hrnet.py │ │ │ ├── mobilenet_v2.py │ │ │ ├── mobilenet_v2_cifar.py │ │ │ ├── resnet.py │ │ │ ├── resnet_cifar.py │ │ │ ├── shufflenet_v2.py │ │ │ ├── tsn.py │ │ │ └── wideresnet.py │ │ ├── builder.py │ │ ├── classifiers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── image.py │ │ │ ├── kd.py │ │ │ ├── repfusion.py │ │ │ └── utils.py │ │ ├── guided_diffusion │ │ │ ├── __init__.py │ │ │ ├── fp16_util.py │ │ │ ├── gaussian_diffusion.py │ │ │ ├── load_model.py │ │ │ ├── logger.py │ │ │ ├── losses.py │ │ │ ├── nn.py │ │ │ ├── respace.py │ │ │ ├── script_util.py │ │ │ └── unet.py │ │ ├── heads │ │ │ ├── __init__.py │ │ │ ├── base_head.py │ │ │ ├── cls_head.py │ │ │ ├── linear_head.py │ │ │ ├── multi_label_head.py │ │ │ └── multitask_linear_head.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ ├── afd.py │ │ │ ├── asymmetric_loss.py │ │ │ ├── at.py │ │ │ ├── crd.py │ │ │ ├── crd_loss │ │ │ │ ├── __init__.py │ │ │ │ ├── criterion.py │ │ │ │ └── memory.py │ │ │ ├── cross_entropy_loss.py │ │ │ ├── focal_loss.py │ │ │ ├── kd_loss.py │ │ │ ├── label_smooth_loss.py │ │ │ ├── norm_l2.py │ │ │ ├── rkd.py │ │ │ └── utils.py │ │ ├── necks │ │ │ ├── __init__.py │ │ │ ├── gap.py │ │ │ └── hr_fuse.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── augment │ │ │ ├── __init__.py │ │ │ ├── augments.py │ │ │ ├── builder.py │ │ │ ├── cutmix.py │ │ │ ├── identity.py │ │ │ └── mixup.py │ │ │ ├── channel_shuffle.py │ │ │ ├── embed.py │ │ │ ├── helpers.py │ │ │ ├── inverted_residual.py │ │ │ ├── make_divisible.py │ │ │ └── se_layer.py │ ├── utils │ │ ├── __init__.py │ │ ├── collect_env.py │ │ └── logger.py │ └── version.py └── tools │ ├── analysis_tools │ ├── analysis_para.py │ ├── analyze_logs.py │ ├── analyze_results.py │ ├── eval_metric.py │ ├── get_flops.py │ └── run.sh │ ├── convert_models │ ├── mobilenetv2_to_mmcls.py │ ├── publish_model.py │ ├── shufflenetv2_to_mmcls.py │ └── vgg_to_mmcls.py │ ├── deployment │ ├── mmcls2torchserve.py │ ├── mmcls_handler.py │ ├── onnx2tensorrt.py │ ├── pytorch2onnx.py │ ├── pytorch2torchscript.py │ └── test.py │ ├── dist_test.sh │ ├── dist_train.py │ ├── dist_train.sh │ ├── find_duplicate.py │ ├── load_ckp.py │ ├── misc │ ├── print_config.py │ └── print_model.py │ ├── plot │ ├── CIFAR10_S.npy │ ├── CIFAR10_rank.npy │ ├── mnist_S.npy │ ├── mnist_rank.npy │ ├── plot_CIFAR10.py │ └── plot_mnist.py │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test.py │ └── train.py ├── landmark ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── 300w.py │ │ │ ├── aflw.py │ │ │ ├── aic.py │ │ │ ├── animalpose.py │ │ │ ├── ap10k.py │ │ │ ├── atrw.py │ │ │ ├── campus.py │ │ │ ├── coco.py │ │ │ ├── coco_wholebody.py │ │ │ ├── coco_wholebody_face.py │ │ │ ├── coco_wholebody_hand.py │ │ │ ├── cofw.py │ │ │ ├── crowdpose.py │ │ │ ├── deepfashion2.py │ │ │ ├── deepfashion_full.py │ │ │ ├── deepfashion_lower.py │ │ │ ├── deepfashion_upper.py │ │ │ ├── fly.py │ │ │ ├── freihand2d.py │ │ │ ├── h36m.py │ │ │ ├── halpe.py │ │ │ ├── horse10.py │ │ │ ├── interhand2d.py │ │ │ ├── interhand3d.py │ │ │ ├── jhmdb.py │ │ │ ├── locust.py │ │ │ ├── macaque.py │ │ │ ├── mhp.py │ │ │ ├── mpi_inf_3dhp.py │ │ │ ├── mpii.py │ │ │ ├── mpii_trb.py │ │ │ ├── nvgesture.py │ │ │ ├── ochuman.py │ │ │ ├── onehand10k.py │ │ │ ├── panoptic_body3d.py │ │ │ ├── panoptic_hand2d.py │ │ │ ├── posetrack18.py │ │ │ ├── rhd2d.py │ │ │ ├── shelf.py │ │ │ ├── wflw.py │ │ │ └── zebra.py │ │ ├── default_runtime.py │ │ └── filters │ │ │ ├── gaussian.py │ │ │ ├── one_euro.py │ │ │ ├── savizky_golay.py │ │ │ ├── smoothnet_h36m.md │ │ │ ├── smoothnet_t16_h36m.py │ │ │ ├── smoothnet_t32_h36m.py │ │ │ ├── smoothnet_t64_h36m.py │ │ │ └── smoothnet_t8_h36m.py │ └── face │ │ └── 2d_kpt_sview_rgb_img │ │ └── topdown_heatmap │ │ └── wflw │ │ ├── hrnetv2_w18_wflw_256x256_baseline_bs128x2.py │ │ ├── hrnetv2_w18_wflw_256x256_bs128x2_repfussion.py │ │ ├── res50_wflw_256x256_baseline_bs64x2.py │ │ └── res50_wflw_256x256_bs64x2_repfusion.py ├── mmpose │ ├── __init__.py │ ├── apis │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── inference_3d.py │ │ ├── inference_tracking.py │ │ ├── test.py │ │ ├── train.py │ │ └── webcam │ │ │ ├── __init__.py │ │ │ ├── nodes │ │ │ ├── __init__.py │ │ │ ├── base_visualizer_node.py │ │ │ ├── helper_nodes │ │ │ │ ├── __init__.py │ │ │ │ ├── monitor_node.py │ │ │ │ ├── object_assigner_node.py │ │ │ │ └── recorder_node.py │ │ │ ├── model_nodes │ │ │ │ ├── __init__.py │ │ │ │ ├── detector_node.py │ │ │ │ ├── hand_gesture_node.py │ │ │ │ ├── pose_estimator_node.py │ │ │ │ └── pose_tracker_node.py │ │ │ ├── node.py │ │ │ ├── registry.py │ │ │ └── visualizer_nodes │ │ │ │ ├── __init__.py │ │ │ │ ├── bigeye_effect_node.py │ │ │ │ ├── notice_board_node.py │ │ │ │ ├── object_visualizer_node.py │ │ │ │ └── sunglasses_effect_node.py │ │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── buffer.py │ │ │ ├── event.py │ │ │ ├── image_capture.py │ │ │ ├── message.py │ │ │ ├── misc.py │ │ │ └── pose.py │ │ │ └── webcam_executor.py │ ├── core │ │ ├── __init__.py │ │ ├── bbox │ │ │ ├── __init__.py │ │ │ └── transforms.py │ │ ├── camera │ │ │ ├── __init__.py │ │ │ ├── camera_base.py │ │ │ ├── single_camera.py │ │ │ └── single_camera_torch.py │ │ ├── distributed_wrapper.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── bottom_up_eval.py │ │ │ ├── eval_hooks.py │ │ │ ├── mesh_eval.py │ │ │ ├── pose3d_eval.py │ │ │ └── top_down_eval.py │ │ ├── fp16 │ │ │ ├── __init__.py │ │ │ ├── decorators.py │ │ │ ├── hooks.py │ │ │ └── utils.py │ │ ├── optimizers │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ └── layer_decay_optimizer_constructor.py │ │ ├── post_processing │ │ │ ├── __init__.py │ │ │ ├── group.py │ │ │ ├── nms.py │ │ │ ├── one_euro_filter.py │ │ │ ├── post_transforms.py │ │ │ ├── smoother.py │ │ │ └── temporal_filters │ │ │ │ ├── __init__.py │ │ │ │ ├── builder.py │ │ │ │ ├── filter.py │ │ │ │ ├── gaussian_filter.py │ │ │ │ ├── one_euro_filter.py │ │ │ │ ├── savizky_golay_filter.py │ │ │ │ └── smoothnet_filter.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── dist_utils.py │ │ │ ├── model_util_hooks.py │ │ │ └── regularizations.py │ │ └── visualization │ │ │ ├── __init__.py │ │ │ └── image.py │ ├── datasets │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── dataset_info.py │ │ ├── dataset_wrappers.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── animal │ │ │ │ ├── __init__.py │ │ │ │ ├── animal_ap10k_dataset.py │ │ │ │ ├── animal_atrw_dataset.py │ │ │ │ ├── animal_base_dataset.py │ │ │ │ ├── animal_fly_dataset.py │ │ │ │ ├── animal_horse10_dataset.py │ │ │ │ ├── animal_locust_dataset.py │ │ │ │ ├── animal_macaque_dataset.py │ │ │ │ ├── animal_pose_dataset.py │ │ │ │ └── animal_zebra_dataset.py │ │ │ ├── base │ │ │ │ ├── __init__.py │ │ │ │ ├── kpt_2d_sview_rgb_img_bottom_up_dataset.py │ │ │ │ ├── kpt_2d_sview_rgb_img_top_down_dataset.py │ │ │ │ ├── kpt_2d_sview_rgb_vid_top_down_dataset.py │ │ │ │ ├── kpt_3d_mview_rgb_img_direct_dataset.py │ │ │ │ ├── kpt_3d_sview_kpt_2d_dataset.py │ │ │ │ └── kpt_3d_sview_rgb_img_top_down_dataset.py │ │ │ ├── body3d │ │ │ │ ├── __init__.py │ │ │ │ ├── body3d_base_dataset.py │ │ │ │ ├── body3d_h36m_dataset.py │ │ │ │ ├── body3d_mpi_inf_3dhp_dataset.py │ │ │ │ ├── body3d_mview_direct_campus_dataset.py │ │ │ │ ├── body3d_mview_direct_panoptic_dataset.py │ │ │ │ ├── body3d_mview_direct_shelf_dataset.py │ │ │ │ └── body3d_semi_supervision_dataset.py │ │ │ ├── bottom_up │ │ │ │ ├── __init__.py │ │ │ │ ├── bottom_up_aic.py │ │ │ │ ├── bottom_up_base_dataset.py │ │ │ │ ├── bottom_up_coco.py │ │ │ │ ├── bottom_up_coco_wholebody.py │ │ │ │ ├── bottom_up_crowdpose.py │ │ │ │ └── bottom_up_mhp.py │ │ │ ├── face │ │ │ │ ├── __init__.py │ │ │ │ ├── face_300w_dataset.py │ │ │ │ ├── face_aflw_dataset.py │ │ │ │ ├── face_base_dataset.py │ │ │ │ ├── face_coco_wholebody_dataset.py │ │ │ │ ├── face_cofw_dataset.py │ │ │ │ └── face_wflw_dataset.py │ │ │ ├── fashion │ │ │ │ ├── __init__.py │ │ │ │ ├── deepfashion_dataset.py │ │ │ │ └── fashion_base_dataset.py │ │ │ ├── gesture │ │ │ │ ├── __init__.py │ │ │ │ ├── gesture_base_dataset.py │ │ │ │ └── nvgesture_dataset.py │ │ │ ├── hand │ │ │ │ ├── __init__.py │ │ │ │ ├── freihand_dataset.py │ │ │ │ ├── hand_base_dataset.py │ │ │ │ ├── hand_coco_wholebody_dataset.py │ │ │ │ ├── interhand2d_dataset.py │ │ │ │ ├── interhand3d_dataset.py │ │ │ │ ├── onehand10k_dataset.py │ │ │ │ ├── panoptic_hand2d_dataset.py │ │ │ │ └── rhd2d_dataset.py │ │ │ ├── mesh │ │ │ │ ├── __init__.py │ │ │ │ ├── mesh_adv_dataset.py │ │ │ │ ├── mesh_base_dataset.py │ │ │ │ ├── mesh_h36m_dataset.py │ │ │ │ ├── mesh_mix_dataset.py │ │ │ │ └── mosh_dataset.py │ │ │ └── top_down │ │ │ │ ├── __init__.py │ │ │ │ ├── topdown_aic_dataset.py │ │ │ │ ├── topdown_base_dataset.py │ │ │ │ ├── topdown_coco_dataset.py │ │ │ │ ├── topdown_coco_wholebody_dataset.py │ │ │ │ ├── topdown_crowdpose_dataset.py │ │ │ │ ├── topdown_h36m_dataset.py │ │ │ │ ├── topdown_halpe_dataset.py │ │ │ │ ├── topdown_jhmdb_dataset.py │ │ │ │ ├── topdown_mhp_dataset.py │ │ │ │ ├── topdown_mpii_dataset.py │ │ │ │ ├── topdown_mpii_trb_dataset.py │ │ │ │ ├── topdown_ochuman_dataset.py │ │ │ │ ├── topdown_posetrack18_dataset.py │ │ │ │ └── topdown_posetrack18_video_dataset.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ ├── bottom_up_transform.py │ │ │ ├── gesture_transform.py │ │ │ ├── hand_transform.py │ │ │ ├── loading.py │ │ │ ├── mesh_transform.py │ │ │ ├── pose3d_transform.py │ │ │ ├── shared_transform.py │ │ │ └── top_down_transform.py │ │ ├── registry.py │ │ └── samplers │ │ │ ├── __init__.py │ │ │ └── distributed_sampler.py │ ├── deprecated.py │ ├── models │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── base_backbone.py │ │ │ ├── cpm.py │ │ │ ├── hourglass.py │ │ │ ├── hourglass_ae.py │ │ │ ├── hrformer.py │ │ │ ├── hrnet.py │ │ │ ├── i3d.py │ │ │ ├── litehrnet.py │ │ │ ├── mobilenet_v2.py │ │ │ ├── mobilenet_v3.py │ │ │ ├── mspn.py │ │ │ ├── pvt.py │ │ │ ├── regnet.py │ │ │ ├── resnest.py │ │ │ ├── resnet.py │ │ │ ├── resnext.py │ │ │ ├── rsn.py │ │ │ ├── scnet.py │ │ │ ├── seresnet.py │ │ │ ├── seresnext.py │ │ │ ├── shufflenet_v1.py │ │ │ ├── shufflenet_v2.py │ │ │ ├── swin.py │ │ │ ├── tcformer.py │ │ │ ├── tcn.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── channel_shuffle.py │ │ │ │ ├── ckpt_convert.py │ │ │ │ ├── inverted_residual.py │ │ │ │ ├── make_divisible.py │ │ │ │ ├── se_layer.py │ │ │ │ └── utils.py │ │ │ ├── v2v_net.py │ │ │ ├── vgg.py │ │ │ ├── vipnas_mbv3.py │ │ │ └── vipnas_resnet.py │ │ ├── builder.py │ │ ├── detectors │ │ │ ├── __init__.py │ │ │ ├── associative_embedding.py │ │ │ ├── base.py │ │ │ ├── cid.py │ │ │ ├── gesture_recognizer.py │ │ │ ├── interhand_3d.py │ │ │ ├── mesh.py │ │ │ ├── multi_task.py │ │ │ ├── multiview_pose.py │ │ │ ├── one_stage.py │ │ │ ├── pose_lifter.py │ │ │ ├── posewarper.py │ │ │ └── top_down.py │ │ ├── heads │ │ │ ├── __init__.py │ │ │ ├── ae_higher_resolution_head.py │ │ │ ├── ae_multi_stage_head.py │ │ │ ├── ae_simple_head.py │ │ │ ├── cid_head.py │ │ │ ├── deconv_head.py │ │ │ ├── deeppose_regression_head.py │ │ │ ├── dekr_head.py │ │ │ ├── hmr_head.py │ │ │ ├── interhand_3d_head.py │ │ │ ├── mtut_head.py │ │ │ ├── temporal_regression_head.py │ │ │ ├── topdown_heatmap_base_head.py │ │ │ ├── topdown_heatmap_multi_stage_head.py │ │ │ ├── topdown_heatmap_simple_head.py │ │ │ ├── vipnas_heatmap_simple_head.py │ │ │ └── voxelpose_head.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── classfication_loss.py │ │ │ ├── heatmap_loss.py │ │ │ ├── mesh_loss.py │ │ │ ├── mse_loss.py │ │ │ ├── multi_loss_factory.py │ │ │ └── regression_loss.py │ │ ├── misc │ │ │ ├── __init__.py │ │ │ └── discriminator.py │ │ ├── necks │ │ │ ├── __init__.py │ │ │ ├── fpn.py │ │ │ ├── gap_neck.py │ │ │ ├── posewarper_neck.py │ │ │ └── tcformer_mta_neck.py │ │ ├── registry.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── ckpt_convert.py │ │ │ ├── geometry.py │ │ │ ├── misc.py │ │ │ ├── ops.py │ │ │ ├── realnvp.py │ │ │ ├── rescore.py │ │ │ ├── smpl.py │ │ │ ├── tcformer_utils.py │ │ │ └── transformer.py │ ├── utils │ │ ├── __init__.py │ │ ├── collect_env.py │ │ ├── hooks.py │ │ ├── logger.py │ │ ├── setup_env.py │ │ └── timer.py │ └── version.py ├── requirements.txt ├── requirements │ ├── albu.txt │ ├── build.txt │ ├── docs.txt │ ├── mminstall.txt │ ├── optional.txt │ ├── poseval.txt │ ├── readthedocs.txt │ ├── runtime.txt │ └── tests.txt ├── setup.cfg ├── setup.py └── tools │ ├── analysis │ ├── analyze_logs.py │ ├── benchmark_inference.py │ ├── benchmark_processing.py │ ├── get_flops.py │ └── print_config.py │ ├── dataset │ ├── h36m_to_coco.py │ ├── mat2json.py │ ├── parse_animalpose_dataset.py │ ├── parse_cofw_dataset.py │ ├── parse_deepposekit_dataset.py │ ├── parse_macaquepose_dataset.py │ ├── preprocess_h36m.py │ └── preprocess_mpi_inf_3dhp.py │ ├── deployment │ ├── mmpose2torchserve.py │ ├── mmpose_handler.py │ ├── pytorch2onnx.py │ └── test_torchserver.py │ ├── dist_test.sh │ ├── dist_train.sh │ ├── evaluate_wlfw.sh │ ├── misc │ ├── keypoints2coco_without_mmdet.py │ └── publish_model.py │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test.py │ └── train.py └── segmentation ├── configs ├── _base_ │ ├── datasets │ │ ├── ade20k.py │ │ ├── ade20k_640x640.py │ │ ├── celebahqmask.py │ │ ├── chase_db1.py │ │ ├── cityscapes.py │ │ ├── cityscapes_1024x1024.py │ │ ├── cityscapes_768x768.py │ │ ├── cityscapes_769x769.py │ │ ├── cityscapes_832x832.py │ │ ├── coco-stuff10k.py │ │ ├── coco-stuff164k.py │ │ ├── drive.py │ │ ├── hrf.py │ │ ├── imagenets.py │ │ ├── isaid.py │ │ ├── loveda.py │ │ ├── occlude_face.py │ │ ├── pascal_context.py │ │ ├── pascal_context_59.py │ │ ├── pascal_voc12.py │ │ ├── pascal_voc12_aug.py │ │ ├── potsdam.py │ │ ├── stare.py │ │ └── vaihingen.py │ ├── default_runtime.py │ ├── models │ │ ├── ann_r50-d8.py │ │ ├── apcnet_r50-d8.py │ │ ├── bisenetv1_r18-d32.py │ │ ├── bisenetv2.py │ │ ├── ccnet_r50-d8.py │ │ ├── cgnet.py │ │ ├── danet_r50-d8.py │ │ ├── deeplabv3_r50-d8.py │ │ ├── deeplabv3_unet_s5-d16.py │ │ ├── deeplabv3plus_r50-d8.py │ │ ├── dmnet_r50-d8.py │ │ ├── dnl_r50-d8.py │ │ ├── dpt_vit-b16.py │ │ ├── emanet_r50-d8.py │ │ ├── encnet_r50-d8.py │ │ ├── erfnet_fcn.py │ │ ├── fast_scnn.py │ │ ├── fastfcn_r50-d32_jpu_psp.py │ │ ├── fcn_hr18.py │ │ ├── fcn_r50-d8.py │ │ ├── fcn_unet_s5-d16.py │ │ ├── fpn_poolformer_s12.py │ │ ├── fpn_r50.py │ │ ├── gcnet_r50-d8.py │ │ ├── icnet_r50-d8.py │ │ ├── isanet_r50-d8.py │ │ ├── lraspp_m-v3-d8.py │ │ ├── nonlocal_r50-d8.py │ │ ├── ocrnet_hr18.py │ │ ├── ocrnet_r50-d8.py │ │ ├── pointrend_r50.py │ │ ├── psanet_r50-d8.py │ │ ├── pspnet_r50-d8.py │ │ ├── pspnet_unet_s5-d16.py │ │ ├── segformer_mit-b0.py │ │ ├── segmenter_vit-b16_mask.py │ │ ├── setr_mla.py │ │ ├── setr_naive.py │ │ ├── setr_pup.py │ │ ├── stdc.py │ │ ├── twins_pcpvt-s_fpn.py │ │ ├── twins_pcpvt-s_upernet.py │ │ ├── upernet_beit.py │ │ ├── upernet_convnext.py │ │ ├── upernet_mae.py │ │ ├── upernet_r50.py │ │ ├── upernet_swin.py │ │ └── upernet_vit-b16_ln_mln.py │ └── schedules │ │ ├── schedule_160k.py │ │ ├── schedule_20k.py │ │ ├── schedule_320k.py │ │ ├── schedule_40k.py │ │ └── schedule_80k.py └── celebahq_mask │ ├── bisenetv1_r18-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_baseline.py │ ├── bisenetv1_r18-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_repfusion.py │ ├── bisenetv1_r50-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_baseline.py │ └── bisenetv1_r50-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_repfussion.py ├── mmseg ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── builder.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── class_names.py │ │ ├── eval_hooks.py │ │ └── metrics.py │ ├── hook │ │ ├── __init__.py │ │ └── wandblogger_hook.py │ ├── optimizers │ │ ├── __init__.py │ │ └── layer_decay_optimizer_constructor.py │ ├── seg │ │ ├── __init__.py │ │ ├── builder.py │ │ └── sampler │ │ │ ├── __init__.py │ │ │ ├── base_pixel_sampler.py │ │ │ └── ohem_pixel_sampler.py │ └── utils │ │ ├── __init__.py │ │ ├── dist_util.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── ade.py │ ├── builder.py │ ├── celebahqmask.py │ ├── chase_db1.py │ ├── cityscapes.py │ ├── coco_stuff.py │ ├── custom.py │ ├── dark_zurich.py │ ├── dataset_wrappers.py │ ├── drive.py │ ├── face.py │ ├── hrf.py │ ├── imagenets.py │ ├── isaid.py │ ├── isprs.py │ ├── loveda.py │ ├── night_driving.py │ ├── pascal_context.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── formatting.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py │ ├── potsdam.py │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ ├── stare.py │ └── voc.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── beit.py │ │ ├── bisenetv1.py │ │ ├── bisenetv2.py │ │ ├── cgnet.py │ │ ├── erfnet.py │ │ ├── fast_scnn.py │ │ ├── hrnet.py │ │ ├── icnet.py │ │ ├── mae.py │ │ ├── mit.py │ │ ├── mobilenet_v2.py │ │ ├── mobilenet_v3.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── stdc.py │ │ ├── swin.py │ │ ├── timm_backbone.py │ │ ├── twins.py │ │ ├── unet.py │ │ └── vit.py │ ├── builder.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── ann_head.py │ │ ├── apc_head.py │ │ ├── aspp_head.py │ │ ├── cascade_decode_head.py │ │ ├── cc_head.py │ │ ├── da_head.py │ │ ├── decode_head.py │ │ ├── dm_head.py │ │ ├── dnl_head.py │ │ ├── dpt_head.py │ │ ├── ema_head.py │ │ ├── enc_head.py │ │ ├── fcn_head.py │ │ ├── fpn_head.py │ │ ├── gc_head.py │ │ ├── isa_head.py │ │ ├── knet_head.py │ │ ├── lraspp_head.py │ │ ├── nl_head.py │ │ ├── ocr_head.py │ │ ├── point_head.py │ │ ├── psa_head.py │ │ ├── psp_head.py │ │ ├── segformer_head.py │ │ ├── segmenter_mask_head.py │ │ ├── sep_aspp_head.py │ │ ├── sep_fcn_head.py │ │ ├── setr_mla_head.py │ │ ├── setr_up_head.py │ │ ├── stdc_head.py │ │ └── uper_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── cross_entropy_loss.py │ │ ├── dice_loss.py │ │ ├── focal_loss.py │ │ ├── lovasz_loss.py │ │ ├── tversky_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ ├── featurepyramid.py │ │ ├── fpn.py │ │ ├── ic_neck.py │ │ ├── jpu.py │ │ ├── mla_neck.py │ │ └── multilevel_neck.py │ ├── segmentors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cascade_encoder_decoder.py │ │ └── encoder_decoder.py │ └── utils │ │ ├── __init__.py │ │ ├── embed.py │ │ ├── inverted_residual.py │ │ ├── make_divisible.py │ │ ├── res_layer.py │ │ ├── se_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ └── up_conv_block.py ├── ops │ ├── __init__.py │ ├── encoding.py │ └── wrappers.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── logger.py │ ├── misc.py │ ├── set_env.py │ └── util_distribution.py └── version.py ├── requirements.txt ├── requirements ├── docs.txt ├── mminstall.txt ├── optional.txt ├── readthedocs.txt ├── runtime.txt └── tests.txt ├── setup.cfg ├── setup.py └── tools ├── analyze_logs.py ├── benchmark.py ├── browse_dataset.py ├── confusion_matrix.py ├── convert_datasets ├── chase_db1.py ├── cityscapes.py ├── coco_stuff10k.py ├── coco_stuff164k.py ├── drive.py ├── hrf.py ├── isaid.py ├── loveda.py ├── pascal_context.py ├── potsdam.py ├── stare.py ├── vaihingen.py └── voc_aug.py ├── dataset └── process_celebahqmask.py ├── deploy_test.py ├── dist_test.sh ├── dist_train.sh ├── get_flops.py ├── model_converters ├── beit2mmseg.py ├── mit2mmseg.py ├── stdc2mmseg.py ├── swin2mmseg.py ├── twins2mmseg.py ├── vit2mmseg.py └── vitjax2mmseg.py ├── model_ensemble.py ├── onnx2tensorrt.py ├── print_config.py ├── publish_model.py ├── pytorch2onnx.py ├── pytorch2torchscript.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── torchserve ├── mmseg2torchserve.py ├── mmseg_handler.py └── test_torchserve.py └── train.py /assets/concept.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/assets/concept.jpg -------------------------------------------------------------------------------- /classification_distill/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/.DS_Store -------------------------------------------------------------------------------- /classification_distill/configs/_base_/datasets/cifar10_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CIFAR10' 3 | img_norm_cfg = dict( 4 | mean=[128, 128, 128], 5 | std=[128, 128, 128], 6 | to_rgb=False) 7 | 8 | train_pipeline = [ 9 | dict(type='RandomCrop', size=32, padding=4), 10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='ImageToTensor', keys=['img']), 13 | dict(type='ToTensor', keys=['gt_label']), 14 | dict(type='Collect', keys=['img', 'gt_label']) 15 | ] 16 | test_pipeline = [ 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='ImageToTensor', keys=['img']), 19 | dict(type='Collect', keys=['img']) 20 | # dict(type='ToTensor', keys=['gt_label']), 21 | # dict(type='Collect', keys=['img', 'gt_label']) 22 | ] 23 | data = dict( 24 | samples_per_gpu=128, 25 | workers_per_gpu=2, 26 | train=dict( 27 | type=dataset_type, data_prefix='data/cifar10', 28 | pipeline=train_pipeline), 29 | val=dict( 30 | type=dataset_type, 31 | data_prefix='data/cifar10', 32 | pipeline=test_pipeline, 33 | test_mode=True), 34 | test=dict( 35 | type=dataset_type, 36 | data_prefix='data/cifar10', 37 | pipeline=test_pipeline, 38 | test_mode=True)) 39 | evaluation = dict(interval=5, metric='accuracy') -------------------------------------------------------------------------------- /classification_distill/configs/_base_/datasets/cifar10_bs128_2task.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CIFAR10_2Task' 3 | img_norm_cfg = dict( 4 | mean=[125.307, 122.961, 113.8575], 5 | std=[51.5865, 50.847, 51.255], 6 | to_rgb=False) 7 | train_pipeline = [ 8 | dict(type='RandomCrop', size=32, padding=4), 9 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 10 | dict(type='Normalize', **img_norm_cfg), 11 | dict(type='ImageToTensor', keys=['img']), 12 | dict(type='ToTensor', keys=['gt_label']), 13 | dict(type='Collect', keys=['img', 'gt_label']) 14 | ] 15 | test_pipeline = [ 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='ImageToTensor', keys=['img']), 18 | dict(type='Collect', keys=['img']) 19 | ] 20 | data = dict( 21 | samples_per_gpu=128, 22 | workers_per_gpu=2, 23 | train=dict( 24 | type=dataset_type, data_prefix='data/cifar10', 25 | pipeline=train_pipeline), 26 | val=dict( 27 | type=dataset_type, 28 | data_prefix='data/cifar10', 29 | pipeline=test_pipeline, 30 | test_mode=True), 31 | test=dict( 32 | type=dataset_type, 33 | data_prefix='data/cifar10', 34 | pipeline=test_pipeline, 35 | test_mode=True)) 36 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/datasets/dsprite.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'dSprites' 3 | multi_task = True 4 | img_norm_cfg = dict( 5 | mean=[0.5], 6 | std=[0.5], 7 | to_rgb=False) 8 | train_pipeline = [ 9 | dict(type='Normalize', **img_norm_cfg), 10 | dict(type='ImageToTensor', keys=['img']), 11 | dict(type='ToTensor', keys=['gt_label']), 12 | dict(type='Collect', keys=['img', 'gt_label']) 13 | ] 14 | test_pipeline = [ 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='ImageToTensor', keys=['img']), 17 | dict(type='Collect', keys=['img']) 18 | ] 19 | data = dict( 20 | samples_per_gpu=128, 21 | workers_per_gpu=4, 22 | train=dict( 23 | type=dataset_type, data_prefix='data/dsprite', 24 | pipeline=train_pipeline), 25 | val=dict( 26 | type=dataset_type, 27 | data_prefix='data/dsprite', 28 | pipeline=test_pipeline, 29 | test_mode=True), 30 | test=dict( 31 | type=dataset_type, 32 | data_prefix='data/dsprite', 33 | pipeline=test_pipeline, 34 | test_mode=True)) 35 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/datasets/shape3d.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'Shape3D' 3 | multi_task = True 4 | img_norm_cfg = dict( 5 | mean=[127.0, 127.0, 127.0], 6 | std=[127.0, 127.0, 127.0], 7 | to_rgb=False) 8 | train_pipeline = [ 9 | dict(type='Normalize', **img_norm_cfg), 10 | dict(type='ImageToTensor', keys=['img']), 11 | dict(type='ToTensor', keys=['gt_label']), 12 | dict(type='Collect', keys=['img', 'gt_label']) 13 | ] 14 | test_pipeline = [ 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='ImageToTensor', keys=['img']), 17 | dict(type='Collect', keys=['img']) 18 | ] 19 | data = dict( 20 | samples_per_gpu=256, 21 | workers_per_gpu=2, 22 | train=dict( 23 | type=dataset_type, data_prefix='data/shape3d', 24 | pipeline=train_pipeline), 25 | val=dict( 26 | type=dataset_type, 27 | data_prefix='data/shape3d', 28 | pipeline=test_pipeline, 29 | test_mode=True), 30 | test=dict( 31 | type=dataset_type, 32 | data_prefix='data/shape3d', 33 | pipeline=test_pipeline, 34 | test_mode=True)) 35 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/datasets/tinyimagenet_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'TinyImageNet' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='RandomResizedCrop', size=64), 8 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 9 | dict(type='Normalize', **img_norm_cfg), 10 | dict(type='ImageToTensor', keys=['img']), 11 | dict(type='ToTensor', keys=['gt_label']), 12 | dict(type='Collect', keys=['img', 'gt_label']) 13 | ] 14 | test_pipeline = [ 15 | dict(type='LoadImageFromFile'), 16 | dict(type='Resize', size=(64, -1)), 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='ImageToTensor', keys=['img']), 19 | dict(type='Collect', keys=['img']) 20 | ] 21 | data = dict( 22 | samples_per_gpu=128, 23 | workers_per_gpu=4, 24 | train=dict( 25 | type=dataset_type, 26 | data_prefix='data/tiny-imagenet-200/train', 27 | pipeline=train_pipeline), 28 | val=dict( 29 | type='ImageNet', 30 | data_prefix='data/tiny-imagenet-200/val/images', 31 | ann_file=None, 32 | pipeline=test_pipeline), 33 | test=dict( 34 | # replace `data/val` with `data/test` for standard test 35 | type='ImageNet', 36 | data_prefix='data/tiny-imagenet-200/val/images', 37 | ann_file=None, 38 | pipeline=test_pipeline)) 39 | evaluation = dict(interval=5, metric='accuracy') -------------------------------------------------------------------------------- /classification_distill/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # checkpoint saving 2 | checkpoint_config = dict(interval=20) 3 | # yapf:disable 4 | log_config = dict( 5 | interval=100, 6 | hooks=[ 7 | dict(type='TextLoggerHook'), 8 | # dict(type='TensorboardLoggerHook') 9 | ]) 10 | # yapf:enable 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/mobilenet_v2_1x.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict(type='MobileNetV2', widen_factor=1.0), 5 | neck=dict(type='GlobalAveragePooling'), 6 | head=dict( 7 | type='LinearClsHead', 8 | num_classes=1000, 9 | in_channels=1280, 10 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 11 | topk=(1, 5), 12 | )) 13 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/resnet18.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='ResNet', 6 | depth=18, 7 | num_stages=4, 8 | out_indices=(3, ), 9 | style='pytorch'), 10 | neck=dict(type='GlobalAveragePooling'), 11 | head=dict( 12 | type='LinearClsHead', 13 | num_classes=1000, 14 | in_channels=512, 15 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 16 | topk=(1, 5), 17 | )) 18 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/resnet18_cifar.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='ResNet_CIFAR', 6 | depth=18, 7 | num_stages=4, 8 | out_indices=(3, ), 9 | style='pytorch'), 10 | neck=dict(type='GlobalAveragePooling'), 11 | head=dict( 12 | type='LinearClsHead', 13 | num_classes=10, 14 | in_channels=512, 15 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 16 | )) 17 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/resnet18_dsprite.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='ResNet_CIFAR', 6 | in_channels=1, 7 | depth=18, 8 | num_stages=4, 9 | out_indices=(3, ), 10 | style='pytorch'), 11 | neck=dict(type='GlobalAveragePooling'), 12 | head=dict( 13 | type='MultiTaskLinearClsHead', 14 | num_classes=[3,6,40,32,32], 15 | in_channels=512, 16 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 17 | )) 18 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/resnet18_shape3d.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='ResNet_CIFAR', 6 | depth=18, 7 | num_stages=4, 8 | out_indices=(3, ), 9 | style='pytorch'), 10 | neck=dict(type='GlobalAveragePooling'), 11 | head=dict( 12 | type='MultiTaskLinearClsHead', 13 | num_classes=[10,10,10,8,4,15], 14 | in_channels=512, 15 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 16 | )) 17 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/resnet50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='ResNet', 6 | depth=50, 7 | num_stages=4, 8 | out_indices=(3, ), 9 | style='pytorch'), 10 | neck=dict(type='GlobalAveragePooling'), 11 | head=dict( 12 | type='LinearClsHead', 13 | num_classes=1000, 14 | in_channels=2048, 15 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 16 | topk=(1, 5), 17 | )) 18 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/shufflenet_v2_1x.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict(type='ShuffleNetV2', widen_factor=1.0), 5 | neck=dict(type='GlobalAveragePooling'), 6 | head=dict( 7 | type='LinearClsHead', 8 | num_classes=1000, 9 | in_channels=1024, 10 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 11 | topk=(1, 5), 12 | )) 13 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/wide-resnet28-10.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='WideResNet_CIFAR', 6 | depth=28, 7 | stem_channels=16, 8 | base_channels=16 * 10, 9 | num_stages=3, 10 | strides=(1, 2, 2), 11 | dilations=(1, 1, 1), 12 | out_indices=(2, ), 13 | out_channel=640, 14 | style='pytorch'), 15 | neck=dict(type='GlobalAveragePooling'), 16 | head=dict( 17 | type='LinearClsHead', 18 | num_classes=10, 19 | in_channels=640, 20 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 21 | topk=(1, 5), 22 | )) -------------------------------------------------------------------------------- /classification_distill/configs/_base_/models/wide-resnet28-2.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict( 5 | type='WideResNet_CIFAR', 6 | depth=28, 7 | stem_channels=16, 8 | base_channels=16 * 2, 9 | num_stages=3, 10 | strides=(1, 2, 2), 11 | dilations=(1, 1, 1), 12 | out_indices=(2, ), 13 | out_channel=128, 14 | style='pytorch'), 15 | neck=dict(type='GlobalAveragePooling'), 16 | head=dict( 17 | type='LinearClsHead', 18 | num_classes=10, 19 | in_channels=128, 20 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 21 | topk=(1, 5), 22 | )) -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/cifar10_bs128.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[100, 150]) 6 | runner = dict(type='EpochBasedRunner', max_epochs=200) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/dsprite_bs128.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[10,15]) 6 | runner = dict(type='EpochBasedRunner', max_epochs=20) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs1024_adamw_swin.py: -------------------------------------------------------------------------------- 1 | paramwise_cfg = dict( 2 | norm_decay_mult=0.0, 3 | bias_decay_mult=0.0, 4 | custom_keys={ 5 | '.absolute_pos_embed': dict(decay_mult=0.0), 6 | '.relative_position_bias_table': dict(decay_mult=0.0) 7 | }) 8 | 9 | # for batch in each gpu is 128, 8 gpu 10 | # lr = 5e-4 * 128 * 8 / 512 = 0.001 11 | optimizer = dict( 12 | type='AdamW', 13 | lr=5e-4 * 128 * 8 / 512, 14 | weight_decay=0.05, 15 | eps=1e-8, 16 | betas=(0.9, 0.999), 17 | paramwise_cfg=paramwise_cfg) 18 | optimizer_config = dict(grad_clip=dict(max_norm=5.0)) 19 | 20 | # learning policy 21 | lr_config = dict( 22 | policy='CosineAnnealing', 23 | by_epoch=False, 24 | min_lr_ratio=1e-2, 25 | warmup='linear', 26 | warmup_ratio=1e-3, 27 | warmup_iters=20 * 1252, 28 | warmup_by_epoch=False) 29 | 30 | runner = dict(type='EpochBasedRunner', max_epochs=300) 31 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs1024_coslr.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', lr=0.5, momentum=0.9, weight_decay=0.0001, nesterov=True) 4 | optimizer_config = dict(grad_clip=None) 5 | # learning policy 6 | lr_config = dict( 7 | policy='CosineAnnealing', 8 | min_lr=0, 9 | warmup='linear', 10 | warmup_iters=2500, 11 | warmup_ratio=0.25) 12 | runner = dict(type='EpochBasedRunner', max_epochs=150) 13 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs1024_linearlr_bn_nowd.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', 4 | lr=0.5, 5 | momentum=0.9, 6 | weight_decay=0.00004, 7 | paramwise_cfg=dict(norm_decay_mult=0)) 8 | optimizer_config = dict(grad_clip=None) 9 | # learning policy 10 | lr_config = dict( 11 | policy='poly', 12 | min_lr=0, 13 | by_epoch=False, 14 | warmup='constant', 15 | warmup_iters=5000, 16 | ) 17 | runner = dict(type='EpochBasedRunner', max_epochs=300) 18 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs2048.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', lr=0.8, momentum=0.9, weight_decay=0.0001, nesterov=True) 4 | optimizer_config = dict(grad_clip=None) 5 | # learning policy 6 | lr_config = dict( 7 | policy='step', 8 | warmup='linear', 9 | warmup_iters=2500, 10 | warmup_ratio=0.25, 11 | step=[30, 60, 90]) 12 | runner = dict(type='EpochBasedRunner', max_epochs=100) 13 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs2048_AdamW.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | # In ClassyVision, the lr is set to 0.003 for bs4096. 3 | # In this implementation(bs2048), lr = 0.003 / 4096 * (32bs * 64gpus) = 0.0015 4 | optimizer = dict(type='AdamW', lr=0.0015, weight_decay=0.3) 5 | optimizer_config = dict(grad_clip=dict(max_norm=1.0)) 6 | 7 | # specific to vit pretrain 8 | paramwise_cfg = dict( 9 | custom_keys={ 10 | '.backbone.cls_token': dict(decay_mult=0.0), 11 | '.backbone.pos_embed': dict(decay_mult=0.0) 12 | }) 13 | # learning policy 14 | lr_config = dict( 15 | policy='CosineAnnealing', 16 | min_lr=0, 17 | warmup='linear', 18 | warmup_iters=10000, 19 | warmup_ratio=1e-4) 20 | runner = dict(type='EpochBasedRunner', max_epochs=300) 21 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs2048_coslr.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', lr=0.8, momentum=0.9, weight_decay=0.0001, nesterov=True) 4 | optimizer_config = dict(grad_clip=None) 5 | # learning policy 6 | lr_config = dict( 7 | policy='CosineAnnealing', 8 | min_lr=0, 9 | warmup='linear', 10 | warmup_iters=2500, 11 | warmup_ratio=0.25) 12 | runner = dict(type='EpochBasedRunner', max_epochs=100) 13 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs256.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[30, 60, 90]) 6 | runner = dict(type='EpochBasedRunner', max_epochs=100) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs256_140e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[40, 80, 120]) 6 | runner = dict(type='EpochBasedRunner', max_epochs=140) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs256_coslr.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='CosineAnnealing', min_lr=0) 6 | runner = dict(type='EpochBasedRunner', max_epochs=150) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs256_coslr_300e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='CosineAnnealing', min_lr=0) 6 | runner = dict(type='EpochBasedRunner', max_epochs=300) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs256_coslr_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.00004) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='CosineAnnealing', min_lr=0) 6 | runner = dict(type='EpochBasedRunner', max_epochs=200) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs256_epochstep.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.045, momentum=0.9, weight_decay=0.00004) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', gamma=0.98, step=1) 6 | runner = dict(type='EpochBasedRunner', max_epochs=300) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/imagenet_bs4096_AdamW.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.003, weight_decay=0.3) 3 | optimizer_config = dict(grad_clip=dict(max_norm=1.0)) 4 | 5 | # specific to vit pretrain 6 | paramwise_cfg = dict( 7 | custom_keys={ 8 | '.backbone.cls_token': dict(decay_mult=0.0), 9 | '.backbone.pos_embed': dict(decay_mult=0.0) 10 | }) 11 | # learning policy 12 | lr_config = dict( 13 | policy='CosineAnnealing', 14 | min_lr=0, 15 | warmup='linear', 16 | warmup_iters=10000, 17 | warmup_ratio=1e-4) 18 | runner = dict(type='EpochBasedRunner', max_epochs=300) 19 | -------------------------------------------------------------------------------- /classification_distill/configs/_base_/schedules/shape3d_bs128.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[3]) 6 | runner = dict(type='EpochBasedRunner', max_epochs=5) 7 | -------------------------------------------------------------------------------- /classification_distill/configs/baseline/mbnv2_b128x1_tinyimagenet_finetune.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/tinyimagenet_bs128.py' 3 | ] 4 | 5 | # checkpoint saving 6 | checkpoint_config = dict(interval=50) 7 | # yapf:disable 8 | log_config = dict( 9 | interval=100, 10 | hooks=[ 11 | dict(type='TextLoggerHook'), 12 | dict(type='TensorboardLoggerHook') 13 | ]) 14 | # yapf:enable 15 | 16 | dist_params = dict(backend='nccl') 17 | log_level = 'INFO' 18 | load_from = None 19 | resume_from = None 20 | workflow = [('train', 1)] 21 | 22 | # optimizer 23 | optimizer = dict(type='SGD', 24 | lr=0.1, momentum=0.9, weight_decay=0.0001) 25 | optimizer_config = dict(grad_clip=None) 26 | # learning policy 27 | lr_config = dict(policy='step', step=[100, 150]) 28 | runner = dict(type='EpochBasedRunner', max_epochs=200) 29 | 30 | 31 | # model settings 32 | model = dict( 33 | type='ImageClassifier', 34 | backbone=dict(type='MobileNetV2_CIFAR', 35 | out_indices=(7, ), 36 | widen_factor=1.0, 37 | init_cfg=dict(type='Pretrained', 38 | prefix='student.backbone.', 39 | checkpoint='') 40 | ), 41 | neck=dict( 42 | type='GlobalAveragePooling' 43 | ), 44 | head=dict( 45 | type='LinearClsHead', 46 | num_classes=200, 47 | in_channels=1280, 48 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 49 | topk=(1, 5), 50 | ) 51 | ) 52 | -------------------------------------------------------------------------------- /classification_distill/configs/baseline/resnet18_b128x1_tinyimagenet_finetune.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/tinyimagenet_bs128.py' 3 | ] 4 | 5 | # checkpoint saving 6 | checkpoint_config = dict(interval=50) 7 | # yapf:disable 8 | log_config = dict( 9 | interval=100, 10 | hooks=[ 11 | dict(type='TextLoggerHook'), 12 | dict(type='TensorboardLoggerHook') 13 | ]) 14 | # yapf:enable 15 | 16 | dist_params = dict(backend='nccl') 17 | log_level = 'INFO' 18 | load_from = None 19 | resume_from = None 20 | workflow = [('train', 1)] 21 | 22 | # optimizer 23 | optimizer = dict(type='SGD', 24 | lr=0.1, momentum=0.9, weight_decay=0.0001) 25 | optimizer_config = dict(grad_clip=None) 26 | # learning policy 27 | lr_config = dict(policy='step', step=[100, 150]) 28 | runner = dict(type='EpochBasedRunner', max_epochs=200) 29 | 30 | 31 | # model settings 32 | model = dict( 33 | type='ImageClassifier', 34 | backbone=dict( 35 | type='ResNet_CIFAR', 36 | depth=18, 37 | num_stages=4, 38 | out_indices=(3, ), 39 | style='pytorch', 40 | init_cfg=dict(type='Pretrained', 41 | prefix='student.backbone.', 42 | checkpoint='') 43 | ), 44 | neck=dict( 45 | type='GlobalAveragePooling' 46 | ), 47 | head=dict( 48 | type='LinearClsHead', 49 | num_classes=200, 50 | in_channels=512, 51 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 52 | topk=(1, 5), 53 | ) 54 | ) 55 | -------------------------------------------------------------------------------- /classification_distill/configs/celebahq-repfusion/ddpm_hrnet18_b32x2_celebahq_pretrain_maxtime1_clean.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/celebahq_bs32.py' 3 | ] 4 | 5 | data = dict(samples_per_gpu=32) 6 | 7 | # checkpoint saving 8 | checkpoint_config = dict(interval=50,max_keep_ckpts=3) 9 | # yapf:disable 10 | log_config = dict( 11 | interval=100, 12 | hooks=[ 13 | dict(type='TextLoggerHook'), 14 | dict(type='TensorboardLoggerHook') 15 | ]) 16 | # yapf:enable 17 | 18 | dist_params = dict(backend='nccl') 19 | log_level = 'INFO' 20 | load_from = None 21 | resume_from = None 22 | workflow = [('train', 1)] 23 | 24 | # optimizer 25 | optimizer = dict(type='SGD', 26 | lr=0.1, momentum=0.9, weight_decay=0.0001) 27 | optimizer_config = dict(grad_clip=None) 28 | # learning policy 29 | lr_config = dict(policy='step', step=[100, 150]) 30 | runner = dict(type='EpochBasedRunner', max_epochs=200) 31 | 32 | fp16 = dict(loss_scale=512.) 33 | 34 | # model settings 35 | model = dict( 36 | type='KDDDPM_Pretrain_CleanDense_HRNet_ImageClassifier', 37 | teacher_layers=[["mid_block.resnets.1.conv2", 512]], 38 | student_layers=[['backbone.stage4.2.relu', 2048]], 39 | distill_fn=[['l1', 10.0]], 40 | model_id='google/ddpm-ema-celebahq-256', 41 | max_time_step=1, 42 | backbone=dict(type='HRNet', arch='w18'), 43 | neck=[ 44 | dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)), 45 | # dict(type='GlobalAveragePooling'), 46 | ], 47 | head=None 48 | ) 49 | 50 | evaluation = dict(interval=200, metric='accuracy') -------------------------------------------------------------------------------- /classification_distill/mmcls/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/mmcls/.DS_Store -------------------------------------------------------------------------------- /classification_distill/mmcls/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_model, init_model, show_result_pyplot 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .multitask_test import multitask_multi_gpu_test, multitask_single_gpu_test 5 | from .train import set_random_seed, train_model 6 | 7 | __all__ = [ 8 | 'set_random_seed', 'train_model', 'init_model', 'inference_model', 9 | 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot', 'multitask_multi_gpu_test', 10 | 'multitask_single_gpu_test' 11 | ] 12 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | from .fp16 import * # noqa: F401, F403 4 | from .utils import * # noqa: F401, F403 5 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import DistEvalHook, EvalHook 3 | from .multitask_eval_hooks import MultiTaskEvalHook, DistMultiTaskEvalHook 4 | from .eval_metrics import (calculate_confusion_matrix, f1_score, precision, 5 | precision_recall_f1, recall, support) 6 | from .mean_ap import average_precision, mAP 7 | from .multilabel_eval_metrics import average_performance 8 | 9 | __all__ = [ 10 | 'DistEvalHook', 'EvalHook', 'precision', 'recall', 'f1_score', 'support', 11 | 'average_precision', 'mAP', 'average_performance', 12 | 'calculate_confusion_matrix', 'precision_recall_f1', 'MultiTaskEvalHook', 13 | 'DistMultiTaskEvalHook' 14 | ] 15 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/export/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .test import ONNXRuntimeClassifier, TensorRTClassifier 3 | 4 | __all__ = ['ONNXRuntimeClassifier', 'TensorRTClassifier'] 5 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/fp16/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .decorators import auto_fp16, force_fp32 3 | from .hooks import Fp16OptimizerHook, wrap_fp16_model 4 | 5 | __all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model'] 6 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/fp16/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import abc 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def cast_tensor_type(inputs, src_type, dst_type): 9 | if isinstance(inputs, torch.Tensor): 10 | return inputs.to(dst_type) 11 | elif isinstance(inputs, str): 12 | return inputs 13 | elif isinstance(inputs, np.ndarray): 14 | return inputs 15 | elif isinstance(inputs, abc.Mapping): 16 | return type(inputs)({ 17 | k: cast_tensor_type(v, src_type, dst_type) 18 | for k, v in inputs.items() 19 | }) 20 | elif isinstance(inputs, abc.Iterable): 21 | return type(inputs)( 22 | cast_tensor_type(item, src_type, dst_type) for item in inputs) 23 | else: 24 | return inputs 25 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import DistOptimizerHook, allreduce_grads 3 | from .misc import multi_apply 4 | from .kd_hook import KDOptimizerBuilder 5 | from .visualize import TensorboardVisLoggerHook 6 | from .timestep_decay import TimeDecayHook, EntropyDecayHook 7 | 8 | __all__ = ['allreduce_grads', 'DistOptimizerHook', 9 | 'multi_apply', 'KDOptimizerBuilder', 10 | 'TensorboardVisLoggerHook','TimeDecayHook','EntropyDecayHook'] 11 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/utils/kd_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.parallel import is_module_wrapper 3 | from mmcv.runner import (HOOKS, OPTIMIZER_BUILDERS, OPTIMIZERS, 4 | DefaultOptimizerConstructor, Hook, OptimizerHook) 5 | from mmcv.utils import build_from_cfg 6 | 7 | 8 | @OPTIMIZER_BUILDERS.register_module() 9 | class KDOptimizerBuilder(DefaultOptimizerConstructor): 10 | def __init__(self, optimizer_cfg, paramwise_cfg=None): 11 | super(KDOptimizerBuilder, self).__init__(optimizer_cfg, 12 | paramwise_cfg) 13 | 14 | def __call__(self, model): 15 | if hasattr(model, 'module'): 16 | model = model.module 17 | 18 | optimizer_cfg = self.optimizer_cfg.copy() 19 | 20 | # if no paramwise option is specified, just use the global setting 21 | if not self.paramwise_cfg: 22 | optimizer_cfg['params'] = model.student.parameters() 23 | student_optimizer = build_from_cfg(optimizer_cfg, 24 | OPTIMIZERS) 25 | return student_optimizer 26 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from functools import partial 3 | 4 | 5 | def multi_apply(func, *args, **kwargs): 6 | pfunc = partial(func, **kwargs) if kwargs else func 7 | map_results = map(pfunc, *args) 8 | return tuple(map(list, zip(*map_results))) 9 | -------------------------------------------------------------------------------- /classification_distill/mmcls/core/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import color_val_matplotlib, imshow_infos 2 | 3 | __all__ = ['imshow_infos', 'color_val_matplotlib'] 4 | -------------------------------------------------------------------------------- /classification_distill/mmcls/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_dataset import BaseDataset 3 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 4 | from .cifar import CIFAR10, CIFAR100, CIFAR10_MultiTask, CIFAR10_2Task, CIFAR10_Select 5 | from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, 6 | RepeatDataset) 7 | from .imagenet import ImageNet, ImageNet_MultiTask, TinyImageNet 8 | from .samplers import DistributedSampler 9 | from .disentangle_data import dSprites, Shape3D 10 | from .images import ImageList 11 | from .custom import CustomDataset 12 | from .celeba import CelebA 13 | 14 | -------------------------------------------------------------------------------- /classification_distill/mmcls/datasets/disentangle_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsprites import dSprites 2 | from .shape3d import Shape3D -------------------------------------------------------------------------------- /classification_distill/mmcls/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .auto_augment import (AutoAugment, AutoContrast, Brightness, 3 | ColorTransform, Contrast, Cutout, Equalize, Invert, 4 | Posterize, RandAugment, Rotate, Sharpness, Shear, 5 | Solarize, SolarizeAdd, Translate) 6 | from .compose import Compose 7 | from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, 8 | Transpose, to_tensor) 9 | from .loading import LoadImageFromFile 10 | from .transforms import (CenterCrop, ColorJitter, Lighting, RandomCrop, 11 | RandomErasing, RandomFlip, RandomGrayscale, 12 | RandomResizedCrop, Resize) 13 | from .multiview import MultiBranch 14 | 15 | __all__ = [ 16 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', 17 | 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', 18 | 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', 19 | 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert', 20 | 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', 21 | 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', 22 | 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'MultiBranch' 23 | ] 24 | -------------------------------------------------------------------------------- /classification_distill/mmcls/datasets/pipelines/multiview.py: -------------------------------------------------------------------------------- 1 | from ..builder import PIPELINES 2 | from .compose import Compose 3 | import copy 4 | 5 | 6 | @PIPELINES.register_module() 7 | class MultiBranch(object): 8 | def __init__(self, **transform_group): 9 | self.transform_group = {k: Compose(v) 10 | for k, v in transform_group.items()} 11 | 12 | def __call__(self, results): 13 | multi_results = dict() 14 | for k, v in self.transform_group.items(): 15 | res = v(copy.deepcopy(results)) 16 | if res is None: 17 | return None 18 | # res["img_metas"]["tag"] = k 19 | multi_results[k] = res 20 | return multi_results 21 | -------------------------------------------------------------------------------- /classification_distill/mmcls/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/mmcls/models/.DS_Store -------------------------------------------------------------------------------- /classification_distill/mmcls/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS, 4 | build_backbone, build_classifier, build_head, build_loss, 5 | build_neck) 6 | from .classifiers import * # noqa: F401,F403 7 | from .heads import * # noqa: F401,F403 8 | from .losses import * # noqa: F401,F403 9 | from .necks import * # noqa: F401,F403 10 | 11 | __all__ = [ 12 | 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone', 13 | 'build_head', 'build_neck', 'build_loss', 'build_classifier' 14 | ] 15 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mobilenet_v2 import MobileNetV2 3 | from .mobilenet_v2_cifar import MobileNetV2_CIFAR 4 | from .resnet import ResNet, ResNetV1d 5 | from .resnet_cifar import ResNet_CIFAR 6 | from .shufflenet_v2 import ShuffleNetV2 7 | from .tsn import TSN_backbone 8 | from .wideresnet import WideResNet_CIFAR 9 | from .disentangle import SimpleConv64, SimpleGaussianConv64 10 | from .hrnet import HRNet 11 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/backbones/base_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class BaseBackbone(BaseModule, metaclass=ABCMeta): 8 | """Base backbone. 9 | 10 | This class defines the basic functions of a backbone. Any backbone that 11 | inherits this class should at least define its own `forward` function. 12 | """ 13 | 14 | def __init__(self, init_cfg=None): 15 | super(BaseBackbone, self).__init__(init_cfg) 16 | 17 | @abstractmethod 18 | def forward(self, x): 19 | """Forward computation. 20 | 21 | Args: 22 | x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of 23 | Torch.tensor, containing input data for forward computation. 24 | """ 25 | pass 26 | 27 | def train(self, mode=True): 28 | """Set module status before forward computation. 29 | 30 | Args: 31 | mode (bool): Whether it is train_mode or test_mode 32 | """ 33 | super(BaseBackbone, self).train(mode) 34 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/backbones/tsn.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import torch.nn as nn 3 | import torch 4 | 5 | from ..builder import BACKBONES, build_backbone 6 | from .base_backbone import BaseBackbone 7 | import torch.nn.functional as F 8 | 9 | @BACKBONES.register_module() 10 | class TSN_backbone(BaseBackbone): 11 | def __init__(self, backbone, in_channels, out_channels): 12 | super().__init__() 13 | self.in_channels = in_channels 14 | self.out_channels = out_channels 15 | 16 | self.encoder = build_backbone(backbone) 17 | self.fc = nn.Linear(self.in_channels, self.out_channels, bias=False) 18 | 19 | 20 | def forward(self, x): 21 | x = self.encoder(x) 22 | if isinstance(x, tuple): 23 | x = x[-1] 24 | 25 | x = F.adaptive_avg_pool2d(x, (1,1)) 26 | x = x.view(x.size(0), -1) 27 | x = self.fc(x) 28 | 29 | mu = torch.mean(x, 0) 30 | log_var = torch.log(torch.var(x, 0)) 31 | return (mu, log_var), x 32 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import MODELS as MMCV_MODELS 3 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 4 | from mmcv.utils import Registry 5 | 6 | MODELS = Registry('models', parent=MMCV_MODELS) 7 | 8 | BACKBONES = MODELS 9 | NECKS = MODELS 10 | HEADS = MODELS 11 | LOSSES = MODELS 12 | CLASSIFIERS = MODELS 13 | 14 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 15 | 16 | 17 | def build_backbone(cfg): 18 | """Build backbone.""" 19 | return BACKBONES.build(cfg) 20 | 21 | 22 | def build_neck(cfg): 23 | """Build neck.""" 24 | return NECKS.build(cfg) 25 | 26 | 27 | def build_head(cfg): 28 | """Build head.""" 29 | return HEADS.build(cfg) 30 | 31 | 32 | def build_loss(cfg): 33 | """Build loss.""" 34 | return LOSSES.build(cfg) 35 | 36 | 37 | def build_classifier(cfg): 38 | return CLASSIFIERS.build(cfg) 39 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseClassifier 3 | from .image import ImageClassifier 4 | from .kd import * 5 | from .repfusion import * 6 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/mmcls/models/guided_diffusion/__init__.py -------------------------------------------------------------------------------- /classification_distill/mmcls/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cls_head import ClsHead 3 | from .linear_head import LinearBCEClsHead, LinearClsHead 4 | from .multi_label_head import MultiLabelClsHead 5 | from .multitask_linear_head import MultiTaskLinearClsHead 6 | 7 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/heads/base_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class BaseHead(BaseModule, metaclass=ABCMeta): 8 | """Base head.""" 9 | 10 | def __init__(self, init_cfg=None): 11 | super(BaseHead, self).__init__(init_cfg) 12 | 13 | @abstractmethod 14 | def forward_train(self, x, gt_label, **kwargs): 15 | pass 16 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .asymmetric_loss import AsymmetricLoss, asymmetric_loss 4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 5 | cross_entropy) 6 | from .focal_loss import FocalLoss, sigmoid_focal_loss 7 | from .label_smooth_loss import LabelSmoothLoss 8 | from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, 9 | weighted_loss) 10 | from .kd_loss import Logits, SoftTarget 11 | from .norm_l2 import MSE_Norm_Loss 12 | from .at import AT 13 | from .rkd import RKD 14 | from .crd_loss import CRDLoss 15 | __all__ = [ 16 | 'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss', 17 | 'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 18 | 'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss', 19 | 'sigmoid_focal_loss', 'convert_to_one_hot', 'Logits', 'SoftTarget', 'MSE_Norm_Loss' 20 | ] 21 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/losses/at.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ''' 6 | AT with sum of absolute values with power p 7 | ''' 8 | class AT(nn.Module): 9 | ''' 10 | Paying More Attention to Attention: Improving the Performance of Convolutional 11 | Neural Netkworks wia Attention Transfer 12 | https://arxiv.org/pdf/1612.03928.pdf 13 | ''' 14 | def __init__(self, p): 15 | super(AT, self).__init__() 16 | self.p = p 17 | 18 | def forward(self, fm_s, fm_t): 19 | if fm_s.shape[2] != fm_t.shape[2] or fm_s.shape[3] != fm_t.shape[3]: 20 | fm_s = F.interpolate(fm_s, (fm_t.size(2), fm_t.size(3)), mode='bilinear') 21 | loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t)) 22 | 23 | return loss 24 | 25 | def attention_map(self, fm, eps=1e-6): 26 | am = torch.pow(torch.abs(fm), self.p) 27 | am = torch.sum(am, dim=1, keepdim=True) 28 | norm = torch.norm(am, dim=(2,3), keepdim=True) 29 | am = torch.div(am, norm+eps) 30 | 31 | return am -------------------------------------------------------------------------------- /classification_distill/mmcls/models/losses/crd_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterion import CRDLoss -------------------------------------------------------------------------------- /classification_distill/mmcls/models/losses/norm_l2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MSE_Norm_Loss(nn.Module): 7 | def __init__(self) -> None: 8 | super().__init__() 9 | 10 | def forward(self, x, y): 11 | x = F.normalize(x, p=2, dim=1) 12 | y = F.normalize(y, p=2, dim=1) 13 | return F.mse_loss(x, y) -------------------------------------------------------------------------------- /classification_distill/mmcls/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .gap import GlobalAveragePooling 3 | from .hr_fuse import HRFuseScales 4 | 5 | __all__ = ['GlobalAveragePooling', 'HRFuseScales'] 6 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .attention import ShiftWindowMSA 3 | from .augment.augments import Augments 4 | from .channel_shuffle import channel_shuffle 5 | from .embed import HybridEmbed, PatchEmbed, PatchMerging 6 | from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple 7 | from .inverted_residual import InvertedResidual 8 | from .make_divisible import make_divisible 9 | from .se_layer import SELayer 10 | 11 | __all__ = [ 12 | 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', 13 | 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed', 14 | 'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing' 15 | ] 16 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/augment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augments import Augments 3 | from .cutmix import BatchCutMixLayer 4 | from .identity import Identity 5 | from .mixup import BatchMixupLayer 6 | 7 | __all__ = ['Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer'] 8 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/augment/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | AUGMENT = Registry('augment') 5 | 6 | 7 | def build_augment(cfg, default_args=None): 8 | return build_from_cfg(cfg, AUGMENT, default_args) 9 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/augment/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn.functional as F 3 | 4 | from .builder import AUGMENT 5 | 6 | 7 | @AUGMENT.register_module(name='Identity') 8 | class Identity(object): 9 | """Change gt_label to one_hot encoding and keep img as the same. 10 | 11 | Args: 12 | num_classes (int): The number of classes. 13 | prob (float): MixUp probability. It should be in range [0, 1]. 14 | Default to 1.0 15 | """ 16 | 17 | def __init__(self, num_classes, prob=1.0): 18 | super(Identity, self).__init__() 19 | 20 | assert isinstance(num_classes, int) 21 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0 22 | 23 | self.num_classes = num_classes 24 | self.prob = prob 25 | 26 | def one_hot(self, gt_label): 27 | return F.one_hot(gt_label, num_classes=self.num_classes) 28 | 29 | def __call__(self, img, gt_label): 30 | return img, self.one_hot(gt_label) 31 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/channel_shuffle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def channel_shuffle(x, groups): 6 | """Channel Shuffle operation. 7 | 8 | This function enables cross-group information flow for multiple groups 9 | convolution layers. 10 | 11 | Args: 12 | x (Tensor): The input tensor. 13 | groups (int): The number of groups to divide the input tensor 14 | in the channel dimension. 15 | 16 | Returns: 17 | Tensor: The output tensor after channel shuffle operation. 18 | """ 19 | 20 | batch_size, num_channels, height, width = x.size() 21 | assert (num_channels % groups == 0), ('num_channels should be ' 22 | 'divisible by groups') 23 | channels_per_group = num_channels // groups 24 | 25 | x = x.view(batch_size, groups, channels_per_group, height, width) 26 | x = torch.transpose(x, 1, 2).contiguous() 27 | x = x.view(batch_size, -1, height, width) 28 | 29 | return x 30 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections.abc 3 | import warnings 4 | from distutils.version import LooseVersion 5 | from itertools import repeat 6 | 7 | import torch 8 | 9 | 10 | def is_tracing() -> bool: 11 | if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): 12 | on_trace = torch.jit.is_tracing() 13 | # In PyTorch 1.6, torch.jit.is_tracing has a bug. 14 | # Refers to https://github.com/pytorch/pytorch/issues/42448 15 | if isinstance(on_trace, bool): 16 | return on_trace 17 | else: 18 | return torch._C._is_tracing() 19 | else: 20 | warnings.warn( 21 | 'torch.jit.is_tracing is only supported after v1.6.0. ' 22 | 'Therefore is_tracing returns False automatically. Please ' 23 | 'set on_trace manually if you are using trace.', UserWarning) 24 | return False 25 | 26 | 27 | # From PyTorch internals 28 | def _ntuple(n): 29 | 30 | def parse(x): 31 | if isinstance(x, collections.abc.Iterable): 32 | return x 33 | return tuple(repeat(x, n)) 34 | 35 | return parse 36 | 37 | 38 | to_1tuple = _ntuple(1) 39 | to_2tuple = _ntuple(2) 40 | to_3tuple = _ntuple(3) 41 | to_4tuple = _ntuple(4) 42 | to_ntuple = _ntuple 43 | -------------------------------------------------------------------------------- /classification_distill/mmcls/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number down to the nearest value that can 6 | be divisible by the divisor. 7 | 8 | Args: 9 | value (int): The original channel number. 10 | divisor (int): The divisor to fully divide the channel number. 11 | min_value (int, optional): The minimum value of the output channel. 12 | Default: None, means that the minimum value equal to the divisor. 13 | min_ratio (float): The minimum ratio of the rounded channel 14 | number to the original channel number. Default: 0.9. 15 | Returns: 16 | int: The modified output channel number 17 | """ 18 | 19 | if min_value is None: 20 | min_value = divisor 21 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than (1-min_ratio). 23 | if new_value < min_ratio * value: 24 | new_value += divisor 25 | return new_value 26 | -------------------------------------------------------------------------------- /classification_distill/mmcls/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | 5 | __all__ = ['collect_env', 'get_root_logger'] 6 | -------------------------------------------------------------------------------- /classification_distill/mmcls/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmcls 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMClassification'] = mmcls.__version__ + '+' + get_git_hash()[:7] 12 | return env_info 13 | 14 | 15 | if __name__ == '__main__': 16 | for name, val in collect_env().items(): 17 | print(f'{name}: {val}') 18 | -------------------------------------------------------------------------------- /classification_distill/mmcls/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | return get_logger('mmcls', log_file, log_level) 9 | -------------------------------------------------------------------------------- /classification_distill/mmcls/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved 2 | 3 | __version__ = '0.15.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | """Parse a version string into a tuple. 8 | 9 | Args: 10 | version_str (str): The version string. 11 | Returns: 12 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 13 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 14 | """ 15 | version_info = [] 16 | for x in version_str.split('.'): 17 | if x.isdigit(): 18 | version_info.append(int(x)) 19 | elif x.find('rc') != -1: 20 | patch_version = x.split('rc') 21 | version_info.append(int(patch_version[0])) 22 | version_info.append(f'rc{patch_version[1]}') 23 | return tuple(version_info) 24 | 25 | 26 | version_info = parse_version_info(__version__) 27 | 28 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 29 | -------------------------------------------------------------------------------- /classification_distill/tools/analysis_tools/analysis_para.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from mmcv import Config 4 | from prettytable import PrettyTable 5 | 6 | from mmcls.models.builder import build_classifier 7 | 8 | 9 | def count_parameters(model): 10 | table = PrettyTable(["Modules", "Parameters"]) 11 | total_params = 0 12 | for name, parameter in model.named_parameters(): 13 | if not parameter.requires_grad: 14 | continue 15 | param = parameter.numel() 16 | table.add_row([name, param]) 17 | total_params += param 18 | print(table) 19 | print(f"Total Trainable Params: {total_params}") 20 | return total_params 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description='Explain a model') 25 | parser.add_argument('config', help='train config file path') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | cfg = Config.fromfile(args.config) 33 | print(cfg) 34 | model = build_classifier(cfg.model) 35 | 36 | count_parameters(model) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /classification_distill/tools/analysis_tools/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON_FILE=$1 4 | CONFIG=$2 5 | 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python $(dirname "$0")/$PYTHON_FILE $CONFIG -------------------------------------------------------------------------------- /classification_distill/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /classification_distill/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29502} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /classification_distill/tools/find_duplicate.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | root = "/Users/xingyiyang/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/7a3671d9cc49f8fbe0d260c990887f29/Message/MessageTemp/60cdc86db4ee39ae13d0a0c6c0c19385/File/" 4 | ee2026 = pd.read_excel(os.path.join(root, "ee2026 NameList_130223.xlsx")) 5 | # print(ee2026.head()) 6 | ee2211 = pd.read_csv(os.path.join(root, "2023-02-14T1417_Grades-EE2211.csv")) 7 | # print(ee2211.head()) 8 | both_df = pd.read_excel(os.path.join(root, "副本Copy of EE2211 n EE2026 emrolm_students taking both modules.xlsx")) 9 | both_df_gongfan = pd.read_excel(os.path.join(root, "results_ID.xlsx")) 10 | 11 | 12 | ee2026_student = ee2026["Student ID"].values 13 | ee2211_student = ee2211['Integration ID'].values 14 | # print(ee2026_student) 15 | # print(ee2211_student) 16 | 17 | def intersection(lst1, lst2): 18 | return list(set(lst1) & set(lst2)) 19 | 20 | both_student_myresults = intersection(ee2211_student, ee2026_student) 21 | 22 | print(both_student_myresults) 23 | print("Num students {}".format(len(both_student_myresults))) 24 | 25 | both_df_student = both_df["Student ID"].values 26 | print(both_df_student) 27 | print("Num students {}".format(len(both_df_student))) 28 | 29 | both_df_gongfan = both_df_gongfan["EE2211"].values[:105] 30 | print(both_df_gongfan) 31 | print("Num students {}".format(len(both_df_gongfan))) 32 | 33 | assert set(both_student_myresults) == set(both_df_student) == set(both_df_gongfan) -------------------------------------------------------------------------------- /classification_distill/tools/load_ckp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def main(): 4 | # PATH = 'resnet18_R182R18_common_network_20211112v2.pth' 5 | # ckp_path = '/home/yangxingyi/NeuralFactor/Multi-task-Depth-Seg/result/NYUD/kd_resnet_50_to_resnet_18/multi_task_baseline/best_model.pth.tar' 6 | ckp_path = '' 7 | model_dict = torch.load(ckp_path) 8 | # save_dict= dict(common_network=model_dict) 9 | # torch.save(save_dict, PATH, _use_new_zipfile_serialization=False) 10 | print(model_dict.keys()) 11 | 12 | def ckp_to_load(): 13 | ckp_path = '/home/yangxingyi/.cache/torch/checkpoints/resnet50_8xb32_in1k_20210831-ea4938fc.pth' 14 | save_path = '/home/yangxingyi/.cache/torch/checkpoints/resnet50_8xb32_in1k_20210831-ea4938fc_converted.pth' 15 | model_dict = torch.load(ckp_path) 16 | if 'state_dict' in model_dict.keys(): 17 | model_dict = model_dict['state_dict'] 18 | new_dict = dict() 19 | for k, v in model_dict.items(): 20 | if k.startswith('fc'): 21 | new_k = 'head.{}'.format(k) 22 | else: 23 | new_k = 'backbone.{}'.format(k) 24 | print('Old Key:', k, '-> New Key:', new_k) 25 | new_dict[new_k] = v 26 | save_dict= dict(state_dict=new_dict) 27 | torch.save(save_dict, save_path, _use_new_zipfile_serialization=False) 28 | 29 | if __name__ == '__main__': 30 | # main() 31 | ckp_to_load() 32 | -------------------------------------------------------------------------------- /classification_distill/tools/plot/CIFAR10_S.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/tools/plot/CIFAR10_S.npy -------------------------------------------------------------------------------- /classification_distill/tools/plot/CIFAR10_rank.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/tools/plot/CIFAR10_rank.npy -------------------------------------------------------------------------------- /classification_distill/tools/plot/mnist_S.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/tools/plot/mnist_S.npy -------------------------------------------------------------------------------- /classification_distill/tools/plot/mnist_rank.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/Repfusion/2fe77c4c3c75592b4ea308488a926cc408e1a116/classification_distill/tools/plot/mnist_rank.npy -------------------------------------------------------------------------------- /classification_distill/tools/plot/plot_CIFAR10.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | 5 | with open('cls/tools/plot/CIFAR10_S.npy','rb') as f: 6 | Ss = np.load(f) 7 | with open('cls/tools/plot/CIFAR10_rank.npy','rb') as f: 8 | ranks = np.load(f) 9 | 10 | sns.set_style("whitegrid") 11 | plt.rcParams["font.family"] = "Times New Roman" 12 | ts= [0,1,2, 5, 8, 10,20,50,300,500, 900,999] 13 | plt.figure(figsize=(8,6)) 14 | for S, t in zip(Ss,ts): 15 | if t in [0, 300,500, 900,999]: 16 | plt.plot(S, label=f"t={t}") 17 | plt.yscale('symlog') 18 | plt.legend(fontsize=20) 19 | plt.xlabel("Singular Values Index",fontsize=32) 20 | plt.ylabel("Singular Values",fontsize=32) 21 | plt.xticks(fontsize=22) 22 | plt.yticks(fontsize=22) 23 | # plt.show() 24 | plt.tight_layout() 25 | plt.savefig("Singular Value_CIFAR10.pdf") 26 | 27 | 28 | plt.figure(figsize=(8,6)) 29 | # for ranks, t in zip(Ss,[0,10,50,300,500, 900]): 30 | # plt.plot(ts[:6],ranks[:6], 'o-',linewidth=3) 31 | # ts = np.array(ts)+0.0000001 32 | plt.plot(ts,ranks, 'o-') 33 | 34 | # plt.legend(fontsize=16) 35 | plt.xlabel("TimeStep",fontsize=32) 36 | plt.ylabel("Effective rank",fontsize=32) 37 | plt.xticks(fontsize=22) 38 | plt.yticks(fontsize=22) 39 | # plt.xticks(fontsize=14) 40 | # plt.yticks(fontsize=14) 41 | plt.xlim((-0.1,1100)) 42 | # plt.show() 43 | plt.xscale("symlog") 44 | plt.tight_layout() 45 | plt.savefig("Effective rank_CIFAR10.pdf") -------------------------------------------------------------------------------- /classification_distill/tools/plot/plot_mnist.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | 5 | with open('cls/tools/plot/mnist_S.npy','rb') as f: 6 | Ss = np.load(f) 7 | with open('cls/tools/plot/mnist_rank.npy','rb') as f: 8 | ranks = np.load(f) 9 | 10 | sns.set_style("whitegrid") 11 | plt.rcParams["font.family"] = "Times New Roman" 12 | ts= [0,1,2, 5, 8, 10,20,50,300,500, 900,999] 13 | plt.figure(figsize=(8,6)) 14 | for S, t in zip(Ss,ts): 15 | if t in [0, 2,10, 300,500, 900,999]: 16 | plt.plot(S, label=f"t={t}") 17 | plt.yscale('symlog') 18 | plt.legend(fontsize=20) 19 | plt.xlabel("Singular Values Index",fontsize=32) 20 | plt.ylabel("Singular Values",fontsize=32) 21 | plt.xticks(fontsize=22) 22 | plt.yticks(fontsize=22) 23 | # plt.show() 24 | plt.tight_layout() 25 | plt.savefig("Singular Value_mnist.pdf") 26 | 27 | 28 | plt.figure(figsize=(8,6)) 29 | # for ranks, t in zip(Ss,[0,10,50,300,500, 900]): 30 | # plt.plot(ts[:6],ranks[:6], 'o-',linewidth=3) 31 | # ts = np.array(ts)+0.0000001 32 | plt.plot(ts,ranks, 'o-') 33 | 34 | # plt.legend(fontsize=16) 35 | plt.xlabel("TimeStep",fontsize=32) 36 | plt.ylabel("Effective rank",fontsize=32) 37 | plt.xticks(fontsize=22) 38 | plt.yticks(fontsize=22) 39 | # plt.xticks(fontsize=14) 40 | # plt.yticks(fontsize=14) 41 | plt.xlim((-0.1,1100)) 42 | # plt.show() 43 | plt.xscale("symlog") 44 | plt.tight_layout() 45 | plt.savefig("Effective rank_mnist.pdf") -------------------------------------------------------------------------------- /classification_distill/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /classification_distill/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-1} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-1} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-4} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /landmark/configs/_base_/datasets/deepfashion_lower.py: -------------------------------------------------------------------------------- 1 | dataset_info = dict( 2 | dataset_name='deepfashion_lower', 3 | paper_info=dict( 4 | author='Liu, Ziwei and Luo, Ping and Qiu, Shi ' 5 | 'and Wang, Xiaogang and Tang, Xiaoou', 6 | title='DeepFashion: Powering Robust Clothes Recognition ' 7 | 'and Retrieval with Rich Annotations', 8 | container='Proceedings of IEEE Conference on Computer ' 9 | 'Vision and Pattern Recognition (CVPR)', 10 | year='2016', 11 | homepage='http://mmlab.ie.cuhk.edu.hk/projects/' 12 | 'DeepFashion/LandmarkDetection.html', 13 | ), 14 | keypoint_info={ 15 | 0: 16 | dict( 17 | name='left waistline', 18 | id=0, 19 | color=[255, 255, 255], 20 | type='', 21 | swap='right waistline'), 22 | 1: 23 | dict( 24 | name='right waistline', 25 | id=1, 26 | color=[255, 255, 255], 27 | type='', 28 | swap='left waistline'), 29 | 2: 30 | dict( 31 | name='left hem', 32 | id=2, 33 | color=[255, 255, 255], 34 | type='', 35 | swap='right hem'), 36 | 3: 37 | dict( 38 | name='right hem', 39 | id=3, 40 | color=[255, 255, 255], 41 | type='', 42 | swap='left hem'), 43 | }, 44 | skeleton_info={}, 45 | joint_weights=[1.] * 4, 46 | sigmas=[]) 47 | -------------------------------------------------------------------------------- /landmark/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=10) 2 | 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | # dict(type='PaviLoggerHook') # for internal services 9 | ]) 10 | 11 | log_level = 'INFO' 12 | load_from = None 13 | resume_from = None 14 | dist_params = dict(backend='nccl') 15 | workflow = [('train', 1)] 16 | 17 | # disable opencv multithreading to avoid system being overloaded 18 | opencv_num_threads = 0 19 | # set multi-process start method as `fork` to speed up the training 20 | mp_start_method = 'fork' 21 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/gaussian.py: -------------------------------------------------------------------------------- 1 | filter_cfg = dict( 2 | type='GaussianFilter', 3 | window_size=11, 4 | sigma=4.0, 5 | ) 6 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/one_euro.py: -------------------------------------------------------------------------------- 1 | filter_cfg = dict( 2 | type='OneEuroFilter', 3 | min_cutoff=0.004, 4 | beta=0.7, 5 | ) 6 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/savizky_golay.py: -------------------------------------------------------------------------------- 1 | filter_cfg = dict( 2 | type='SavizkyGolayFilter', 3 | window_size=11, 4 | polyorder=2, 5 | ) 6 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/smoothnet_t16_h36m.py: -------------------------------------------------------------------------------- 1 | # Config for SmoothNet filter trained on Human3.6M data with a window size of 2 | # 16. The model is trained using root-centered keypoint coordinates around the 3 | # pelvis (index:0), thus we set root_index=0 for the filter 4 | filter_cfg = dict( 5 | type='SmoothNetFilter', 6 | window_size=16, 7 | output_size=16, 8 | checkpoint='https://download.openmmlab.com/mmpose/plugin/smoothnet/' 9 | 'smoothnet_ws16_h36m.pth', 10 | hidden_size=512, 11 | res_hidden_size=256, 12 | num_blocks=3, 13 | root_index=0) 14 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/smoothnet_t32_h36m.py: -------------------------------------------------------------------------------- 1 | # Config for SmoothNet filter trained on Human3.6M data with a window size of 2 | # 32. The model is trained using root-centered keypoint coordinates around the 3 | # pelvis (index:0), thus we set root_index=0 for the filter 4 | filter_cfg = dict( 5 | type='SmoothNetFilter', 6 | window_size=32, 7 | output_size=32, 8 | checkpoint='https://download.openmmlab.com/mmpose/plugin/smoothnet/' 9 | 'smoothnet_ws32_h36m.pth', 10 | hidden_size=512, 11 | res_hidden_size=256, 12 | num_blocks=3, 13 | root_index=0) 14 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/smoothnet_t64_h36m.py: -------------------------------------------------------------------------------- 1 | # Config for SmoothNet filter trained on Human3.6M data with a window size of 2 | # 64. The model is trained using root-centered keypoint coordinates around the 3 | # pelvis (index:0), thus we set root_index=0 for the filter 4 | filter_cfg = dict( 5 | type='SmoothNetFilter', 6 | window_size=64, 7 | output_size=64, 8 | checkpoint='https://download.openmmlab.com/mmpose/plugin/smoothnet/' 9 | 'smoothnet_ws64_h36m.pth', 10 | hidden_size=512, 11 | res_hidden_size=256, 12 | num_blocks=3, 13 | root_index=0) 14 | -------------------------------------------------------------------------------- /landmark/configs/_base_/filters/smoothnet_t8_h36m.py: -------------------------------------------------------------------------------- 1 | # Config for SmoothNet filter trained on Human3.6M data with a window size of 2 | # 8. The model is trained using root-centered keypoint coordinates around the 3 | # pelvis (index:0), thus we set root_index=0 for the filter 4 | filter_cfg = dict( 5 | type='SmoothNetFilter', 6 | window_size=8, 7 | output_size=8, 8 | checkpoint='https://download.openmmlab.com/mmpose/plugin/smoothnet/' 9 | 'smoothnet_ws8_h36m.pth', 10 | hidden_size=512, 11 | res_hidden_size=256, 12 | num_blocks=3, 13 | root_index=0) 14 | -------------------------------------------------------------------------------- /landmark/mmpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | from .version import __version__, short_version 5 | 6 | 7 | def digit_version(version_str): 8 | digit_version = [] 9 | for x in version_str.split('.'): 10 | if x.isdigit(): 11 | digit_version.append(int(x)) 12 | elif x.find('rc') != -1: 13 | patch_version = x.split('rc') 14 | digit_version.append(int(patch_version[0]) - 1) 15 | digit_version.append(int(patch_version[1])) 16 | return digit_version 17 | 18 | 19 | mmcv_minimum_version = '1.3.8' 20 | mmcv_maximum_version = '1.8.0' 21 | mmcv_version = digit_version(mmcv.__version__) 22 | 23 | 24 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 25 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 26 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 27 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 28 | 29 | __all__ = ['__version__', 'short_version'] 30 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import (collect_multi_frames, inference_bottom_up_pose_model, 3 | inference_gesture_model, inference_top_down_pose_model, 4 | init_pose_model, process_mmdet_results, 5 | vis_pose_result) 6 | from .inference_3d import (extract_pose_sequence, inference_interhand_3d_model, 7 | inference_mesh_model, inference_pose_lifter_model, 8 | vis_3d_mesh_result, vis_3d_pose_result) 9 | from .inference_tracking import get_track_id, vis_pose_tracking_result 10 | from .test import multi_gpu_test, single_gpu_test 11 | from .train import init_random_seed, train_model 12 | 13 | __all__ = [ 14 | 'train_model', 15 | 'init_pose_model', 16 | 'inference_top_down_pose_model', 17 | 'inference_bottom_up_pose_model', 18 | 'multi_gpu_test', 19 | 'single_gpu_test', 20 | 'vis_pose_result', 21 | 'get_track_id', 22 | 'vis_pose_tracking_result', 23 | 'inference_pose_lifter_model', 24 | 'vis_3d_pose_result', 25 | 'inference_interhand_3d_model', 26 | 'extract_pose_sequence', 27 | 'inference_mesh_model', 28 | 'vis_3d_mesh_result', 29 | 'process_mmdet_results', 30 | 'init_random_seed', 31 | 'collect_multi_frames', 32 | 'inference_gesture_model', 33 | ] 34 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .webcam_executor import WebcamExecutor 3 | 4 | __all__ = ['WebcamExecutor'] 5 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_visualizer_node import BaseVisualizerNode 3 | from .helper_nodes import MonitorNode, ObjectAssignerNode, RecorderNode 4 | from .model_nodes import (DetectorNode, PoseTrackerNode, 5 | TopDownPoseEstimatorNode) 6 | from .node import Node 7 | from .registry import NODES 8 | from .visualizer_nodes import (BigeyeEffectNode, NoticeBoardNode, 9 | ObjectVisualizerNode, SunglassesEffectNode) 10 | 11 | __all__ = [ 12 | 'BaseVisualizerNode', 'NODES', 'MonitorNode', 'ObjectAssignerNode', 13 | 'RecorderNode', 'DetectorNode', 'PoseTrackerNode', 14 | 'TopDownPoseEstimatorNode', 'Node', 'BigeyeEffectNode', 'NoticeBoardNode', 15 | 'ObjectVisualizerNode', 'ObjectAssignerNode', 'SunglassesEffectNode' 16 | ] 17 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/nodes/helper_nodes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .monitor_node import MonitorNode 3 | from .object_assigner_node import ObjectAssignerNode 4 | from .recorder_node import RecorderNode 5 | 6 | __all__ = ['MonitorNode', 'ObjectAssignerNode', 'RecorderNode'] 7 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/nodes/model_nodes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .detector_node import DetectorNode 3 | from .hand_gesture_node import HandGestureRecognizerNode 4 | from .pose_estimator_node import TopDownPoseEstimatorNode 5 | from .pose_tracker_node import PoseTrackerNode 6 | 7 | __all__ = [ 8 | 'DetectorNode', 'TopDownPoseEstimatorNode', 'PoseTrackerNode', 9 | 'HandGestureRecognizerNode' 10 | ] 11 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/nodes/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry 3 | 4 | NODES = Registry('node') 5 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/nodes/visualizer_nodes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bigeye_effect_node import BigeyeEffectNode 3 | from .notice_board_node import NoticeBoardNode 4 | from .object_visualizer_node import ObjectVisualizerNode 5 | from .sunglasses_effect_node import SunglassesEffectNode 6 | 7 | __all__ = [ 8 | 'ObjectVisualizerNode', 'NoticeBoardNode', 'SunglassesEffectNode', 9 | 'BigeyeEffectNode' 10 | ] 11 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .buffer import BufferManager 3 | from .event import EventManager 4 | from .image_capture import ImageCapture 5 | from .message import FrameMessage, Message, VideoEndingMessage 6 | from .misc import (copy_and_paste, expand_and_clamp, get_cached_file_path, 7 | get_config_path, is_image_file, limit_max_fps, 8 | load_image_from_disk_or_url, screen_matting) 9 | from .pose import (get_eye_keypoint_ids, get_face_keypoint_ids, 10 | get_hand_keypoint_ids, get_mouth_keypoint_ids, 11 | get_wrist_keypoint_ids) 12 | 13 | __all__ = [ 14 | 'BufferManager', 'EventManager', 'FrameMessage', 'Message', 15 | 'limit_max_fps', 'VideoEndingMessage', 'load_image_from_disk_or_url', 16 | 'get_cached_file_path', 'screen_matting', 'get_config_path', 17 | 'expand_and_clamp', 'copy_and_paste', 'is_image_file', 'ImageCapture', 18 | 'get_eye_keypoint_ids', 'get_face_keypoint_ids', 'get_wrist_keypoint_ids', 19 | 'get_mouth_keypoint_ids', 'get_hand_keypoint_ids' 20 | ] 21 | -------------------------------------------------------------------------------- /landmark/mmpose/apis/webcam/utils/image_capture.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Union 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from .misc import load_image_from_disk_or_url 8 | 9 | 10 | class ImageCapture: 11 | """A mock-up of cv2.VideoCapture that always return a const image. 12 | 13 | Args: 14 | image (str | ndarray): The image path or image data 15 | """ 16 | 17 | def __init__(self, image: Union[str, np.ndarray]): 18 | if isinstance(image, str): 19 | self.image = load_image_from_disk_or_url(image) 20 | else: 21 | self.image = image 22 | 23 | def isOpened(self): 24 | return (self.image is not None) 25 | 26 | def read(self): 27 | return True, self.image.copy() 28 | 29 | def release(self): 30 | pass 31 | 32 | def get(self, propId): 33 | if propId == cv2.CAP_PROP_FRAME_WIDTH: 34 | return self.image.shape[1] 35 | elif propId == cv2.CAP_PROP_FRAME_HEIGHT: 36 | return self.image.shape[0] 37 | elif propId == cv2.CAP_PROP_FPS: 38 | return np.nan 39 | else: 40 | raise NotImplementedError() 41 | -------------------------------------------------------------------------------- /landmark/mmpose/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bbox import * # noqa: F401, F403 3 | from .camera import * # noqa: F401, F403 4 | from .evaluation import * # noqa: F401, F403 5 | from .fp16 import * # noqa: F401, F403 6 | from .optimizers import * # noqa: F401, F403 7 | from .post_processing import * # noqa: F401, F403 8 | from .utils import * # noqa: F401, F403 9 | from .visualization import * # noqa: F401, F403 10 | -------------------------------------------------------------------------------- /landmark/mmpose/core/bbox/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .transforms import (bbox_cs2xywh, bbox_xywh2cs, bbox_xywh2xyxy, 3 | bbox_xyxy2xywh) 4 | 5 | __all__ = ['bbox_xywh2xyxy', 'bbox_xyxy2xywh', 'bbox_xywh2cs', 'bbox_cs2xywh'] 6 | -------------------------------------------------------------------------------- /landmark/mmpose/core/camera/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .camera_base import CAMERAS 3 | from .single_camera import SimpleCamera 4 | from .single_camera_torch import SimpleCameraTorch 5 | 6 | __all__ = ['CAMERAS', 'SimpleCamera', 'SimpleCameraTorch'] 7 | -------------------------------------------------------------------------------- /landmark/mmpose/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bottom_up_eval import (aggregate_scale, aggregate_stage_flip, 3 | flip_feature_maps, get_group_preds, 4 | split_ae_outputs) 5 | from .eval_hooks import DistEvalHook, EvalHook 6 | from .mesh_eval import compute_similarity_transform 7 | from .pose3d_eval import keypoint_3d_auc, keypoint_3d_pck, keypoint_mpjpe 8 | from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_pck_accuracy, 9 | keypoints_from_heatmaps, keypoints_from_heatmaps3d, 10 | keypoints_from_regression, 11 | multilabel_classification_accuracy, 12 | pose_pck_accuracy, post_dark_udp) 13 | 14 | __all__ = [ 15 | 'EvalHook', 'DistEvalHook', 'pose_pck_accuracy', 'keypoints_from_heatmaps', 16 | 'keypoints_from_regression', 'keypoint_pck_accuracy', 'keypoint_3d_pck', 17 | 'keypoint_3d_auc', 'keypoint_auc', 'keypoint_epe', 'get_group_preds', 18 | 'split_ae_outputs', 'flip_feature_maps', 'aggregate_stage_flip', 19 | 'aggregate_scale', 'compute_similarity_transform', 'post_dark_udp', 20 | 'keypoint_mpjpe', 'keypoints_from_heatmaps3d', 21 | 'multilabel_classification_accuracy' 22 | ] 23 | -------------------------------------------------------------------------------- /landmark/mmpose/core/fp16/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .decorators import auto_fp16, force_fp32 3 | from .hooks import Fp16OptimizerHook, wrap_fp16_model 4 | from .utils import cast_tensor_type 5 | 6 | __all__ = [ 7 | 'auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model', 8 | 'cast_tensor_type' 9 | ] 10 | -------------------------------------------------------------------------------- /landmark/mmpose/core/fp16/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import abc 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def cast_tensor_type(inputs, src_type, dst_type): 9 | """Recursively convert Tensor in inputs from src_type to dst_type. 10 | 11 | Args: 12 | inputs: Inputs that to be casted. 13 | src_type (torch.dtype): Source type. 14 | dst_type (torch.dtype): Destination type. 15 | 16 | Returns: 17 | The same type with inputs, but all contained Tensors have been cast. 18 | """ 19 | if isinstance(inputs, torch.Tensor): 20 | return inputs.to(dst_type) 21 | elif isinstance(inputs, str): 22 | return inputs 23 | elif isinstance(inputs, np.ndarray): 24 | return inputs 25 | elif isinstance(inputs, abc.Mapping): 26 | return type(inputs)({ 27 | k: cast_tensor_type(v, src_type, dst_type) 28 | for k, v in inputs.items() 29 | }) 30 | elif isinstance(inputs, abc.Iterable): 31 | return type(inputs)( 32 | cast_tensor_type(item, src_type, dst_type) for item in inputs) 33 | 34 | return inputs 35 | -------------------------------------------------------------------------------- /landmark/mmpose/core/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, 3 | build_optimizer_constructor, build_optimizers) 4 | 5 | __all__ = [ 6 | 'build_optimizers', 'build_optimizer_constructor', 'OPTIMIZERS', 7 | 'OPTIMIZER_BUILDERS' 8 | ] 9 | -------------------------------------------------------------------------------- /landmark/mmpose/core/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .nms import nearby_joints_nms, oks_iou, oks_nms, soft_oks_nms 4 | from .one_euro_filter import OneEuroFilter 5 | from .post_transforms import (affine_transform, flip_back, fliplr_joints, 6 | fliplr_regression, get_affine_transform, 7 | get_warp_matrix, rotate_point, transform_preds, 8 | warp_affine_joints) 9 | from .smoother import Smoother 10 | 11 | __all__ = [ 12 | 'oks_nms', 'soft_oks_nms', 'nearby_joints_nms', 'affine_transform', 13 | 'rotate_point', 'flip_back', 'fliplr_joints', 'fliplr_regression', 14 | 'transform_preds', 'get_affine_transform', 'get_warp_matrix', 15 | 'warp_affine_joints', 'oks_iou', 'OneEuroFilter', 'Smoother' 16 | ] 17 | -------------------------------------------------------------------------------- /landmark/mmpose/core/post_processing/temporal_filters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_filter 3 | from .gaussian_filter import GaussianFilter 4 | from .one_euro_filter import OneEuroFilter 5 | from .savizky_golay_filter import SavizkyGolayFilter 6 | from .smoothnet_filter import SmoothNetFilter 7 | 8 | __all__ = [ 9 | 'build_filter', 'GaussianFilter', 'OneEuroFilter', 'SavizkyGolayFilter', 10 | 'SmoothNetFilter' 11 | ] 12 | -------------------------------------------------------------------------------- /landmark/mmpose/core/post_processing/temporal_filters/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry 3 | 4 | FILTERS = Registry('filters') 5 | 6 | 7 | def build_filter(cfg): 8 | """Build filters function.""" 9 | return FILTERS.build(cfg) 10 | -------------------------------------------------------------------------------- /landmark/mmpose/core/post_processing/temporal_filters/filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class TemporalFilter(metaclass=ABCMeta): 6 | """Base class of temporal filter. 7 | 8 | A subclass should implement the method __call__(). 9 | 10 | Parameters: 11 | window_size (int): the size of the sliding window. 12 | """ 13 | 14 | # If the filter can be shared by multiple humans or targets 15 | _shareable: bool = True 16 | 17 | def __init__(self, window_size=1): 18 | self._window_size = window_size 19 | 20 | @property 21 | def window_size(self): 22 | return self._window_size 23 | 24 | @property 25 | def shareable(self): 26 | return self._shareable 27 | 28 | @abstractmethod 29 | def __call__(self, x): 30 | """Apply filter to a pose sequence. 31 | 32 | Note: 33 | T: The temporal length of the pose sequence 34 | K: The keypoint number of each target 35 | C: The keypoint coordinate dimension 36 | 37 | Args: 38 | x (np.ndarray): input pose sequence in shape [T, K, C] 39 | 40 | Returns: 41 | np.ndarray: Smoothed pose sequence in shape [T, K, C] 42 | """ 43 | -------------------------------------------------------------------------------- /landmark/mmpose/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import allreduce_grads, sync_random_seed 3 | from .model_util_hooks import ModelSetEpochHook 4 | from .regularizations import WeightNormClipHook 5 | 6 | __all__ = [ 7 | 'allreduce_grads', 'WeightNormClipHook', 'sync_random_seed', 8 | 'ModelSetEpochHook' 9 | ] 10 | -------------------------------------------------------------------------------- /landmark/mmpose/core/utils/model_util_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import HOOKS, Hook 3 | 4 | 5 | @HOOKS.register_module() 6 | class ModelSetEpochHook(Hook): 7 | """The hook that tells model the current epoch in training.""" 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def before_epoch(self, runner): 13 | runner.model.module.set_train_epoch(runner.epoch + 1) 14 | -------------------------------------------------------------------------------- /landmark/mmpose/core/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .image import (imshow_bboxes, imshow_keypoints, imshow_keypoints_3d, 3 | imshow_mesh_3d, imshow_multiview_keypoints_3d) 4 | 5 | __all__ = [ 6 | 'imshow_keypoints', 'imshow_keypoints_3d', 'imshow_bboxes', 7 | 'imshow_mesh_3d', 'imshow_multiview_keypoints_3d' 8 | ] 9 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | 4 | 5 | @DATASETS.register_module() 6 | class RepeatDataset: 7 | """A wrapper of repeated dataset. 8 | 9 | The length of repeated dataset will be `times` larger than the original 10 | dataset. This is useful when the data loading time is long but the dataset 11 | is small. Using RepeatDataset can reduce the data loading time between 12 | epochs. 13 | 14 | Args: 15 | dataset (:obj:`Dataset`): The dataset to be repeated. 16 | times (int): Repeat times. 17 | """ 18 | 19 | def __init__(self, dataset, times): 20 | self.dataset = dataset 21 | self.times = times 22 | 23 | self._ori_len = len(self.dataset) 24 | 25 | def __getitem__(self, idx): 26 | """Get data.""" 27 | return self.dataset[idx % self._ori_len] 28 | 29 | def __len__(self): 30 | """Length after repetition.""" 31 | return self.times * self._ori_len 32 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/animal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .animal_ap10k_dataset import AnimalAP10KDataset 3 | from .animal_atrw_dataset import AnimalATRWDataset 4 | from .animal_fly_dataset import AnimalFlyDataset 5 | from .animal_horse10_dataset import AnimalHorse10Dataset 6 | from .animal_locust_dataset import AnimalLocustDataset 7 | from .animal_macaque_dataset import AnimalMacaqueDataset 8 | from .animal_pose_dataset import AnimalPoseDataset 9 | from .animal_zebra_dataset import AnimalZebraDataset 10 | 11 | __all__ = [ 12 | 'AnimalHorse10Dataset', 'AnimalMacaqueDataset', 'AnimalFlyDataset', 13 | 'AnimalLocustDataset', 'AnimalZebraDataset', 'AnimalATRWDataset', 14 | 'AnimalPoseDataset', 'AnimalAP10KDataset' 15 | ] 16 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/animal/animal_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class AnimalBaseDataset(Dataset, metaclass=ABCMeta): 8 | """This class has been deprecated and replaced by 9 | Kpt2dSviewRgbImgTopDownDataset.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | raise (ImportError( 13 | 'AnimalBaseDataset has been replaced by ' 14 | 'Kpt2dSviewRgbImgTopDownDataset,' 15 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 16 | ) 17 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .kpt_2d_sview_rgb_img_bottom_up_dataset import \ 3 | Kpt2dSviewRgbImgBottomUpDataset 4 | from .kpt_2d_sview_rgb_img_top_down_dataset import \ 5 | Kpt2dSviewRgbImgTopDownDataset 6 | from .kpt_2d_sview_rgb_vid_top_down_dataset import \ 7 | Kpt2dSviewRgbVidTopDownDataset 8 | from .kpt_3d_mview_rgb_img_direct_dataset import Kpt3dMviewRgbImgDirectDataset 9 | from .kpt_3d_sview_kpt_2d_dataset import Kpt3dSviewKpt2dDataset 10 | from .kpt_3d_sview_rgb_img_top_down_dataset import \ 11 | Kpt3dSviewRgbImgTopDownDataset 12 | 13 | __all__ = [ 14 | 'Kpt3dMviewRgbImgDirectDataset', 'Kpt2dSviewRgbImgTopDownDataset', 15 | 'Kpt3dSviewRgbImgTopDownDataset', 'Kpt2dSviewRgbImgBottomUpDataset', 16 | 'Kpt3dSviewKpt2dDataset', 'Kpt2dSviewRgbVidTopDownDataset' 17 | ] 18 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/body3d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .body3d_h36m_dataset import Body3DH36MDataset 3 | from .body3d_mpi_inf_3dhp_dataset import Body3DMpiInf3dhpDataset 4 | from .body3d_mview_direct_campus_dataset import Body3DMviewDirectCampusDataset 5 | from .body3d_mview_direct_panoptic_dataset import \ 6 | Body3DMviewDirectPanopticDataset 7 | from .body3d_mview_direct_shelf_dataset import Body3DMviewDirectShelfDataset 8 | from .body3d_semi_supervision_dataset import Body3DSemiSupervisionDataset 9 | 10 | __all__ = [ 11 | 'Body3DH36MDataset', 'Body3DSemiSupervisionDataset', 12 | 'Body3DMpiInf3dhpDataset', 'Body3DMviewDirectPanopticDataset', 13 | 'Body3DMviewDirectShelfDataset', 'Body3DMviewDirectCampusDataset' 14 | ] 15 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/body3d/body3d_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class Body3DBaseDataset(Dataset, metaclass=ABCMeta): 8 | """This class has been deprecated and replaced by 9 | Kpt3dSviewKpt2dDataset.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | raise (ImportError( 13 | 'Body3DBaseDataset has been replaced by ' 14 | 'Kpt3dSviewKpt2dDataset' 15 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 16 | ) 17 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/bottom_up/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bottom_up_aic import BottomUpAicDataset 3 | from .bottom_up_coco import BottomUpCocoDataset 4 | from .bottom_up_coco_wholebody import BottomUpCocoWholeBodyDataset 5 | from .bottom_up_crowdpose import BottomUpCrowdPoseDataset 6 | from .bottom_up_mhp import BottomUpMhpDataset 7 | 8 | __all__ = [ 9 | 'BottomUpCocoDataset', 'BottomUpCrowdPoseDataset', 'BottomUpMhpDataset', 10 | 'BottomUpAicDataset', 'BottomUpCocoWholeBodyDataset' 11 | ] 12 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/bottom_up/bottom_up_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class BottomUpBaseDataset(Dataset): 6 | """This class has been deprecated and replaced by 7 | Kpt2dSviewRgbImgBottomUpDataset.""" 8 | 9 | def __init__(self, *args, **kwargs): 10 | raise (ImportError( 11 | 'BottomUpBaseDataset has been replaced by ' 12 | 'Kpt2dSviewRgbImgBottomUpDataset,' 13 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 14 | ) 15 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/face/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .face_300w_dataset import Face300WDataset 3 | from .face_aflw_dataset import FaceAFLWDataset 4 | from .face_coco_wholebody_dataset import FaceCocoWholeBodyDataset 5 | from .face_cofw_dataset import FaceCOFWDataset 6 | from .face_wflw_dataset import FaceWFLWDataset 7 | 8 | __all__ = [ 9 | 'Face300WDataset', 'FaceAFLWDataset', 'FaceWFLWDataset', 'FaceCOFWDataset', 10 | 'FaceCocoWholeBodyDataset' 11 | ] 12 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/face/face_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class FaceBaseDataset(Dataset, metaclass=ABCMeta): 8 | """This class has been deprecated and replaced by 9 | Kpt2dSviewRgbImgTopDownDataset.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | raise (ImportError( 13 | 'FaceBaseDataset has been replaced by ' 14 | 'Kpt2dSviewRgbImgTopDownDataset,' 15 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 16 | ) 17 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/fashion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .deepfashion_dataset import DeepFashionDataset 3 | 4 | __all__ = ['DeepFashionDataset'] 5 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/fashion/fashion_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class FashionBaseDataset(Dataset, metaclass=ABCMeta): 8 | """This class has been deprecated and replaced by 9 | Kpt2dSviewRgbImgTopDownDataset.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | raise (ImportError( 13 | 'FashionBaseDataset has been replaced by ' 14 | 'Kpt2dSviewRgbImgTopDownDataset,' 15 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 16 | ) 17 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/gesture/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .nvgesture_dataset import NVGestureDataset 3 | 4 | __all__ = ['NVGestureDataset'] 5 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/hand/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .freihand_dataset import FreiHandDataset 3 | from .hand_coco_wholebody_dataset import HandCocoWholeBodyDataset 4 | from .interhand2d_dataset import InterHand2DDataset 5 | from .interhand3d_dataset import InterHand3DDataset 6 | from .onehand10k_dataset import OneHand10KDataset 7 | from .panoptic_hand2d_dataset import PanopticDataset 8 | from .rhd2d_dataset import Rhd2DDataset 9 | 10 | __all__ = [ 11 | 'FreiHandDataset', 'InterHand2DDataset', 'InterHand3DDataset', 12 | 'OneHand10KDataset', 'PanopticDataset', 'Rhd2DDataset', 13 | 'HandCocoWholeBodyDataset' 14 | ] 15 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/hand/hand_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class HandBaseDataset(Dataset, metaclass=ABCMeta): 8 | """This class has been deprecated and replaced by 9 | Kpt2dSviewRgbImgTopDownDataset.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | raise (ImportError( 13 | 'HandBaseDataset has been replaced by ' 14 | 'Kpt2dSviewRgbImgTopDownDataset,' 15 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 16 | ) 17 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mesh_adv_dataset import MeshAdversarialDataset 3 | from .mesh_h36m_dataset import MeshH36MDataset 4 | from .mesh_mix_dataset import MeshMixDataset 5 | from .mosh_dataset import MoshDataset 6 | 7 | __all__ = [ 8 | 'MeshH36MDataset', 'MoshDataset', 'MeshMixDataset', 9 | 'MeshAdversarialDataset' 10 | ] 11 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/top_down/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .topdown_aic_dataset import TopDownAicDataset 3 | from .topdown_coco_dataset import TopDownCocoDataset 4 | from .topdown_coco_wholebody_dataset import TopDownCocoWholeBodyDataset 5 | from .topdown_crowdpose_dataset import TopDownCrowdPoseDataset 6 | from .topdown_h36m_dataset import TopDownH36MDataset 7 | from .topdown_halpe_dataset import TopDownHalpeDataset 8 | from .topdown_jhmdb_dataset import TopDownJhmdbDataset 9 | from .topdown_mhp_dataset import TopDownMhpDataset 10 | from .topdown_mpii_dataset import TopDownMpiiDataset 11 | from .topdown_mpii_trb_dataset import TopDownMpiiTrbDataset 12 | from .topdown_ochuman_dataset import TopDownOCHumanDataset 13 | from .topdown_posetrack18_dataset import TopDownPoseTrack18Dataset 14 | from .topdown_posetrack18_video_dataset import TopDownPoseTrack18VideoDataset 15 | 16 | __all__ = [ 17 | 'TopDownAicDataset', 18 | 'TopDownCocoDataset', 19 | 'TopDownCocoWholeBodyDataset', 20 | 'TopDownCrowdPoseDataset', 21 | 'TopDownMpiiDataset', 22 | 'TopDownMpiiTrbDataset', 23 | 'TopDownOCHumanDataset', 24 | 'TopDownPoseTrack18Dataset', 25 | 'TopDownJhmdbDataset', 26 | 'TopDownMhpDataset', 27 | 'TopDownH36MDataset', 28 | 'TopDownHalpeDataset', 29 | 'TopDownPoseTrack18VideoDataset', 30 | ] 31 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/datasets/top_down/topdown_base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class TopDownBaseDataset(Dataset, metaclass=ABCMeta): 8 | """This class has been deprecated and replaced by 9 | Kpt2dSviewRgbImgTopDownDataset.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | raise (ImportError( 13 | 'TopDownBaseDataset has been replaced by ' 14 | 'Kpt2dSviewRgbImgTopDownDataset,' 15 | 'check https://github.com/open-mmlab/mmpose/pull/663 for details.') 16 | ) 17 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bottom_up_transform import * # noqa 3 | from .gesture_transform import * # noqa 4 | from .hand_transform import * # noqa 5 | from .loading import * # noqa 6 | from .mesh_transform import * # noqa 7 | from .pose3d_transform import * # noqa 8 | from .shared_transform import * # noqa 9 | from .top_down_transform import * # noqa 10 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from .builder import DATASETS, PIPELINES 5 | 6 | __all__ = ['DATASETS', 'PIPELINES'] 7 | 8 | warnings.simplefilter('once', DeprecationWarning) 9 | warnings.warn( 10 | 'Registries (DATASETS, PIPELINES) have been moved to ' 11 | 'mmpose.datasets.builder. Importing from ' 12 | 'mmpose.models.registry will be deprecated in the future.', 13 | DeprecationWarning) 14 | -------------------------------------------------------------------------------- /landmark/mmpose/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /landmark/mmpose/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa 3 | from .builder import (BACKBONES, HEADS, LOSSES, MESH_MODELS, NECKS, POSENETS, 4 | build_backbone, build_head, build_loss, build_mesh_model, 5 | build_neck, build_posenet) 6 | from .detectors import * # noqa 7 | from .heads import * # noqa 8 | from .losses import * # noqa 9 | from .necks import * # noqa 10 | from .utils import * # noqa 11 | 12 | __all__ = [ 13 | 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'MESH_MODELS', 14 | 'build_backbone', 'build_head', 'build_loss', 'build_posenet', 15 | 'build_neck', 'build_mesh_model' 16 | ] 17 | -------------------------------------------------------------------------------- /landmark/mmpose/models/backbones/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .channel_shuffle import channel_shuffle 3 | from .inverted_residual import InvertedResidual 4 | from .make_divisible import make_divisible 5 | from .se_layer import SELayer 6 | from .utils import load_checkpoint 7 | 8 | __all__ = [ 9 | 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', 10 | 'load_checkpoint' 11 | ] 12 | -------------------------------------------------------------------------------- /landmark/mmpose/models/backbones/utils/channel_shuffle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def channel_shuffle(x, groups): 6 | """Channel Shuffle operation. 7 | 8 | This function enables cross-group information flow for multiple groups 9 | convolution layers. 10 | 11 | Args: 12 | x (Tensor): The input tensor. 13 | groups (int): The number of groups to divide the input tensor 14 | in the channel dimension. 15 | 16 | Returns: 17 | Tensor: The output tensor after channel shuffle operation. 18 | """ 19 | 20 | batch_size, num_channels, height, width = x.size() 21 | assert (num_channels % groups == 0), ('num_channels should be ' 22 | 'divisible by groups') 23 | channels_per_group = num_channels // groups 24 | 25 | x = x.view(batch_size, groups, channels_per_group, height, width) 26 | x = torch.transpose(x, 1, 2).contiguous() 27 | x = x.view(batch_size, groups * channels_per_group, height, width) 28 | 29 | return x 30 | -------------------------------------------------------------------------------- /landmark/mmpose/models/backbones/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number down to the nearest value that can 6 | be divisible by the divisor. 7 | 8 | Args: 9 | value (int): The original channel number. 10 | divisor (int): The divisor to fully divide the channel number. 11 | min_value (int, optional): The minimum value of the output channel. 12 | Default: None, means that the minimum value equal to the divisor. 13 | min_ratio (float, optional): The minimum ratio of the rounded channel 14 | number to the original channel number. Default: 0.9. 15 | Returns: 16 | int: The modified output channel number 17 | """ 18 | 19 | if min_value is None: 20 | min_value = divisor 21 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than (1-min_ratio). 23 | if new_value < min_ratio * value: 24 | new_value += divisor 25 | return new_value 26 | -------------------------------------------------------------------------------- /landmark/mmpose/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import MODELS as MMCV_MODELS 3 | from mmcv.cnn import build_model_from_cfg 4 | from mmcv.utils import Registry 5 | 6 | MODELS = Registry( 7 | 'models', build_func=build_model_from_cfg, parent=MMCV_MODELS) 8 | 9 | BACKBONES = MODELS 10 | NECKS = MODELS 11 | HEADS = MODELS 12 | LOSSES = MODELS 13 | POSENETS = MODELS 14 | MESH_MODELS = MODELS 15 | 16 | 17 | def build_backbone(cfg): 18 | """Build backbone.""" 19 | return BACKBONES.build(cfg) 20 | 21 | 22 | def build_neck(cfg): 23 | """Build neck.""" 24 | return NECKS.build(cfg) 25 | 26 | 27 | def build_head(cfg): 28 | """Build head.""" 29 | return HEADS.build(cfg) 30 | 31 | 32 | def build_loss(cfg): 33 | """Build loss.""" 34 | return LOSSES.build(cfg) 35 | 36 | 37 | def build_posenet(cfg): 38 | """Build posenet.""" 39 | return POSENETS.build(cfg) 40 | 41 | 42 | def build_mesh_model(cfg): 43 | """Build mesh model.""" 44 | return MESH_MODELS.build(cfg) 45 | -------------------------------------------------------------------------------- /landmark/mmpose/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .associative_embedding import AssociativeEmbedding 3 | from .cid import CID 4 | from .gesture_recognizer import GestureRecognizer 5 | from .interhand_3d import Interhand3D 6 | from .mesh import ParametricMesh 7 | from .multi_task import MultiTask 8 | from .multiview_pose import (DetectAndRegress, VoxelCenterDetector, 9 | VoxelSinglePose) 10 | from .one_stage import DisentangledKeypointRegressor 11 | from .pose_lifter import PoseLifter 12 | from .posewarper import PoseWarper 13 | from .top_down import TopDown 14 | 15 | __all__ = [ 16 | 'TopDown', 'AssociativeEmbedding', 'CID', 'ParametricMesh', 'MultiTask', 17 | 'PoseLifter', 'Interhand3D', 'PoseWarper', 'DetectAndRegress', 18 | 'VoxelCenterDetector', 'VoxelSinglePose', 'GestureRecognizer', 19 | 'DisentangledKeypointRegressor' 20 | ] 21 | -------------------------------------------------------------------------------- /landmark/mmpose/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ae_higher_resolution_head import AEHigherResolutionHead 3 | from .ae_multi_stage_head import AEMultiStageHead 4 | from .ae_simple_head import AESimpleHead 5 | from .cid_head import CIDHead 6 | from .deconv_head import DeconvHead 7 | from .deeppose_regression_head import DeepposeRegressionHead 8 | from .dekr_head import DEKRHead 9 | from .hmr_head import HMRMeshHead 10 | from .interhand_3d_head import Interhand3DHead 11 | from .mtut_head import MultiModalSSAHead 12 | from .temporal_regression_head import TemporalRegressionHead 13 | from .topdown_heatmap_base_head import TopdownHeatmapBaseHead 14 | from .topdown_heatmap_multi_stage_head import (TopdownHeatmapMSMUHead, 15 | TopdownHeatmapMultiStageHead) 16 | from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead 17 | from .vipnas_heatmap_simple_head import ViPNASHeatmapSimpleHead 18 | from .voxelpose_head import CuboidCenterHead, CuboidPoseHead 19 | 20 | __all__ = [ 21 | 'TopdownHeatmapSimpleHead', 'TopdownHeatmapMultiStageHead', 22 | 'TopdownHeatmapMSMUHead', 'TopdownHeatmapBaseHead', 23 | 'AEHigherResolutionHead', 'AESimpleHead', 'AEMultiStageHead', 'CIDHead', 24 | 'DeepposeRegressionHead', 'TemporalRegressionHead', 'Interhand3DHead', 25 | 'HMRMeshHead', 'DeconvHead', 'ViPNASHeatmapSimpleHead', 'CuboidCenterHead', 26 | 'CuboidPoseHead', 'MultiModalSSAHead', 'DEKRHead' 27 | ] 28 | -------------------------------------------------------------------------------- /landmark/mmpose/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .classfication_loss import BCELoss 3 | from .heatmap_loss import AdaptiveWingLoss, FocalHeatmapLoss 4 | from .mesh_loss import GANLoss, MeshLoss 5 | from .mse_loss import JointsMSELoss, JointsOHKMMSELoss 6 | from .multi_loss_factory import AELoss, HeatmapLoss, MultiLossFactory 7 | from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss, 8 | SemiSupervisionLoss, SmoothL1Loss, 9 | SoftWeightSmoothL1Loss, SoftWingLoss, WingLoss) 10 | 11 | __all__ = [ 12 | 'JointsMSELoss', 'JointsOHKMMSELoss', 'HeatmapLoss', 'AELoss', 13 | 'MultiLossFactory', 'MeshLoss', 'GANLoss', 'SmoothL1Loss', 'WingLoss', 14 | 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss', 15 | 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss', 16 | 'SoftWeightSmoothL1Loss', 'FocalHeatmapLoss' 17 | ] 18 | -------------------------------------------------------------------------------- /landmark/mmpose/models/losses/classfication_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import LOSSES 6 | 7 | 8 | @LOSSES.register_module() 9 | class BCELoss(nn.Module): 10 | """Binary Cross Entropy loss.""" 11 | 12 | def __init__(self, use_target_weight=False, loss_weight=1.): 13 | super().__init__() 14 | self.criterion = F.binary_cross_entropy 15 | self.use_target_weight = use_target_weight 16 | self.loss_weight = loss_weight 17 | 18 | def forward(self, output, target, target_weight=None): 19 | """Forward function. 20 | 21 | Note: 22 | - batch_size: N 23 | - num_labels: K 24 | 25 | Args: 26 | output (torch.Tensor[N, K]): Output classification. 27 | target (torch.Tensor[N, K]): Target classification. 28 | target_weight (torch.Tensor[N, K] or torch.Tensor[N]): 29 | Weights across different labels. 30 | """ 31 | 32 | if self.use_target_weight: 33 | assert target_weight is not None 34 | loss = self.criterion(output, target, reduction='none') 35 | if target_weight.dim() == 1: 36 | target_weight = target_weight[:, None] 37 | loss = (loss * target_weight).mean() 38 | else: 39 | loss = self.criterion(output, target) 40 | 41 | return loss * self.loss_weight 42 | -------------------------------------------------------------------------------- /landmark/mmpose/models/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /landmark/mmpose/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fpn import FPN 3 | from .gap_neck import GlobalAveragePooling 4 | from .posewarper_neck import PoseWarperNeck 5 | from .tcformer_mta_neck import MTA 6 | 7 | __all__ = ['GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'MTA'] 8 | -------------------------------------------------------------------------------- /landmark/mmpose/models/necks/gap_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | @NECKS.register_module() 9 | class GlobalAveragePooling(nn.Module): 10 | """Global Average Pooling neck. 11 | 12 | Note that we use `view` to remove extra channel after pooling. We do not 13 | use `squeeze` as it will also remove the batch dimension when the tensor 14 | has a batch dimension of size 1, which can lead to unexpected errors. 15 | """ 16 | 17 | def __init__(self): 18 | super().__init__() 19 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 20 | 21 | def init_weights(self): 22 | pass 23 | 24 | def forward(self, inputs): 25 | if isinstance(inputs, tuple): 26 | outs = tuple([self.gap(x) for x in inputs]) 27 | outs = tuple( 28 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) 29 | elif isinstance(inputs, list): 30 | outs = [self.gap(x) for x in inputs] 31 | outs = [out.view(x.size(0), -1) for out, x in zip(outs, inputs)] 32 | elif isinstance(inputs, torch.Tensor): 33 | outs = self.gap(inputs) 34 | outs = outs.view(inputs.size(0), -1) 35 | else: 36 | raise TypeError('neck inputs should be tuple or torch.tensor') 37 | return outs 38 | -------------------------------------------------------------------------------- /landmark/mmpose/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from .builder import BACKBONES, HEADS, LOSSES, NECKS, POSENETS 5 | 6 | __all__ = ['BACKBONES', 'HEADS', 'LOSSES', 'NECKS', 'POSENETS'] 7 | 8 | warnings.simplefilter('once', DeprecationWarning) 9 | warnings.warn( 10 | 'Registries (BACKBONES, NECKS, HEADS, LOSSES, POSENETS) have ' 11 | 'been moved to mmpose.models.builder. Importing from ' 12 | 'mmpose.models.registry will be deprecated in the future.', 13 | DeprecationWarning) 14 | -------------------------------------------------------------------------------- /landmark/mmpose/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ckpt_convert import pvt_convert, tcformer_convert 3 | from .geometry import batch_rodrigues, quat_to_rotmat, rot6d_to_rotmat 4 | from .misc import torch_meshgrid_ij 5 | from .ops import resize 6 | from .realnvp import RealNVP 7 | from .rescore import DekrRescoreNet 8 | from .smpl import SMPL 9 | from .tcformer_utils import (TCFormerDynamicBlock, TCFormerRegularBlock, 10 | TokenConv, cluster_dpc_knn, merge_tokens, 11 | token2map, token_interp) 12 | from .transformer import PatchEmbed, PatchMerging, nchw_to_nlc, nlc_to_nchw 13 | 14 | __all__ = [ 15 | 'SMPL', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert', 16 | 'PatchMerging', 'batch_rodrigues', 'quat_to_rotmat', 'rot6d_to_rotmat', 17 | 'resize', 'RealNVP', 'torch_meshgrid_ij', 'token2map', 'TokenConv', 18 | 'TCFormerRegularBlock', 'TCFormerDynamicBlock', 'cluster_dpc_knn', 19 | 'merge_tokens', 'token_interp', 'tcformer_convert', 'DekrRescoreNet' 20 | ] 21 | -------------------------------------------------------------------------------- /landmark/mmpose/models/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from packaging import version 4 | 5 | _torch_version_meshgrid_indexing = version.parse( 6 | torch.__version__) >= version.parse('1.10.0a0') 7 | 8 | 9 | def torch_meshgrid_ij(*tensors): 10 | if _torch_version_meshgrid_indexing: 11 | return torch.meshgrid(*tensors, indexing='ij') 12 | else: 13 | return torch.meshgrid(*tensors) # Uses indexing='ij' by default 14 | -------------------------------------------------------------------------------- /landmark/mmpose/models/utils/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > output_h: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | if isinstance(size, torch.Size): 28 | size = tuple(int(x) for x in size) 29 | return F.interpolate(input, size, scale_factor, mode, align_corners) 30 | -------------------------------------------------------------------------------- /landmark/mmpose/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .setup_env import setup_multi_processes 5 | from .timer import StopWatch 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes' 9 | ] 10 | -------------------------------------------------------------------------------- /landmark/mmpose/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_basic_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmpose 6 | 7 | 8 | def collect_env(): 9 | env_info = collect_basic_env() 10 | env_info['MMPose'] = (mmpose.__version__ + '+' + get_git_hash(digits=7)) 11 | return env_info 12 | 13 | 14 | if __name__ == '__main__': 15 | for name, val in collect_env().items(): 16 | print(f'{name}: {val}') 17 | -------------------------------------------------------------------------------- /landmark/mmpose/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Use `get_logger` method in mmcv to get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmpose". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | return get_logger(__name__.split('.')[0], log_file, log_level) 26 | -------------------------------------------------------------------------------- /landmark/mmpose/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.29.0' 4 | short_version = __version__ 5 | 6 | 7 | def parse_version_info(version_str): 8 | version_info = [] 9 | for x in version_str.split('.'): 10 | if x.isdigit(): 11 | version_info.append(int(x)) 12 | elif x.find('rc') != -1: 13 | patch_version = x.split('rc') 14 | version_info.append(int(patch_version[0])) 15 | version_info.append(f'rc{patch_version[1]}') 16 | return tuple(version_info) 17 | 18 | 19 | version_info = parse_version_info(__version__) 20 | -------------------------------------------------------------------------------- /landmark/requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/build.txt 2 | -r requirements/runtime.txt 3 | -r requirements/tests.txt 4 | -r requirements/optional.txt 5 | -r requirements/poseval.txt 6 | -------------------------------------------------------------------------------- /landmark/requirements/albu.txt: -------------------------------------------------------------------------------- 1 | albumentations>=0.3.2 --no-binary qudida,albumentations 2 | -------------------------------------------------------------------------------- /landmark/requirements/build.txt: -------------------------------------------------------------------------------- 1 | # These must be installed before building mmpose 2 | numpy 3 | torch>=1.3 4 | -------------------------------------------------------------------------------- /landmark/requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | markdown 3 | myst-parser 4 | -e git+https://github.com/gaotongxiao/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | sphinx==4.0.2 6 | sphinx_copybutton 7 | sphinx_markdown_tables 8 | -------------------------------------------------------------------------------- /landmark/requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv-full>=1.3.8 2 | mmdet>=2.14.0 3 | mmtrack>=0.6.0 4 | -------------------------------------------------------------------------------- /landmark/requirements/optional.txt: -------------------------------------------------------------------------------- 1 | onnx 2 | onnxruntime 3 | pyrender 4 | requests 5 | smplx>=0.1.28 6 | trimesh 7 | -------------------------------------------------------------------------------- /landmark/requirements/poseval.txt: -------------------------------------------------------------------------------- 1 | poseval@git+https://github.com/svenkreiss/poseval.git 2 | -------------------------------------------------------------------------------- /landmark/requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv-full 2 | munkres 3 | poseval@git+https://github.com/svenkreiss/poseval.git 4 | regex 5 | scipy 6 | titlecase 7 | torch 8 | torchvision 9 | xtcocotools>=1.8 10 | -------------------------------------------------------------------------------- /landmark/requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | chumpy 2 | dataclasses; python_version == '3.6' 3 | json_tricks 4 | matplotlib 5 | munkres 6 | numpy 7 | opencv-python 8 | pillow 9 | scipy 10 | torchvision 11 | xtcocotools>=1.12 12 | -------------------------------------------------------------------------------- /landmark/requirements/tests.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | flake8 3 | interrogate 4 | isort==4.3.21 5 | pytest 6 | pytest-runner 7 | smplx>=0.1.28 8 | xdoctest>=0.10.0 9 | yapf 10 | -------------------------------------------------------------------------------- /landmark/setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | addopts=tests/ 9 | 10 | [yapf] 11 | based_on_style = pep8 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | split_penalty_import_names=0 15 | SPLIT_PENALTY_AFTER_OPENING_BRACKET=800 16 | 17 | [isort] 18 | line_length = 79 19 | multi_line_output = 0 20 | extra_standard_library = pkg_resources,setuptools 21 | known_first_party = mmpose 22 | known_third_party = PIL,cv2,h5py,json_tricks,matplotlib,mmcv,munkres,numpy,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,spacepy,titlecase,torch,torchvision,webcam_apis,xmltodict,xtcocotools 23 | no_lines_before = STDLIB,LOCALFOLDER 24 | default_section = THIRDPARTY 25 | -------------------------------------------------------------------------------- /landmark/tools/analysis/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config, DictAction 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Print the whole config') 9 | parser.add_argument('config', help='config file path') 10 | parser.add_argument( 11 | '--options', nargs='+', action=DictAction, help='arguments in dict') 12 | args = parser.parse_args() 13 | 14 | return args 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | 20 | cfg = Config.fromfile(args.config) 21 | if args.options is not None: 22 | cfg.merge_from_dict(args.options) 23 | print(f'Config:\n{cfg.pretty_text}') 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /landmark/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | CONFIG=$1 5 | CHECKPOINT=$2 6 | GPUS=$3 7 | NNODES=${NNODES:-1} 8 | NODE_RANK=${NODE_RANK:-0} 9 | PORT=${PORT:-29500} 10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | 12 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 13 | python -m torch.distributed.launch \ 14 | --nnodes=$NNODES \ 15 | --node_rank=$NODE_RANK \ 16 | --master_addr=$MASTER_ADDR \ 17 | --nproc_per_node=$GPUS \ 18 | --master_port=$PORT \ 19 | $(dirname "$0")/test.py \ 20 | $CONFIG \ 21 | $CHECKPOINT \ 22 | --launcher pytorch \ 23 | ${@:4} 24 | -------------------------------------------------------------------------------- /landmark/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | CONFIG=$1 5 | GPUS=$2 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/train.py \ 19 | $CONFIG \ 20 | --launcher pytorch ${@:3} 21 | -------------------------------------------------------------------------------- /landmark/tools/evaluate_wlfw.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CHECKPOINT=$2 3 | GPUS=$3 4 | 5 | for i in '' '_largepose' '_illumination' '_occlusion' '_blur' '_makeup' '_expression' 6 | do 7 | echo "Evaluate on $i subset" 8 | bash tools/dist_test.sh $CONFIG $CHECKPOINT $GPUS \ 9 | --eval NME \ 10 | --cfg-options data.test.ann_file=data/wflw/annotations/face_landmarks_wflw_test$i.json 11 | done 12 | 13 | 14 | -------------------------------------------------------------------------------- /landmark/tools/misc/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | from datetime import date 5 | 6 | import torch 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Process a checkpoint to be published') 12 | parser.add_argument('in_file', help='input checkpoint filename') 13 | parser.add_argument('out_file', help='output checkpoint filename') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def process_checkpoint(in_file, out_file): 19 | checkpoint = torch.load(in_file, map_location='cpu') 20 | # remove optimizer for smaller file size 21 | if 'optimizer' in checkpoint: 22 | del checkpoint['optimizer'] 23 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 24 | # add the code here. 25 | torch.save(checkpoint, out_file) 26 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 27 | if out_file.endswith('.pth'): 28 | out_file_name = out_file[:-4] 29 | else: 30 | out_file_name = out_file 31 | 32 | date_now = date.today().strftime('%Y%m%d') 33 | final_file = out_file_name + f'-{sha[:8]}_{date_now}.pth' 34 | subprocess.Popen(['mv', out_file, final_file]) 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | process_checkpoint(args.in_file, args.out_file) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /landmark/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | CHECKPOINT=$4 10 | GPUS=${GPUS:-8} 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 13 | PY_ARGS=${@:5} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 26 | -------------------------------------------------------------------------------- /landmark/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | GPUS=${GPUS:-8} 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | PY_ARGS=${@:5} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 26 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/cityscapes_1024x1024.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | crop_size = (1024, 1024) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations'), 8 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 9 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 10 | dict(type='RandomFlip', prob=0.5), 11 | dict(type='PhotoMetricDistortion'), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2048, 1024), 22 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 23 | flip=False, 24 | transforms=[ 25 | dict(type='Resize', keep_ratio=True), 26 | dict(type='RandomFlip'), 27 | dict(type='Normalize', **img_norm_cfg), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | train=dict(pipeline=train_pipeline), 34 | val=dict(pipeline=test_pipeline), 35 | test=dict(pipeline=test_pipeline)) 36 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/cityscapes_768x768.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | crop_size = (768, 768) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations'), 8 | dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), 9 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 10 | dict(type='RandomFlip', prob=0.5), 11 | dict(type='PhotoMetricDistortion'), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2049, 1025), 22 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 23 | flip=False, 24 | transforms=[ 25 | dict(type='Resize', keep_ratio=True), 26 | dict(type='RandomFlip'), 27 | dict(type='Normalize', **img_norm_cfg), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | train=dict(pipeline=train_pipeline), 34 | val=dict(pipeline=test_pipeline), 35 | test=dict(pipeline=test_pipeline)) 36 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/cityscapes_769x769.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | crop_size = (769, 769) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations'), 8 | dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), 9 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 10 | dict(type='RandomFlip', prob=0.5), 11 | dict(type='PhotoMetricDistortion'), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2049, 1025), 22 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 23 | flip=False, 24 | transforms=[ 25 | dict(type='Resize', keep_ratio=True), 26 | dict(type='RandomFlip'), 27 | dict(type='Normalize', **img_norm_cfg), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | train=dict(pipeline=train_pipeline), 34 | val=dict(pipeline=test_pipeline), 35 | test=dict(pipeline=test_pipeline)) 36 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/cityscapes_832x832.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | crop_size = (832, 832) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations'), 8 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 9 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 10 | dict(type='RandomFlip', prob=0.5), 11 | dict(type='PhotoMetricDistortion'), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2048, 1024), 22 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 23 | flip=False, 24 | transforms=[ 25 | dict(type='Resize', keep_ratio=True), 26 | dict(type='RandomFlip'), 27 | dict(type='Normalize', **img_norm_cfg), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | train=dict(pipeline=train_pipeline), 34 | val=dict(pipeline=test_pipeline), 35 | test=dict(pipeline=test_pipeline)) 36 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/pascal_voc12_aug.py: -------------------------------------------------------------------------------- 1 | _base_ = './pascal_voc12.py' 2 | # dataset settings 3 | data = dict( 4 | train=dict( 5 | ann_dir=['SegmentationClass', 'SegmentationClassAug'], 6 | split=[ 7 | 'ImageSets/Segmentation/train.txt', 8 | 'ImageSets/Segmentation/aug.txt' 9 | ])) 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | # dict(type='PaviLoggerHook') # for internal services 8 | ]) 9 | # yapf:enable 10 | dist_params = dict(backend='nccl') 11 | log_level = 'INFO' 12 | load_from = None 13 | resume_from = None 14 | workflow = [('train', 1)] 15 | cudnn_benchmark = True 16 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/ann_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='ANNHead', 19 | in_channels=[1024, 2048], 20 | in_index=[2, 3], 21 | channels=512, 22 | project_channels=256, 23 | query_scales=(1, ), 24 | key_pool_scales=(1, 3, 6, 8), 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/apcnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='APCHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | pool_scales=(1, 2, 3, 6), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=dict(type='SyncBN', requires_grad=True), 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/ccnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='CCHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | recurrence=2, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/cgnet.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | backbone=dict( 6 | type='CGNet', 7 | norm_cfg=norm_cfg, 8 | in_channels=3, 9 | num_channels=(32, 64, 128), 10 | num_blocks=(3, 21), 11 | dilations=(2, 4), 12 | reductions=(8, 16)), 13 | decode_head=dict( 14 | type='FCNHead', 15 | in_channels=256, 16 | in_index=2, 17 | channels=256, 18 | num_convs=0, 19 | concat_input=False, 20 | dropout_ratio=0, 21 | num_classes=19, 22 | norm_cfg=norm_cfg, 23 | loss_decode=dict( 24 | type='CrossEntropyLoss', 25 | use_sigmoid=False, 26 | loss_weight=1.0, 27 | class_weight=[ 28 | 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352, 29 | 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905, 30 | 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587, 31 | 10.396974, 10.055647 32 | ])), 33 | # model training and testing settings 34 | train_cfg=dict(sampler=None), 35 | test_cfg=dict(mode='whole')) 36 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/danet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DAHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | pam_channels=64, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/deeplabv3_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='ASPPHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dilations=(1, 12, 24, 36), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/dmnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DMHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | filter_sizes=(1, 3, 5, 7), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=dict(type='SyncBN', requires_grad=True), 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/dnl_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='DNLHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dropout_ratio=0.1, 23 | reduction=2, 24 | use_scale=True, 25 | mode='embedded_gaussian', 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/dpt_vit-b16.py: -------------------------------------------------------------------------------- 1 | norm_cfg = dict(type='SyncBN', requires_grad=True) 2 | model = dict( 3 | type='EncoderDecoder', 4 | pretrained='pretrain/vit-b16_p16_224-80ecf9dd.pth', # noqa 5 | backbone=dict( 6 | type='VisionTransformer', 7 | img_size=224, 8 | embed_dims=768, 9 | num_layers=12, 10 | num_heads=12, 11 | out_indices=(2, 5, 8, 11), 12 | final_norm=False, 13 | with_cls_token=True, 14 | output_cls_token=True), 15 | decode_head=dict( 16 | type='DPTHead', 17 | in_channels=(768, 768, 768, 768), 18 | channels=256, 19 | embed_dims=768, 20 | post_process_channels=[96, 192, 384, 768], 21 | num_classes=150, 22 | readout_type='project', 23 | input_transform='multiple_select', 24 | in_index=(0, 1, 2, 3), 25 | norm_cfg=norm_cfg, 26 | loss_decode=dict( 27 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | auxiliary_head=None, 29 | # model training and testing settings 30 | train_cfg=dict(), 31 | test_cfg=dict(mode='whole')) # yapf: disable 32 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/emanet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='EMAHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=256, 22 | ema_channels=512, 23 | num_bases=64, 24 | num_stages=3, 25 | momentum=0.1, 26 | dropout_ratio=0.1, 27 | num_classes=19, 28 | norm_cfg=norm_cfg, 29 | align_corners=False, 30 | loss_decode=dict( 31 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 32 | auxiliary_head=dict( 33 | type='FCNHead', 34 | in_channels=1024, 35 | in_index=2, 36 | channels=256, 37 | num_convs=1, 38 | concat_input=False, 39 | dropout_ratio=0.1, 40 | num_classes=19, 41 | norm_cfg=norm_cfg, 42 | align_corners=False, 43 | loss_decode=dict( 44 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 45 | # model training and testing settings 46 | train_cfg=dict(), 47 | test_cfg=dict(mode='whole')) 48 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/erfnet_fcn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='ERFNet', 8 | in_channels=3, 9 | enc_downsample_channels=(16, 64, 128), 10 | enc_stage_non_bottlenecks=(5, 8), 11 | enc_non_bottleneck_dilations=(2, 4, 8, 16), 12 | enc_non_bottleneck_channels=(64, 128), 13 | dec_upsample_channels=(64, 16), 14 | dec_stages_non_bottleneck=(2, 2), 15 | dec_non_bottleneck_channels=(64, 16), 16 | dropout_ratio=0.1, 17 | init_cfg=None), 18 | decode_head=dict( 19 | type='FCNHead', 20 | in_channels=16, 21 | channels=128, 22 | num_convs=1, 23 | concat_input=False, 24 | dropout_ratio=0.1, 25 | num_classes=19, 26 | norm_cfg=norm_cfg, 27 | align_corners=False, 28 | loss_decode=dict( 29 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 30 | # model training and testing settings 31 | train_cfg=dict(), 32 | test_cfg=dict(mode='whole')) 33 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/fcn_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='FCNHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | num_convs=2, 23 | concat_input=True, 24 | dropout_ratio=0.1, 25 | num_classes=19, 26 | norm_cfg=norm_cfg, 27 | align_corners=False, 28 | loss_decode=dict( 29 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 30 | auxiliary_head=dict( 31 | type='FCNHead', 32 | in_channels=1024, 33 | in_index=2, 34 | channels=256, 35 | num_convs=1, 36 | concat_input=False, 37 | dropout_ratio=0.1, 38 | num_classes=19, 39 | norm_cfg=norm_cfg, 40 | align_corners=False, 41 | loss_decode=dict( 42 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 43 | # model training and testing settings 44 | train_cfg=dict(), 45 | test_cfg=dict(mode='whole')) 46 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/fpn_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 1, 1), 12 | strides=(1, 2, 2, 2), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[256, 512, 1024, 2048], 20 | out_channels=256, 21 | num_outs=4), 22 | decode_head=dict( 23 | type='FPNHead', 24 | in_channels=[256, 256, 256, 256], 25 | in_index=[0, 1, 2, 3], 26 | feature_strides=[4, 8, 16, 32], 27 | channels=128, 28 | dropout_ratio=0.1, 29 | num_classes=19, 30 | norm_cfg=norm_cfg, 31 | align_corners=False, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | # model training and testing settings 35 | train_cfg=dict(), 36 | test_cfg=dict(mode='whole')) 37 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/gcnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='GCHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | ratio=1 / 4., 23 | pooling_type='att', 24 | fusion_types=('channel_add', ), 25 | dropout_ratio=0.1, 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/isanet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='ISAHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | isa_channels=256, 23 | down_factor=(8, 8), 24 | dropout_ratio=0.1, 25 | num_classes=19, 26 | norm_cfg=norm_cfg, 27 | align_corners=False, 28 | loss_decode=dict( 29 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 30 | auxiliary_head=dict( 31 | type='FCNHead', 32 | in_channels=1024, 33 | in_index=2, 34 | channels=256, 35 | num_convs=1, 36 | concat_input=False, 37 | dropout_ratio=0.1, 38 | num_classes=19, 39 | norm_cfg=norm_cfg, 40 | align_corners=False, 41 | loss_decode=dict( 42 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 43 | # model training and testing settings 44 | train_cfg=dict(), 45 | test_cfg=dict(mode='whole')) 46 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/lraspp_m-v3-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | backbone=dict( 6 | type='MobileNetV3', 7 | arch='large', 8 | out_indices=(1, 3, 16), 9 | norm_cfg=norm_cfg), 10 | decode_head=dict( 11 | type='LRASPPHead', 12 | in_channels=(16, 24, 960), 13 | in_index=(0, 1, 2), 14 | channels=128, 15 | input_transform='multiple_select', 16 | dropout_ratio=0.1, 17 | num_classes=19, 18 | norm_cfg=norm_cfg, 19 | act_cfg=dict(type='ReLU'), 20 | align_corners=False, 21 | loss_decode=dict( 22 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 23 | # model training and testing settings 24 | train_cfg=dict(), 25 | test_cfg=dict(mode='whole')) 26 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/nonlocal_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='NLHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | dropout_ratio=0.1, 23 | reduction=2, 24 | use_scale=True, 25 | mode='embedded_gaussian', 26 | num_classes=19, 27 | norm_cfg=norm_cfg, 28 | align_corners=False, 29 | loss_decode=dict( 30 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 31 | auxiliary_head=dict( 32 | type='FCNHead', 33 | in_channels=1024, 34 | in_index=2, 35 | channels=256, 36 | num_convs=1, 37 | concat_input=False, 38 | dropout_ratio=0.1, 39 | num_classes=19, 40 | norm_cfg=norm_cfg, 41 | align_corners=False, 42 | loss_decode=dict( 43 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 44 | # model training and testing settings 45 | train_cfg=dict(), 46 | test_cfg=dict(mode='whole')) 47 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/pspnet_r50-d8.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='PSPHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | pool_scales=(1, 2, 3, 6), 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/segformer_mit-b0.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='MixVisionTransformer', 8 | in_channels=3, 9 | embed_dims=32, 10 | num_stages=4, 11 | num_layers=[2, 2, 2, 2], 12 | num_heads=[1, 2, 5, 8], 13 | patch_sizes=[7, 3, 3, 3], 14 | sr_ratios=[8, 4, 2, 1], 15 | out_indices=(0, 1, 2, 3), 16 | mlp_ratio=4, 17 | qkv_bias=True, 18 | drop_rate=0.0, 19 | attn_drop_rate=0.0, 20 | drop_path_rate=0.1), 21 | decode_head=dict( 22 | type='SegformerHead', 23 | in_channels=[32, 64, 160, 256], 24 | in_index=[0, 1, 2, 3], 25 | channels=256, 26 | dropout_ratio=0.1, 27 | num_classes=19, 28 | norm_cfg=norm_cfg, 29 | align_corners=False, 30 | loss_decode=dict( 31 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 32 | # model training and testing settings 33 | train_cfg=dict(), 34 | test_cfg=dict(mode='whole')) 35 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/segmenter_vit-b16_mask.py: -------------------------------------------------------------------------------- 1 | checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_base_p16_384_20220308-96dfe169.pth' # noqa 2 | # model settings 3 | backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) 4 | model = dict( 5 | type='EncoderDecoder', 6 | pretrained=checkpoint, 7 | backbone=dict( 8 | type='VisionTransformer', 9 | img_size=(512, 512), 10 | patch_size=16, 11 | in_channels=3, 12 | embed_dims=768, 13 | num_layers=12, 14 | num_heads=12, 15 | drop_path_rate=0.1, 16 | attn_drop_rate=0.0, 17 | drop_rate=0.0, 18 | final_norm=True, 19 | norm_cfg=backbone_norm_cfg, 20 | with_cls_token=True, 21 | interpolate_mode='bicubic', 22 | ), 23 | decode_head=dict( 24 | type='SegmenterMaskTransformerHead', 25 | in_channels=768, 26 | channels=768, 27 | num_classes=150, 28 | num_layers=2, 29 | num_heads=12, 30 | embed_dims=768, 31 | dropout_ratio=0.0, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 34 | ), 35 | test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480)), 36 | ) 37 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/upernet_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 1, 1), 12 | strides=(1, 2, 2, 2), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='UPerHead', 19 | in_channels=[256, 512, 1024, 2048], 20 | in_index=[0, 1, 2, 3], 21 | pool_scales=(1, 2, 3, 6), 22 | channels=512, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | loss_decode=dict( 28 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | auxiliary_head=dict( 30 | type='FCNHead', 31 | in_channels=1024, 32 | in_index=2, 33 | channels=256, 34 | num_convs=1, 35 | concat_input=False, 36 | dropout_ratio=0.1, 37 | num_classes=19, 38 | norm_cfg=norm_cfg, 39 | align_corners=False, 40 | loss_decode=dict( 41 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 42 | # model training and testing settings 43 | train_cfg=dict(), 44 | test_cfg=dict(mode='whole')) 45 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=16000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_20k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=20000) 8 | checkpoint_config = dict(by_epoch=False, interval=2000) 9 | evaluation = dict(interval=2000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_320k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=320000) 8 | checkpoint_config = dict(by_epoch=False, interval=32000) 9 | evaluation = dict(interval=32000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /segmentation/configs/celebahq_mask/bisenetv1_r18-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_baseline.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/bisenetv1_r18-d32.py', 3 | '../_base_/datasets/celebahqmask.py', '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_160k.py' 5 | ] 6 | data = dict(samples_per_gpu=8) 7 | norm_cfg = dict(type='SyncBN', requires_grad=True) 8 | model = dict( 9 | decode_head=dict(num_classes=18), 10 | auxiliary_head=[ 11 | dict( 12 | type='FCNHead', 13 | in_channels=128, 14 | channels=64, 15 | num_convs=1, 16 | num_classes=18, 17 | in_index=1, 18 | norm_cfg=norm_cfg, 19 | concat_input=False, 20 | align_corners=False, 21 | loss_decode=dict( 22 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 23 | dict( 24 | type='FCNHead', 25 | in_channels=128, 26 | channels=64, 27 | num_convs=1, 28 | num_classes=18, 29 | in_index=2, 30 | norm_cfg=norm_cfg, 31 | concat_input=False, 32 | align_corners=False, 33 | loss_decode=dict( 34 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 35 | ]) 36 | lr_config = dict(warmup='linear', warmup_iters=1000) 37 | optimizer = dict(lr=0.005) 38 | -------------------------------------------------------------------------------- /segmentation/configs/celebahq_mask/bisenetv1_r18-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_repfusion.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | _base_ = './bisenetv1_r18-d32_lr5e-3_1x16_448x448_160k_coco-celebahq_mask.py' 4 | model = dict( 5 | backbone=dict( 6 | backbone_cfg=dict( 7 | init_cfg=dict( 8 | type='Pretrained', 9 | checkpoint='' , # Put the disilled checkpoint hear 10 | prefix='student.backbone.') 11 | ) 12 | ), 13 | ) 14 | -------------------------------------------------------------------------------- /segmentation/configs/celebahq_mask/bisenetv1_r50-d32_lr5e-3_2x8_448x448_160k_coco-celebahq_mask_repfussion.py: -------------------------------------------------------------------------------- 1 | 2 | _base_ = './bisenetv1_r50-d32_lr5e-3_4x4_448x448_160k_coco-celebahq_mask.py' 3 | model = dict( 4 | backbone=dict( 5 | backbone_cfg=dict( 6 | init_cfg=dict( 7 | type='Pretrained', 8 | checkpoint='' , # Put the disilled checkpoint hear 9 | prefix='student.backbone.') 10 | ))) 11 | 12 | -------------------------------------------------------------------------------- /segmentation/mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import (get_root_logger, init_random_seed, set_random_seed, 5 | train_segmentor) 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 9 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 10 | 'show_result_pyplot', 'init_random_seed' 11 | ] 12 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import (OPTIMIZER_BUILDERS, build_optimizer, 3 | build_optimizer_constructor) 4 | from .evaluation import * # noqa: F401, F403 5 | from .hook import * # noqa: F401, F403 6 | from .optimizers import * # noqa: F401, F403 7 | from .seg import * # noqa: F401, F403 8 | from .utils import * # noqa: F401, F403 9 | 10 | __all__ = [ 11 | 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' 12 | ] 13 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS 5 | from mmcv.utils import Registry, build_from_cfg 6 | 7 | OPTIMIZER_BUILDERS = Registry( 8 | 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) 9 | 10 | 11 | def build_optimizer_constructor(cfg): 12 | constructor_type = cfg.get('type') 13 | if constructor_type in OPTIMIZER_BUILDERS: 14 | return build_from_cfg(cfg, OPTIMIZER_BUILDERS) 15 | elif constructor_type in MMCV_OPTIMIZER_BUILDERS: 16 | return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) 17 | else: 18 | raise KeyError(f'{constructor_type} is not registered ' 19 | 'in the optimizer builder registry.') 20 | 21 | 22 | def build_optimizer(model, cfg): 23 | optimizer_cfg = copy.deepcopy(cfg) 24 | constructor_type = optimizer_cfg.pop('constructor', 25 | 'DefaultOptimizerConstructor') 26 | paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) 27 | optim_constructor = build_optimizer_constructor( 28 | dict( 29 | type=constructor_type, 30 | optimizer_cfg=optimizer_cfg, 31 | paramwise_cfg=paramwise_cfg)) 32 | optimizer = optim_constructor(model) 33 | return optimizer 34 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .class_names import get_classes, get_palette 3 | from .eval_hooks import DistEvalHook, EvalHook 4 | from .metrics import (eval_metrics, intersect_and_union, mean_dice, 5 | mean_fscore, mean_iou, pre_eval_to_metrics) 6 | 7 | __all__ = [ 8 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 9 | 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', 10 | 'intersect_and_union' 11 | ] 12 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/hook/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .wandblogger_hook import MMSegWandbHook 3 | 4 | __all__ = ['MMSegWandbHook'] 5 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .layer_decay_optimizer_constructor import ( 3 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) 4 | 5 | __all__ = [ 6 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' 7 | ] 8 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_pixel_sampler 3 | from .sampler import BasePixelSampler, OHEMPixelSampler 4 | 5 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | PIXEL_SAMPLERS = Registry('pixel sampler') 5 | 6 | 7 | def build_pixel_sampler(cfg, **default_args): 8 | """Build pixel sampler for segmentation map.""" 9 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 10 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_pixel_sampler import BasePixelSampler 3 | from .ohem_pixel_sampler import OHEMPixelSampler 4 | 5 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_util import check_dist_init, sync_random_seed 3 | from .misc import add_prefix 4 | 5 | __all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed'] 6 | -------------------------------------------------------------------------------- /segmentation/mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def add_prefix(inputs, prefix): 3 | """Add prefix for dict. 4 | 5 | Args: 6 | inputs (dict): The input dict with str keys. 7 | prefix (str): The prefix to add. 8 | 9 | Returns: 10 | 11 | dict: The dict with keys updated with ``prefix``. 12 | """ 13 | 14 | outputs = dict() 15 | for name, value in inputs.items(): 16 | outputs[f'{prefix}.{name}'] = value 17 | 18 | return outputs 19 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/celebahqmask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class CelebAHQMask(CustomDataset): 10 | """Face Occluded dataset. 11 | 12 | Args: 13 | split (str): Split txt file for Pascal VOC. 14 | """ 15 | 16 | CLASSES = ('skin', 'l_brow', 'r_brow', 17 | 'l_eye', 'r_eye', 'eye_g', 18 | 'l_ear', 'r_ear', 'ear_r', 19 | 'nose', 'mouth', 'u_lip', 20 | 'l_lip', 'neck', 'neck_l', 21 | 'cloth', 'hair', 'hat') 22 | 23 | PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], 24 | [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], 25 | [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], 26 | [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], 27 | [0, 32, 192], [128, 128, 224]] 28 | 29 | def __init__(self, split, **kwargs): 30 | super(CelebAHQMask, self).__init__( 31 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 32 | assert osp.exists(self.img_dir) and self.split is not None 33 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class ChaseDB1Dataset(CustomDataset): 9 | """Chase_db1 dataset. 10 | 11 | In segmentation map annotation for Chase_db1, 0 stands for background, 12 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False. 13 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_1stHO.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(ChaseDB1Dataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_1stHO.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/dark_zurich.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class DarkZurichDataset(CityscapesDataset): 8 | """DarkZurichDataset dataset.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__( 12 | img_suffix='_rgb_anon.png', 13 | seg_map_suffix='_gt_labelTrainIds.png', 14 | **kwargs) 15 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class DRIVEDataset(CustomDataset): 9 | """DRIVE dataset. 10 | 11 | In segmentation map annotation for DRIVE, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_manual1.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(DRIVEDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_manual1.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/face.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class FaceOccludedDataset(CustomDataset): 10 | """Face Occluded dataset. 11 | 12 | Args: 13 | split (str): Split txt file for Pascal VOC. 14 | """ 15 | 16 | CLASSES = ('background', 'face') 17 | 18 | PALETTE = [[0, 0, 0], [128, 0, 0]] 19 | 20 | def __init__(self, split, **kwargs): 21 | super(FaceOccludedDataset, self).__init__( 22 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 23 | assert osp.exists(self.img_dir) and self.split is not None 24 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class HRFDataset(CustomDataset): 9 | """HRF dataset. 10 | 11 | In segmentation map annotation for HRF, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(HRFDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/isprs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .custom import CustomDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class ISPRSDataset(CustomDataset): 8 | """ISPRS dataset. 9 | 10 | In segmentation map annotation for LoveDA, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', 15 | 'car', 'clutter') 16 | 17 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]] 19 | 20 | def __init__(self, **kwargs): 21 | super(ISPRSDataset, self).__init__( 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=True, 25 | **kwargs) 26 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/night_driving.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class NightDrivingDataset(CityscapesDataset): 8 | """NightDrivingDataset dataset.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__( 12 | img_suffix='_leftImg8bit.png', 13 | seg_map_suffix='_gtCoarse_labelTrainIds.png', 14 | **kwargs) 15 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .compose import Compose 3 | from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, 4 | Transpose, to_tensor) 5 | from .loading import LoadAnnotations, LoadImageFromFile 6 | from .test_time_aug import MultiScaleFlipAug 7 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 8 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 9 | RandomFlip, RandomMosaic, RandomRotate, Rerange, 10 | Resize, RGB2Gray, SegRescale) 11 | 12 | __all__ = [ 13 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 14 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 15 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 16 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 17 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 18 | 'RandomMosaic' 19 | ] 20 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # flake8: noqa 3 | import warnings 4 | 5 | from .formatting import * 6 | 7 | warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be ' 8 | 'deprecated in 2021, please replace it with ' 9 | 'mmseg.datasets.pipelines.formatting.') 10 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/potsdam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .custom import CustomDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class PotsdamDataset(CustomDataset): 8 | """ISPRS Potsdam dataset. 9 | 10 | In segmentation map annotation for Potsdam dataset, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', 15 | 'car', 'clutter') 16 | 17 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]] 19 | 20 | def __init__(self, **kwargs): 21 | super(PotsdamDataset, self).__init__( 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=True, 25 | **kwargs) 26 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/stare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class STAREDataset(CustomDataset): 10 | """STARE dataset. 11 | 12 | In segmentation map annotation for STARE, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '.ah.png'. 16 | """ 17 | 18 | CLASSES = ('background', 'vessel') 19 | 20 | PALETTE = [[120, 120, 120], [6, 230, 230]] 21 | 22 | def __init__(self, **kwargs): 23 | super(STAREDataset, self).__init__( 24 | img_suffix='.png', 25 | seg_map_suffix='.ah.png', 26 | reduce_zero_label=False, 27 | **kwargs) 28 | assert osp.exists(self.img_dir) 29 | -------------------------------------------------------------------------------- /segmentation/mmseg/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class PascalVOCDataset(CustomDataset): 10 | """Pascal VOC dataset. 11 | 12 | Args: 13 | split (str): Split txt file for Pascal VOC. 14 | """ 15 | 16 | CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 17 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 18 | 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 19 | 'train', 'tvmonitor') 20 | 21 | PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], 22 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], 23 | [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], 24 | [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], 25 | [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] 26 | 27 | def __init__(self, split, **kwargs): 28 | super(PascalVOCDataset, self).__init__( 29 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 30 | assert osp.exists(self.img_dir) and self.split is not None 31 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 4 | build_head, build_loss, build_segmentor) 5 | from .decode_heads import * # noqa: F401,F403 6 | from .losses import * # noqa: F401,F403 7 | from .necks import * # noqa: F401,F403 8 | from .segmentors import * # noqa: F401,F403 9 | 10 | __all__ = [ 11 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 12 | 'build_head', 'build_loss', 'build_segmentor' 13 | ] 14 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .beit import BEiT 3 | from .bisenetv1 import BiSeNetV1 4 | from .bisenetv2 import BiSeNetV2 5 | from .cgnet import CGNet 6 | from .erfnet import ERFNet 7 | from .fast_scnn import FastSCNN 8 | from .hrnet import HRNet 9 | from .icnet import ICNet 10 | from .mae import MAE 11 | from .mit import MixVisionTransformer 12 | from .mobilenet_v2 import MobileNetV2 13 | from .mobilenet_v3 import MobileNetV3 14 | from .resnest import ResNeSt 15 | from .resnet import ResNet, ResNetV1c, ResNetV1d 16 | from .resnext import ResNeXt 17 | from .stdc import STDCContextPathNet, STDCNet 18 | from .swin import SwinTransformer 19 | from .timm_backbone import TIMMBackbone 20 | from .twins import PCPVT, SVT 21 | from .unet import UNet 22 | from .vit import VisionTransformer 23 | 24 | __all__ = [ 25 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 26 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 27 | 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 28 | 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 29 | 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE' 30 | ] 31 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmcv.cnn import MODELS as MMCV_MODELS 5 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 6 | from mmcv.utils import Registry 7 | 8 | MODELS = Registry('models', parent=MMCV_MODELS) 9 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 10 | 11 | BACKBONES = MODELS 12 | NECKS = MODELS 13 | HEADS = MODELS 14 | LOSSES = MODELS 15 | SEGMENTORS = MODELS 16 | 17 | 18 | def build_backbone(cfg): 19 | """Build backbone.""" 20 | return BACKBONES.build(cfg) 21 | 22 | 23 | def build_neck(cfg): 24 | """Build neck.""" 25 | return NECKS.build(cfg) 26 | 27 | 28 | def build_head(cfg): 29 | """Build head.""" 30 | return HEADS.build(cfg) 31 | 32 | 33 | def build_loss(cfg): 34 | """Build loss.""" 35 | return LOSSES.build(cfg) 36 | 37 | 38 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 39 | """Build segmentor.""" 40 | if train_cfg is not None or test_cfg is not None: 41 | warnings.warn( 42 | 'train_cfg and test_cfg is deprecated, ' 43 | 'please specify them in model', UserWarning) 44 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 45 | 'train_cfg specified in both outer field and model field ' 46 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 47 | 'test_cfg specified in both outer field and model field ' 48 | return SEGMENTORS.build( 49 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 50 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/decode_heads/cc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | try: 8 | from mmcv.ops import CrissCrossAttention 9 | except ModuleNotFoundError: 10 | CrissCrossAttention = None 11 | 12 | 13 | @HEADS.register_module() 14 | class CCHead(FCNHead): 15 | """CCNet: Criss-Cross Attention for Semantic Segmentation. 16 | 17 | This head is the implementation of `CCNet 18 | `_. 19 | 20 | Args: 21 | recurrence (int): Number of recurrence of Criss Cross Attention 22 | module. Default: 2. 23 | """ 24 | 25 | def __init__(self, recurrence=2, **kwargs): 26 | if CrissCrossAttention is None: 27 | raise RuntimeError('Please install mmcv-full for ' 28 | 'CrissCrossAttention ops') 29 | super(CCHead, self).__init__(num_convs=2, **kwargs) 30 | self.recurrence = recurrence 31 | self.cca = CrissCrossAttention(self.channels) 32 | 33 | def forward(self, inputs): 34 | """Forward function.""" 35 | x = self._transform_inputs(inputs) 36 | output = self.convs[0](x) 37 | for _ in range(self.recurrence): 38 | output = self.cca(output) 39 | output = self.convs[1](output) 40 | if self.concat_input: 41 | output = self.conv_cat(torch.cat([x, output], dim=1)) 42 | output = self.cls_seg(output) 43 | return output 44 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 4 | cross_entropy, mask_cross_entropy) 5 | from .dice_loss import DiceLoss 6 | from .focal_loss import FocalLoss 7 | from .lovasz_loss import LovaszLoss 8 | from .tversky_loss import TverskyLoss 9 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 10 | 11 | __all__ = [ 12 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 13 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 14 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 15 | 'FocalLoss', 'TverskyLoss' 16 | ] 17 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .featurepyramid import Feature2Pyramid 3 | from .fpn import FPN 4 | from .ic_neck import ICNeck 5 | from .jpu import JPU 6 | from .mla_neck import MLANeck 7 | from .multilevel_neck import MultiLevelNeck 8 | 9 | __all__ = [ 10 | 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' 11 | ] 12 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseSegmentor 3 | from .cascade_encoder_decoder import CascadeEncoderDecoder 4 | from .encoder_decoder import EncoderDecoder 5 | 6 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] 7 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .embed import PatchEmbed 3 | from .inverted_residual import InvertedResidual, InvertedResidualV3 4 | from .make_divisible import make_divisible 5 | from .res_layer import ResLayer 6 | from .se_layer import SELayer 7 | from .self_attention_block import SelfAttentionBlock 8 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, 9 | nlc_to_nchw) 10 | from .up_conv_block import UpConvBlock 11 | 12 | __all__ = [ 13 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 14 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 15 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' 16 | ] 17 | -------------------------------------------------------------------------------- /segmentation/mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /segmentation/mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .encoding import Encoding 3 | from .wrappers import Upsample, resize 4 | 5 | __all__ = ['Upsample', 'resize', 'Encoding'] 6 | -------------------------------------------------------------------------------- /segmentation/mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .misc import find_latest_checkpoint 5 | from .set_env import setup_multi_processes 6 | from .util_distribution import build_ddp, build_dp, get_device 7 | 8 | __all__ = [ 9 | 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 10 | 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device' 11 | ] 12 | -------------------------------------------------------------------------------- /segmentation/mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmseg 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /segmentation/mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmseg". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | 26 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 27 | 28 | return logger 29 | -------------------------------------------------------------------------------- /segmentation/mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.30.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /segmentation/requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/optional.txt 2 | -r requirements/runtime.txt 3 | -r requirements/tests.txt 4 | -------------------------------------------------------------------------------- /segmentation/requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | myst-parser 3 | -e git+https://github.com/gaotongxiao/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 4 | sphinx==4.0.2 5 | sphinx_copybutton 6 | sphinx_markdown_tables 7 | -------------------------------------------------------------------------------- /segmentation/requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcls>=0.20.1 2 | mmcv-full>=1.4.4,<1.7.0 3 | -------------------------------------------------------------------------------- /segmentation/requirements/optional.txt: -------------------------------------------------------------------------------- 1 | cityscapesscripts 2 | -------------------------------------------------------------------------------- /segmentation/requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | prettytable 3 | torch 4 | torchvision 5 | -------------------------------------------------------------------------------- /segmentation/requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | mmcls>=0.20.1 3 | numpy 4 | packaging 5 | prettytable 6 | -------------------------------------------------------------------------------- /segmentation/requirements/tests.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | flake8 3 | interrogate 4 | pytest 5 | xdoctest>=0.10.0 6 | yapf 7 | -------------------------------------------------------------------------------- /segmentation/setup.cfg: -------------------------------------------------------------------------------- 1 | [yapf] 2 | based_on_style = pep8 3 | blank_line_before_nested_class_or_def = true 4 | split_before_expression_after_opening_paren = true 5 | 6 | [isort] 7 | line_length = 79 8 | multi_line_output = 0 9 | extra_standard_library = setuptools 10 | known_first_party = mmseg 11 | known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,packaging,prettytable,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,torch,ts 12 | no_lines_before = STDLIB,LOCALFOLDER 13 | default_section = THIRDPARTY 14 | 15 | # ignore-words-list needs to be lowercase format. For example, if we want to 16 | # ignore word "BA", then we need to append "ba" to ignore-words-list rather 17 | # than "BA" 18 | [codespell] 19 | skip = *.po,*.ts,*.ipynb 20 | count = 21 | quiet-level = 3 22 | ignore-words-list = formating,sur,hist,dota,ba,warmup 23 | -------------------------------------------------------------------------------- /segmentation/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CHECKPOINT=$2 3 | GPUS=$3 4 | NNODES=${NNODES:-1} 5 | NODE_RANK=${NODE_RANK:-0} 6 | PORT=${PORT:-29500} 7 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 8 | 9 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 10 | python -m torch.distributed.launch \ 11 | --nnodes=$NNODES \ 12 | --node_rank=$NODE_RANK \ 13 | --master_addr=$MASTER_ADDR \ 14 | --nproc_per_node=$GPUS \ 15 | --master_port=$PORT \ 16 | $(dirname "$0")/test.py \ 17 | $CONFIG \ 18 | $CHECKPOINT \ 19 | --launcher pytorch \ 20 | ${@:4} 21 | -------------------------------------------------------------------------------- /segmentation/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | GPUS=$2 3 | NNODES=${NNODES:-1} 4 | NODE_RANK=${NODE_RANK:-0} 5 | PORT=${PORT:-29500} 6 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch \ 10 | --nnodes=$NNODES \ 11 | --node_rank=$NODE_RANK \ 12 | --master_addr=$MASTER_ADDR \ 13 | --nproc_per_node=$GPUS \ 14 | --master_port=$PORT \ 15 | $(dirname "$0")/train.py \ 16 | $CONFIG \ 17 | --launcher pytorch ${@:3} 18 | -------------------------------------------------------------------------------- /segmentation/tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Process a checkpoint to be published') 11 | parser.add_argument('in_file', help='input checkpoint filename') 12 | parser.add_argument('out_file', help='output checkpoint filename') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def process_checkpoint(in_file, out_file): 18 | checkpoint = torch.load(in_file, map_location='cpu') 19 | # remove optimizer for smaller file size 20 | if 'optimizer' in checkpoint: 21 | del checkpoint['optimizer'] 22 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 23 | # add the code here. 24 | torch.save(checkpoint, out_file) 25 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 26 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 27 | subprocess.Popen(['mv', out_file, final_file]) 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | process_checkpoint(args.in_file, args.out_file) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /segmentation/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /segmentation/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-4} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 24 | --------------------------------------------------------------------------------