├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── common └── __init__.py ├── config ├── classification │ ├── finetune_higher_res_in1k │ │ └── mobilevit_v2.yaml │ ├── finetune_in21k_to_1k │ │ └── mobilevit_v2.yaml │ ├── imagenet │ │ ├── mobilenet_v1.yaml │ │ ├── mobilenet_v2.yaml │ │ ├── mobilenet_v3.yaml │ │ ├── mobilevit.yaml │ │ ├── mobilevit_v2.yaml │ │ ├── resnet.yaml │ │ ├── resnet_adv.yaml │ │ ├── resnet_eval │ │ │ ├── imagenet_a.yaml │ │ │ ├── imagenet_r.yaml │ │ │ └── imagenet_sketch.yaml │ │ └── vit.yaml │ └── imagenet_21k │ │ └── mobilevit_v2.yaml ├── detection │ ├── mask_rcnn_coco │ │ └── resnet_fpn.yaml │ └── ssd_coco │ │ ├── mobilevit.yaml │ │ ├── mobilevit_v2.yaml │ │ └── resnet.yaml ├── distillation │ └── teacher_resnet101_student_mobilenet_v1.yaml ├── multi_modal_img_text │ └── clip_vit.yaml └── segmentation │ ├── ade20k │ ├── deeplabv3_mobilenetv2.yaml │ ├── deeplabv3_mobilevitv2.yaml │ ├── deeplabv3_resnet50.yaml │ └── pspnet_mobilevitv2.yaml │ └── pascal_voc │ ├── deeplabv3_mobilevit.yaml │ ├── deeplabv3_mobilevitv2.yaml │ └── pspnet_mobilevitv2.yaml ├── conftest.py ├── constraints.txt ├── cvnets ├── __init__.py ├── anchor_generator │ ├── __init__.py │ ├── base_anchor_generator.py │ └── ssd_anchor_generator.py ├── image_projection_layers │ ├── __init__.py │ ├── attention_pool_2d.py │ ├── base_image_projection.py │ ├── global_pool_2d.py │ └── simple_projection_head.py ├── layers │ ├── __init__.py │ ├── activation │ │ ├── __init__.py │ │ ├── gelu.py │ │ ├── hard_sigmoid.py │ │ ├── hard_swish.py │ │ ├── leaky_relu.py │ │ ├── prelu.py │ │ ├── relu.py │ │ ├── relu6.py │ │ ├── sigmoid.py │ │ ├── swish.py │ │ └── tanh.py │ ├── adaptive_pool.py │ ├── base_layer.py │ ├── conv_layer.py │ ├── dropout.py │ ├── embedding.py │ ├── flatten.py │ ├── global_pool.py │ ├── identity.py │ ├── linear_attention.py │ ├── linear_layer.py │ ├── multi_head_attention.py │ ├── normalization │ │ ├── __init__.py │ │ ├── batch_norm.py │ │ ├── group_norm.py │ │ ├── instance_norm.py │ │ ├── layer_norm.py │ │ └── sync_batch_norm.py │ ├── normalization_layers.py │ ├── pixel_shuffle.py │ ├── pooling.py │ ├── positional_embedding.py │ ├── positional_encoding.py │ ├── random_layers.py │ ├── single_head_attention.py │ ├── softmax.py │ ├── stochastic_depth.py │ ├── token_merging.py │ └── upsample.py ├── matcher_det │ ├── __init__.py │ ├── base_matcher.py │ └── ssd_matcher.py ├── misc │ ├── __init__.py │ ├── averaging_utils.py │ ├── box_utils.py │ ├── common.py │ ├── init_utils.py │ └── third_party │ │ ├── __init__.py │ │ └── ssd_utils.py ├── models │ ├── __init__.py │ ├── audio_classification │ │ ├── __init__.py │ │ ├── audio_byteformer.py │ │ └── base_audio_classification.py │ ├── base_model.py │ ├── classification │ │ ├── __init__.py │ │ ├── base_image_encoder.py │ │ ├── byteformer.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ ├── byteformer.py │ │ │ ├── efficientnet.py │ │ │ ├── fastvit.py │ │ │ ├── mobilenetv1.py │ │ │ ├── mobilenetv2.py │ │ │ ├── mobilenetv3.py │ │ │ ├── mobileone.py │ │ │ ├── mobilevit.py │ │ │ ├── mobilevit_v2.py │ │ │ ├── regnet.py │ │ │ ├── resnet.py │ │ │ ├── swin_transformer.py │ │ │ └── vit.py │ │ ├── efficientnet.py │ │ ├── fastvit.py │ │ ├── mobilenetv1.py │ │ ├── mobilenetv2.py │ │ ├── mobilenetv3.py │ │ ├── mobileone.py │ │ ├── mobilevit.py │ │ ├── mobilevit_v2.py │ │ ├── regnet.py │ │ ├── resnet.py │ │ ├── swin_transformer.py │ │ └── vit.py │ ├── detection │ │ ├── __init__.py │ │ ├── base_detection.py │ │ ├── mask_rcnn.py │ │ ├── ssd.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── rcnn_utils.py │ ├── multi_modal_img_text │ │ ├── __init__.py │ │ ├── base_multi_modal_img_text.py │ │ └── clip.py │ └── segmentation │ │ ├── __init__.py │ │ ├── base_seg.py │ │ ├── enc_dec.py │ │ └── heads │ │ ├── __init__.py │ │ ├── base_seg_head.py │ │ ├── deeplabv3.py │ │ ├── pspnet.py │ │ └── simple_seg_head.py ├── modules │ ├── __init__.py │ ├── aspp_block.py │ ├── base_module.py │ ├── efficientnet.py │ ├── fastvit.py │ ├── feature_pyramid.py │ ├── mobilenetv2.py │ ├── mobileone_block.py │ ├── mobilevit_block.py │ ├── pspnet_module.py │ ├── regnet_modules.py │ ├── resnet_modules.py │ ├── squeeze_excitation.py │ ├── ssd_heads.py │ ├── swin_transformer_block.py │ ├── transformer.py │ └── windowed_transformer.py ├── neural_augmentor │ ├── __init__.py │ ├── neural_aug.py │ └── utils │ │ ├── __init__.py │ │ └── neural_aug_utils.py └── text_encoders │ ├── __init__.py │ ├── base_text_encoder.py │ └── transformer.py ├── data ├── __init__.py ├── collate_fns │ ├── __init__.py │ ├── byteformer_collate_functions.py │ └── collate_functions.py ├── data_loaders.py ├── datasets │ ├── __init__.py │ ├── audio_classification │ │ ├── __init__.py │ │ └── speech_commands_v2.py │ ├── classification │ │ ├── __init__.py │ │ ├── base_image_classification_dataset.py │ │ ├── base_imagenet_shift_dataset.py │ │ ├── imagenet.py │ │ ├── imagenet_a.py │ │ ├── imagenet_r.py │ │ ├── imagenet_sketch.py │ │ ├── imagenet_synsets.py │ │ ├── imagenet_v2.py │ │ └── places365.py │ ├── dataset_base.py │ ├── detection │ │ ├── __init__.py │ │ ├── base_detection.py │ │ ├── coco_base.py │ │ ├── coco_mask_rcnn.py │ │ └── coco_ssd.py │ ├── multi_modal_img_text │ │ ├── __init__.py │ │ ├── base_multi_modal_img_text.py │ │ ├── flickr.py │ │ ├── img_text_tar_dataset.py │ │ └── zero_shot │ │ │ ├── __init__.py │ │ │ ├── base_zero_shot.py │ │ │ ├── imagenet.py │ │ │ ├── imagenet_class_names.py │ │ │ └── templates.py │ ├── segmentation │ │ ├── __init__.py │ │ ├── ade20k.py │ │ ├── base_segmentation.py │ │ ├── coco_segmentation.py │ │ └── pascal_voc.py │ └── utils │ │ ├── __init__.py │ │ ├── common.py │ │ ├── text.py │ │ └── video.py ├── loader │ ├── __init__.py │ └── dataloader.py ├── sampler │ ├── __init__.py │ ├── base_sampler.py │ ├── batch_sampler.py │ ├── chain_sampler.py │ ├── multi_scale_sampler.py │ ├── utils.py │ └── variable_batch_sampler.py ├── text_tokenizer │ ├── __init__.py │ ├── base_tokenizer.py │ └── clip_tokenizer.py ├── transforms │ ├── __init__.py │ ├── audio.py │ ├── audio_aux │ │ ├── __init__.py │ │ └── mfccs.py │ ├── audio_bytes.py │ ├── base_transforms.py │ ├── common.py │ ├── image_bytes.py │ ├── image_pil.py │ ├── image_torch.py │ ├── utils.py │ └── video.py └── video_reader │ ├── __init__.py │ ├── base_av_reader.py │ ├── decord_reader.py │ └── pyav_reader.py ├── docs ├── .gitignore ├── .nojekyll ├── Makefile ├── __init__.py ├── make.bat └── source │ ├── conf.py │ ├── data_samplers.rst │ ├── en │ ├── general │ │ ├── README-config-files-intro.md │ │ ├── README-directory-structure.md │ │ ├── README-model-zoo.md │ │ ├── README-new-dataset.md │ │ └── README-pytorch-to-coreml.md │ └── models │ │ ├── classification │ │ ├── README-classification-tutorial.md │ │ ├── README-mobilenets.md │ │ ├── README-mobilevit-v2.md │ │ ├── README-mobilevit.md │ │ ├── README-resnet.md │ │ └── README-robustness-evaluations.md │ │ ├── detection │ │ ├── README-SSDLite-mobilevit-v2.md │ │ └── README-detection-SSD-tutorial.md │ │ └── segmentation │ │ ├── README-segmentation-deeplabv3-tutorial.md │ │ └── README-segmentation-mobilevit-v2.md │ ├── getting_started.rst │ ├── how_to.rst │ ├── index.rst │ ├── installation.rst │ ├── models.md │ ├── modules.rst │ └── sample_recipes.rst ├── engine ├── __init__.py ├── detection_utils │ ├── __init__.py │ └── coco_map.py ├── eval_detection.py ├── eval_segmentation.py ├── evaluation_engine.py ├── segmentation_utils │ ├── __init__.py │ └── cityscapes_iou.py ├── training_engine.py └── utils.py ├── examples ├── byteformer │ ├── README.md │ ├── imagenet_file_encodings │ │ ├── encoding_type=PNG.yaml │ │ ├── encoding_type=TIFF.yaml │ │ ├── encoding_type=fCHW.yaml │ │ └── encoding_type=fHWC.yaml │ ├── imagenet_jpeg_q100 │ │ ├── conv_kernel_size=16.yaml │ │ ├── conv_kernel_size=32.yaml │ │ └── conv_kernel_size=8.yaml │ ├── imagenet_jpeg_q60 │ │ ├── conv_kernel_size=16,window_sizes=[128].yaml │ │ ├── conv_kernel_size=16,window_sizes=[32].yaml │ │ ├── conv_kernel_size=32,window_sizes=[128].yaml │ │ ├── conv_kernel_size=32,window_sizes=[32].yaml │ │ ├── conv_kernel_size=4,window_sizes=[128].yaml │ │ ├── conv_kernel_size=4,window_sizes=[32].yaml │ │ ├── conv_kernel_size=8,window_sizes=[128].yaml │ │ └── conv_kernel_size=8,window_sizes=[32].yaml │ ├── imagenet_jpeg_shuffle_bytes │ │ ├── mode=cyclic_half_length.yaml │ │ ├── mode=random_shuffle.yaml │ │ ├── mode=reverse.yaml │ │ ├── mode=stride.yaml │ │ └── mode=window_shuffle.yaml │ ├── imagenet_obfuscation │ │ ├── width_range=[-10,10].yaml │ │ ├── width_range=[-20,20].yaml │ │ ├── width_range=[-5,5].yaml │ │ └── width_range=[0,0].yaml │ ├── imagenet_privacy_preserving_camera │ │ ├── keep_frac=0.03,conv_kernel_size=4.yaml │ │ ├── keep_frac=0.05,conv_kernel_size=4.yaml │ │ ├── keep_frac=0.1,conv_kernel_size=4.yaml │ │ ├── keep_frac=0.25,conv_kernel_size=8.yaml │ │ ├── keep_frac=0.5,conv_kernel_size=16.yaml │ │ └── keep_frac=0.75,conv_kernel_size=32.yaml │ ├── model_arch.png │ ├── speech_commands_mp3 │ │ ├── conv_kernel_size=4,window_size=[128].yaml │ │ ├── conv_kernel_size=4,window_size=[32].yaml │ │ ├── conv_kernel_size=8,window_size=[128].yaml │ │ └── conv_kernel_size=8,window_size=[32].yaml │ └── speech_commands_wav │ │ ├── encoding_dtype=float32,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=float32,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=int16,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=int16,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=int16,conv_kernel_size=8.yaml │ │ ├── encoding_dtype=int32,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=int32,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=uint8,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=uint8,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=uint8,conv_kernel_size=4.yaml │ │ └── encoding_dtype=uint8,conv_kernel_size=8.yaml ├── range_augment │ ├── README-classification.md │ ├── README-clip.md │ ├── README-distillation.md │ ├── README-object-detection.md │ ├── README-segmentation.md │ ├── README.md │ ├── classification │ │ ├── efficientnet_b0.yaml │ │ ├── efficientnet_b1.yaml │ │ ├── efficientnet_b2.yaml │ │ ├── efficientnet_b3.yaml │ │ ├── mobilenet_v1.yaml │ │ ├── mobilenet_v2.yaml │ │ ├── mobilenet_v3.yaml │ │ ├── mobilevit_v1.yaml │ │ ├── regnety_16gf.yaml │ │ ├── resnet_101.yaml │ │ ├── resnet_50.yaml │ │ ├── se_resnet_50.yaml │ │ ├── swin_transformer_small.yaml │ │ └── swin_transformer_tiny.yaml │ ├── clip │ │ ├── clip_vit_base.yaml │ │ └── clip_vit_huge.yaml │ ├── clip_finetune_imagenet │ │ ├── clip_vit_base.yaml │ │ └── clip_vit_huge.yaml │ ├── detection │ │ ├── maskrcnn_efficientnet_b3.yaml │ │ ├── maskrcnn_mobilenet_v1.yaml │ │ ├── maskrcnn_mobilenet_v2.yaml │ │ ├── maskrcnn_mobilenet_v3.yaml │ │ ├── maskrcnn_mobilevit.yaml │ │ ├── maskrcnn_resnet_101.yaml │ │ └── maskrcnn_resnet_50.yaml │ ├── distillation │ │ ├── teacher_resnet101_student_mobilenet_v1.yaml │ │ ├── teacher_resnet101_student_mobilenet_v2.yaml │ │ ├── teacher_resnet101_student_mobilenet_v3.yaml │ │ └── teacher_resnet101_student_mobilevit.yaml │ └── segmentation │ │ ├── ade20k │ │ ├── deeplabv3_efficientnet_b3.yaml │ │ ├── deeplabv3_mobilenet_v1.yaml │ │ ├── deeplabv3_mobilenet_v2.yaml │ │ ├── deeplabv3_mobilenet_v3.yaml │ │ ├── deeplabv3_mobilevit.yaml │ │ ├── deeplabv3_resnet_101.yaml │ │ └── deeplabv3_resnet_50.yaml │ │ └── pascal_voc │ │ ├── deeplabv3_efficientnet_b3.yaml │ │ ├── deeplabv3_mobilenet_v1.yaml │ │ ├── deeplabv3_mobilenet_v2.yaml │ │ ├── deeplabv3_mobilenet_v3.yaml │ │ ├── deeplabv3_resnet_101.yaml │ │ └── deeplabv3_resnet_50.yaml └── vit │ ├── README.md │ ├── classification │ └── vit_base.yaml │ ├── detection │ └── mask_rcnn_vit_base_clip.yaml │ └── segmentation │ └── ade20k │ ├── deeplabv3_vit_base_clip_os_16.yaml │ └── deeplabv3_vit_base_clip_os_8.yaml ├── loss_fn ├── __init__.py ├── base_criteria.py ├── classification │ ├── __init__.py │ ├── base_classification_criteria.py │ ├── binary_cross_entropy.py │ └── cross_entropy.py ├── composite_loss.py ├── detection │ ├── __init__.py │ ├── base_detection_criteria.py │ ├── mask_rcnn_loss.py │ └── ssd_multibox_loss.py ├── distillation │ ├── __init__.py │ ├── base_distillation.py │ ├── hard_distillation.py │ └── soft_kl_distillation.py ├── multi_modal_img_text │ ├── __init__.py │ ├── base_multi_modal_img_text_criteria.py │ └── contrastive_loss_clip.py ├── neural_augmentation.py ├── segmentation │ ├── __init__.py │ ├── base_segmentation_criteria.py │ └── cross_entropy.py └── utils │ ├── __init__.py │ ├── build_helper.py │ └── class_weighting.py ├── loss_landscape ├── __init__.py └── landscape_utils.py ├── main_benchmark.py ├── main_conversion.py ├── main_eval.py ├── main_loss_landscape.py ├── main_train.py ├── metrics ├── __init__.py ├── average_precision.py ├── coco_map.py ├── confusion_mat.py ├── image_text_retrieval.py ├── intersection_over_union.py ├── metric_base.py ├── metric_base_test.py ├── misc.py ├── probability_histograms.py ├── psnr.py ├── retrieval_cmc.py ├── stats.py └── topk_accuracy.py ├── optim ├── __init__.py ├── adam.py ├── adamw.py ├── base_optim.py ├── scheduler │ ├── __init__.py │ ├── base_scheduler.py │ ├── cosine.py │ ├── cyclic.py │ ├── fixed.py │ ├── multi_step.py │ └── polynomial.py └── sgd.py ├── options ├── __init__.py ├── errors.py ├── opts.py ├── parse_args.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── requirements_docs.txt ├── setup.py ├── tests ├── __init__.py ├── configs.py ├── data │ ├── __init__.py │ ├── coco │ │ └── annotations │ │ │ └── instances_val2017.json │ ├── collate_fns │ │ └── test_collate_functions.py │ ├── datasets │ │ ├── __init__.py │ │ ├── audio_classification │ │ │ ├── __init__.py │ │ │ └── test_speech_commands_v2.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── dummy_configs │ │ │ │ ├── image_classification_dataset.yaml │ │ │ │ ├── imagenet.yaml │ │ │ │ ├── imagenet_a.yaml │ │ │ │ ├── imagenet_r.yaml │ │ │ │ └── imagenet_sketch.yaml │ │ │ ├── mock_imagenet.py │ │ │ ├── test_base_image_classification_dataset.py │ │ │ └── test_mock_imagenet.py │ │ ├── multi_modal_img_text │ │ │ ├── __init__.py │ │ │ └── zero_shot │ │ │ │ ├── __init__.py │ │ │ │ ├── dummy_imagenet_config.yaml │ │ │ │ ├── mock_imagenet.py │ │ │ │ └── test_mock_imagenet.py │ │ ├── segmentation │ │ │ ├── __init__.py │ │ │ ├── dummy_ade20k_config.yaml │ │ │ ├── mock_ade20k.py │ │ │ └── test_mock_ade20k.py │ │ ├── test_dataset_base.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── test_common.py │ │ │ └── test_video.py │ ├── dummy_silent_video.mov │ ├── dummy_video.mov │ ├── samplers │ │ ├── __init__.py │ │ ├── test_batch_sampler_config.yaml │ │ ├── test_chain_sampler.py │ │ ├── test_chain_sampler_config.yaml │ │ ├── test_data_samplers.py │ │ ├── test_multi_scale_sampler_config.yaml │ │ └── test_variable_batch_sampler_config.yaml │ └── video_reader │ │ └── test_av_reader.py ├── dummy_datasets │ ├── __init__.py │ ├── classification.py │ ├── multi_modal_img_text.py │ ├── segmentation.py │ └── ssd_detection.py ├── dummy_loader.py ├── layers │ └── test_token_merging.py ├── loss_fns │ ├── __init__.py │ ├── test_class_weighting.py │ ├── test_classification_loss.py │ ├── test_composite_loss.py │ ├── test_contrastive_loss.py │ ├── test_detection_loss.py │ ├── test_distillation_loss.py │ ├── test_neural_aug.py │ ├── test_neural_aug_compatibility.py │ └── test_segmentation_loss.py ├── metrics │ ├── __init__.py │ ├── base.py │ ├── test_coco_map.py │ ├── test_image_text_retrieval_metrics.py │ ├── test_iou.py │ ├── test_misc.py │ ├── test_probability_histogram.py │ ├── test_psnr.py │ ├── test_retrieval_cmc_metrics.py │ └── test_topk_accuracy.py ├── misc │ ├── __init__.py │ └── test_common.py ├── models │ ├── __init__.py │ ├── audio_classification │ │ ├── test_base_audio_classification.py │ │ └── test_byteformer.py │ ├── classification │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── test_byteformer.py │ │ └── test_byteformer.py │ └── test_neural_aug_utils.py ├── modules │ ├── __init__.py │ ├── test_transformer.py │ └── test_windowed_transformer.py ├── options │ ├── __init__.py │ ├── test_parse_args.py │ └── test_utils.py ├── test_byteformer_collate_fn.py ├── test_conventions.py ├── test_image_pil.py ├── test_model.py ├── test_multi_head_attn.py ├── test_pos_embeddings.py ├── test_scheduler.py ├── test_tokenizer.py ├── test_training_engine.py ├── test_utils.py ├── transforms │ ├── __init__.py │ ├── test_audio.py │ ├── test_audio_bytes.py │ ├── test_image.py │ ├── test_image_bytes.py │ └── test_video.py └── utils │ ├── __init__.py │ ├── test_common_utils.py │ ├── test_dict_utils.py │ └── test_import_utils.py └── utils ├── __init__.py ├── checkpoint_utils.py ├── color_map.py ├── common_utils.py ├── ddp_utils.py ├── dict_utils.py ├── download_utils.py ├── download_utils_base.py ├── import_utils.py ├── logger.py ├── math_utils.py ├── object_utils.py ├── object_utils_test.py ├── pytorch_to_coreml.py ├── registry.py ├── registry_test.py ├── resources.py ├── tensor_utils.py ├── third_party ├── __init__.py └── ddp_functional_utils.py └── visualization_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .coverage 3 | *.pyc 4 | __pycache__ 5 | .DS_STORE 6 | .idea 7 | results* 8 | *.png 9 | *.jpg 10 | .idea 11 | *.swp 12 | .pytest_cache 13 | .mypy_cache 14 | 15 | build/ 16 | 17 | results* 18 | vision_datasets/ 19 | exp_results/ 20 | exp_results* 21 | results_* 22 | 23 | *.so 24 | model_zoo 25 | model_zoo/* 26 | pipeline.yaml 27 | 28 | cvnets.egg-info 29 | cvnets.egg-info/* 30 | 31 | venv/ 32 | 33 | trash 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | LIBRARY_ROOT = Path(__file__).parent.parent 10 | 11 | MIN_TORCH_VERSION = "1.11.0" 12 | 13 | SUPPORTED_IMAGE_EXTNS = [".png", ".jpg", ".jpeg"] # Add image formats here 14 | SUPPORTED_MODALITIES = ["image", "video"] 15 | SUPPORTED_VIDEO_CLIP_VOTING_FN = ["sum", "max"] 16 | SUPPORTED_VIDEO_READER = ["pyav", "decord"] 17 | 18 | DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE_HEIGHT = 256 19 | DEFAULT_IMAGE_CHANNELS = 3 20 | DEFAULT_VIDEO_FRAMES = 8 21 | DEFAULT_LOG_FREQ = 500 22 | 23 | DEFAULT_ITERATIONS = 300000 24 | DEFAULT_EPOCHS = 300 25 | DEFAULT_MAX_ITERATIONS = DEFAULT_MAX_EPOCHS = 10000000 26 | 27 | TMP_RES_FOLDER = "results_tmp" 28 | 29 | TMP_CACHE_LOC = "/tmp/cvnets" 30 | 31 | Path(TMP_CACHE_LOC).mkdir(parents=True, exist_ok=True) 32 | 33 | 34 | def is_test_env() -> bool: 35 | return "PYTEST_CURRENT_TEST" in os.environ 36 | 37 | 38 | def if_test_env(then: Any, otherwise: Any) -> Any: 39 | return then if "PYTEST_CURRENT_TEST" in os.environ else otherwise 40 | -------------------------------------------------------------------------------- /config/classification/finetune_higher_res_in1k/mobilevit_v2.yaml: -------------------------------------------------------------------------------- 1 | taskname: '+ MobileViTv2-2.0 384x384' 2 | common: 3 | run_label: "train" 4 | log_freq: 500 5 | auto_resume: true 6 | mixed_precision: true 7 | channels_last: true 8 | tensorboard_logging: false 9 | grad_clip: 10.0 10 | dataset: 11 | root_train: "/mnt/imagenet/training" 12 | root_val: "/mnt/imagenet/validation" 13 | name: "imagenet" 14 | category: "classification" 15 | train_batch_size0: 64 # effective batch size of 128 (64 x 2 GPUs) 16 | val_batch_size0: 50 17 | eval_batch_size0: 50 18 | workers: 8 19 | persistent_workers: false 20 | pin_memory: true 21 | image_augmentation: 22 | random_resized_crop: 23 | enable: true 24 | interpolation: "bicubic" 25 | random_horizontal_flip: 26 | enable: true 27 | resize: 28 | enable: true 29 | size: 384 # shorter size is 384 30 | interpolation: "bicubic" 31 | center_crop: 32 | enable: true 33 | size: 384 34 | sampler: 35 | name: "batch_sampler" 36 | bs: 37 | crop_size_width: 384 38 | crop_size_height: 384 39 | loss: 40 | category: "classification" 41 | classification: 42 | name: "cross_entropy" 43 | cross_entropy: 44 | label_smoothing: 0.1 45 | optim: 46 | name: "sgd" 47 | weight_decay: 4.e-5 48 | no_decay_bn_filter_bias: true 49 | sgd: 50 | momentum: 0.9 51 | scheduler: 52 | name: "fixed" 53 | max_epochs: 10 54 | fixed: 55 | lr: 1.e-3 56 | model: 57 | classification: 58 | name: "mobilevit_v2" 59 | mitv2: 60 | width_multiplier: 2.0 61 | attn_norm_layer: "layer_norm_2d" 62 | activation: 63 | name: "swish" 64 | normalization: 65 | name: "batch_norm" 66 | momentum: 0.1 67 | activation: 68 | name: "swish" 69 | ema: 70 | enable: true 71 | momentum: 0.00005 72 | stats: 73 | val: [ "loss", "top1", "top5" ] 74 | train: ["loss"] 75 | checkpoint_metric: "top1" 76 | checkpoint_metric_max: true 77 | -------------------------------------------------------------------------------- /config/classification/imagenet/mobilenet_v3.yaml: -------------------------------------------------------------------------------- 1 | taskname: '+ MobileNetv3-Large' 2 | common: 3 | run_label: "train" 4 | log_freq: 500 5 | auto_resume: true 6 | mixed_precision: true 7 | channels_last: true 8 | dataset: 9 | root_train: "/mnt/imagenet/training" 10 | root_val: "/mnt/imagenet/validation" 11 | name: "imagenet" 12 | category: "classification" 13 | train_batch_size0: 512 # effective batch size is 2048 (512 * 4 GPUs) 14 | val_batch_size0: 100 15 | eval_batch_size0: 100 16 | workers: 8 17 | persistent_workers: true 18 | pin_memory: true 19 | image_augmentation: 20 | random_resized_crop: 21 | enable: true 22 | interpolation: "bilinear" 23 | random_horizontal_flip: 24 | enable: true 25 | resize: 26 | enable: true 27 | size: 256 # shorter size is 256 28 | interpolation: "bilinear" 29 | center_crop: 30 | enable: true 31 | size: 224 32 | sampler: 33 | name: "variable_batch_sampler" 34 | vbs: 35 | crop_size_width: 224 36 | crop_size_height: 224 37 | max_n_scales: 5 38 | min_crop_size_width: 128 39 | max_crop_size_width: 320 40 | min_crop_size_height: 128 41 | max_crop_size_height: 320 42 | check_scale: 32 43 | loss: 44 | category: "classification" 45 | classification: 46 | name: "cross_entropy" 47 | cross_entropy: 48 | label_smoothing: 0.1 49 | optim: 50 | name: "sgd" 51 | weight_decay: 4.e-5 52 | no_decay_bn_filter_bias: true 53 | sgd: 54 | momentum: 0.9 55 | nesterov: true 56 | scheduler: 57 | name: "cosine" 58 | is_iteration_based: false 59 | max_epochs: 300 60 | warmup_iterations: 3000 61 | warmup_init_lr: 0.1 62 | cosine: 63 | max_lr: 0.8 64 | min_lr: 4.e-4 65 | model: 66 | classification: 67 | name: "mobilenetv3" 68 | mobilenetv3: 69 | mode: "large" 70 | width_multiplier: 1.0 71 | normalization: 72 | name: "batch_norm" 73 | momentum: 0.1 74 | layer: 75 | global_pool: "mean" 76 | conv_init: "kaiming_normal" 77 | linear_init: "normal" 78 | ema: 79 | enable: true 80 | momentum: 0.0005 81 | stats: 82 | val: [ "loss", "top1", "top5" ] 83 | train: ["loss"] 84 | checkpoint_metric: "top1" 85 | checkpoint_metric_max: true 86 | -------------------------------------------------------------------------------- /config/classification/imagenet/resnet.yaml: -------------------------------------------------------------------------------- 1 | taskname: '+ ResNet-50' 2 | common: 3 | run_label: "train" 4 | log_freq: 500 5 | auto_resume: true 6 | mixed_precision: true 7 | channels_last: true 8 | dataset: 9 | root_train: "/mnt/imagenet/training" 10 | root_val: "/mnt/imagenet/validation" 11 | name: "imagenet" 12 | category: "classification" 13 | train_batch_size0: 128 # effective batch size is 1024 (128 * 8 GPUs) 14 | val_batch_size0: 100 15 | eval_batch_size0: 100 16 | workers: 8 17 | persistent_workers: true 18 | pin_memory: true 19 | image_augmentation: 20 | random_resized_crop: 21 | enable: true 22 | interpolation: "bilinear" 23 | random_horizontal_flip: 24 | enable: true 25 | resize: 26 | enable: true 27 | size: 256 # shorter size is 256 28 | interpolation: "bilinear" 29 | center_crop: 30 | enable: true 31 | size: 224 32 | sampler: 33 | name: "variable_batch_sampler" 34 | vbs: 35 | crop_size_width: 224 36 | crop_size_height: 224 37 | max_n_scales: 5 38 | min_crop_size_width: 128 39 | max_crop_size_width: 320 40 | min_crop_size_height: 128 41 | max_crop_size_height: 320 42 | check_scale: 32 43 | loss: 44 | category: "classification" 45 | classification: 46 | name: "cross_entropy" 47 | cross_entropy: 48 | label_smoothing: 0.1 49 | optim: 50 | name: "sgd" 51 | weight_decay: 1.e-4 52 | no_decay_bn_filter_bias: true 53 | sgd: 54 | momentum: 0.9 55 | scheduler: 56 | name: "cosine" 57 | is_iteration_based: false 58 | max_epochs: 150 59 | warmup_iterations: 7500 60 | warmup_init_lr: 0.05 61 | cosine: 62 | max_lr: 0.4 63 | min_lr: 2.e-4 64 | model: 65 | classification: 66 | name: "resnet" 67 | activation: 68 | name: "relu" 69 | resnet: 70 | depth: 50 71 | normalization: 72 | name: "batch_norm" 73 | momentum: 0.1 74 | activation: 75 | name: "relu" 76 | inplace: true 77 | layer: 78 | global_pool: "mean" 79 | conv_init: "kaiming_normal" 80 | linear_init: "normal" 81 | ema: 82 | enable: true 83 | momentum: 0.0005 84 | stats: 85 | val: [ "loss", "top1", "top5" ] 86 | train: ["loss"] 87 | checkpoint_metric: "top1" 88 | checkpoint_metric_max: true 89 | -------------------------------------------------------------------------------- /config/classification/imagenet/resnet_eval/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | taskname: '+ ViT-B/16' 2 | common: 3 | run_label: "val" 4 | log_freq: 500 5 | dataset: 6 | root_val: "/mnt/imagenet_a/" 7 | name: "imagenet_a" 8 | category: "classification" 9 | val_batch_size0: 100 10 | eval_batch_size0: 100 11 | workers: 8 12 | persistent_workers: false 13 | pin_memory: true 14 | image_augmentation: 15 | resize: 16 | enable: true 17 | size: 232 # shorter size is 256 18 | interpolation: "bilinear" 19 | center_crop: 20 | enable: true 21 | size: 224 22 | loss: 23 | category: "classification" 24 | classification: 25 | name: "cross_entropy" 26 | model: 27 | classification: 28 | n_classes: 1000 29 | name: "resnet" 30 | activation: 31 | name: "relu" 32 | resnet: 33 | depth: 50 34 | normalization: 35 | name: "batch_norm" 36 | activation: 37 | name: "relu" 38 | inplace: true 39 | layer: 40 | global_pool: "mean" 41 | stats: 42 | val: [ "loss", "top1", "top5" ] 43 | checkpoint_metric: "top1" 44 | checkpoint_metric_max: true 45 | -------------------------------------------------------------------------------- /config/classification/imagenet/resnet_eval/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | taskname: '+ ViT-B/16' 2 | common: 3 | run_label: "val" 4 | log_freq: 500 5 | dataset: 6 | root_val: "/mnt/imagenet_r/" 7 | name: "imagenet_r" 8 | category: "classification" 9 | val_batch_size0: 100 10 | eval_batch_size0: 100 11 | workers: 8 12 | persistent_workers: false 13 | pin_memory: true 14 | image_augmentation: 15 | resize: 16 | enable: true 17 | size: 232 # shorter size is 256 18 | interpolation: "bilinear" 19 | center_crop: 20 | enable: true 21 | size: 224 22 | loss: 23 | category: "classification" 24 | classification: 25 | name: "cross_entropy" 26 | model: 27 | classification: 28 | n_classes: 1000 29 | name: "resnet" 30 | activation: 31 | name: "relu" 32 | resnet: 33 | depth: 50 34 | normalization: 35 | name: "batch_norm" 36 | activation: 37 | name: "relu" 38 | inplace: true 39 | layer: 40 | global_pool: "mean" 41 | stats: 42 | val: [ "loss", "top1", "top5" ] 43 | checkpoint_metric: "top1" 44 | checkpoint_metric_max: true 45 | -------------------------------------------------------------------------------- /config/classification/imagenet/resnet_eval/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | taskname: '+ ViT-B/16' 2 | common: 3 | run_label: "val" 4 | log_freq: 500 5 | dataset: 6 | root_val: "/mnt/imagenet_sketch/" 7 | name: "imagenet_sketch" 8 | category: "classification" 9 | val_batch_size0: 100 10 | eval_batch_size0: 100 11 | workers: 8 12 | persistent_workers: false 13 | pin_memory: true 14 | image_augmentation: 15 | resize: 16 | enable: true 17 | size: 232 # shorter size is 256 18 | interpolation: "bilinear" 19 | center_crop: 20 | enable: true 21 | size: 224 22 | loss: 23 | category: "classification" 24 | classification: 25 | name: "cross_entropy" 26 | model: 27 | classification: 28 | n_classes: 1000 29 | name: "resnet" 30 | activation: 31 | name: "relu" 32 | resnet: 33 | depth: 50 34 | normalization: 35 | name: "batch_norm" 36 | activation: 37 | name: "relu" 38 | inplace: true 39 | layer: 40 | global_pool: "mean" 41 | stats: 42 | val: [ "loss", "top1", "top5" ] 43 | checkpoint_metric: "top1" 44 | checkpoint_metric_max: true 45 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import signal 8 | from types import FrameType 9 | from typing import Optional 10 | 11 | import pytest 12 | 13 | session_timed_out = False 14 | 15 | 16 | def handle_timeout(signum: int, frame: Optional[FrameType] = None) -> None: 17 | global session_timed_out 18 | session_timed_out = True 19 | # Call fail() to capture the output of the test. 20 | pytest.fail("timeout") 21 | 22 | 23 | def pytest_sessionstart(): 24 | timeout = os.environ.get("PYTEST_GLOBAL_TIMEOUT", "") 25 | if not timeout: 26 | return 27 | if timeout.endswith("s"): 28 | timeout = int(timeout[:-1]) 29 | elif timeout.endswith("m"): 30 | timeout = int(timeout[:-1]) * 60 31 | else: 32 | raise ValueError( 33 | f"Timeout value {timeout} should either end with 'm' (minutes) or 's' (seconds)." 34 | ) 35 | 36 | signal.signal(signal.SIGALRM, handle_timeout) 37 | signal.setitimer(signal.ITIMER_REAL, timeout) 38 | 39 | 40 | def pytest_runtest_logfinish(nodeid, location): 41 | if session_timed_out: 42 | pytest.exit("timeout") 43 | -------------------------------------------------------------------------------- /constraints.txt: -------------------------------------------------------------------------------- 1 | av==10.0.0 2 | black==22.10.0 3 | colorama==0.4.3 4 | coremltools==6.3.0 5 | coverage==7.2.3 6 | decord==0.6.0 7 | docutils==0.15.2 8 | fairscale==0.4.13 9 | fairseq==0.12.2 10 | ftfy==6.1.1 11 | fvcore==0.1.5.post20221221 12 | librosa==0.8.1 13 | matplotlib==3.7.1 14 | numba==0.56.4 15 | numpy==1.21.2 16 | omegaconf==2.0.6 17 | opencv-contrib-python==4.5.5.64 18 | opencv-python==4.5.5.64 19 | opencv-python-headless==4.5.5.64 20 | packaging==21.3 21 | Pillow==9.5.0 22 | Pillow-SIMD==9.0.0.post1 23 | psutil==5.9.4 24 | pyarrow==8.0.0 25 | pybase64==1.2.3 26 | pycocotools==2.0.6 27 | pycodestyle==2.10.0 28 | pyflakes==3.0.1 29 | pytest==7.2.2 30 | pytorchvideo==0.1.5 31 | scikit-learn==1.2.2 32 | scipy==1.10.1 33 | Sphinx==5.3.0 34 | sphinx-rtd-theme==1.2.0 35 | tensorboard==2.12.0 36 | tensorboard-logger==0.1.0 37 | tensorboardX==2.6 38 | torch==1.13.0 39 | torchaudio==0.13.0 40 | torchdata==0.5.0 41 | torchtext==0.14.0 42 | torchvision==0.14.0 43 | tqdm==4.65.0 44 | ujson==5.7.0 45 | -------------------------------------------------------------------------------- /cvnets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from cvnets.anchor_generator import arguments_anchor_gen 9 | from cvnets.image_projection_layers import arguments_image_projection_head 10 | from cvnets.layers import arguments_nn_layers 11 | from cvnets.matcher_det import arguments_box_matcher 12 | from cvnets.misc.averaging_utils import EMA, arguments_ema 13 | from cvnets.misc.common import parameter_list 14 | from cvnets.models import arguments_model, get_model 15 | from cvnets.models.detection import DetectionPredTuple 16 | from cvnets.neural_augmentor import arguments_neural_augmentor 17 | from cvnets.text_encoders import arguments_text_encoder 18 | from options.utils import extend_selected_args_with_prefix 19 | 20 | 21 | def modeling_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | # text encoder arguments (usually for multi-modal tasks) 23 | parser = arguments_text_encoder(parser) 24 | # image projection head arguments (usually for multi-modal tasks) 25 | parser = arguments_image_projection_head(parser) 26 | # model arguments 27 | parser = arguments_model(parser) 28 | # neural network layer argumetns 29 | parser = arguments_nn_layers(parser) 30 | # EMA arguments 31 | parser = arguments_ema(parser) 32 | # anchor generator arguments (for object detection) 33 | parser = arguments_anchor_gen(parser) 34 | # box matcher arguments (for object detection) 35 | parser = arguments_box_matcher(parser) 36 | # neural aug arguments 37 | parser = arguments_neural_augmentor(parser) 38 | 39 | # Add teacher as a prefix to enable distillation tasks 40 | # keep it as the last entry 41 | parser = extend_selected_args_with_prefix( 42 | parser, match_prefix="--model.", additional_prefix="--teacher.model." 43 | ) 44 | 45 | return parser 46 | -------------------------------------------------------------------------------- /cvnets/anchor_generator/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from cvnets.anchor_generator.base_anchor_generator import BaseAnchorGenerator 9 | from utils import logger 10 | from utils.registry import Registry 11 | 12 | # register anchor generator 13 | ANCHOR_GEN_REGISTRY = Registry( 14 | "anchor_gen", 15 | base_class=BaseAnchorGenerator, 16 | lazy_load_dirs=["cvnets/anchor_generator"], 17 | internal_dirs=["internal", "internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_anchor_gen(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | """Arguments related to anchor generator for object detection""" 23 | group = parser.add_argument_group("Anchor generator", "Anchor generator") 24 | group.add_argument( 25 | "--anchor-generator.name", type=str, help="Name of the anchor generator" 26 | ) 27 | 28 | # add class specific arguments 29 | parser = ANCHOR_GEN_REGISTRY.all_arguments(parser) 30 | return parser 31 | 32 | 33 | def build_anchor_generator(opts, *args, **kwargs): 34 | """Build anchor generator for object detection""" 35 | anchor_gen_name = getattr(opts, "anchor_generator.name") 36 | 37 | # We registered the base class using a special `name` (i.e., `__base__`) 38 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 39 | # be used. Therefore, we raise an error for such cases 40 | if anchor_gen_name == "__base__": 41 | logger.error("__base__ can't be used as a projection name. Please check.") 42 | 43 | anchor_gen = ANCHOR_GEN_REGISTRY[anchor_gen_name](opts, *args, **kwargs) 44 | return anchor_gen 45 | -------------------------------------------------------------------------------- /cvnets/image_projection_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | 7 | import argparse 8 | 9 | from cvnets.image_projection_layers.base_image_projection import BaseImageProjectionHead 10 | from utils import logger 11 | from utils.registry import Registry 12 | 13 | IMAGE_PROJECTION_HEAD_REGISTRY = Registry( 14 | "image_projection_head", 15 | base_class=BaseImageProjectionHead, 16 | lazy_load_dirs=["cvnets/image_projection_layers"], 17 | internal_dirs=["internal", "internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_image_projection_head( 22 | parser: argparse.ArgumentParser, 23 | ) -> argparse.ArgumentParser: 24 | """Register arguments of all image projection heads.""" 25 | # add arguments for base image projection layer 26 | parser = BaseImageProjectionHead.add_arguments(parser) 27 | 28 | # add class specific arguments 29 | parser = IMAGE_PROJECTION_HEAD_REGISTRY.all_arguments(parser) 30 | return parser 31 | 32 | 33 | def build_image_projection_head( 34 | opts: argparse.Namespace, in_dim: int, out_dim: int, *args, **kwargs 35 | ) -> BaseImageProjectionHead: 36 | """Helper function to build an image projection head from command-line arguments. 37 | 38 | Args: 39 | opts: Command-line arguments 40 | in_dim: Input dimension to the projection head. 41 | out_dim: Output dimension of the projection head. 42 | 43 | Returns: 44 | Image projection head module. 45 | """ 46 | 47 | # Get the name of image projection head 48 | image_projection_head_name = getattr(opts, "model.image_projection_head.name") 49 | 50 | # We registered the base class using a special `name` (i.e., `__base__`) 51 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 52 | # be used. Therefore, we raise an error for such cases 53 | if image_projection_head_name == "__base__": 54 | logger.error("__base__ can't be used as a projection name. Please check.") 55 | 56 | image_projection_head = IMAGE_PROJECTION_HEAD_REGISTRY[image_projection_head_name]( 57 | opts, in_dim, out_dim, *args, **kwargs 58 | ) 59 | return image_projection_head 60 | -------------------------------------------------------------------------------- /cvnets/layers/activation/gelu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from cvnets.layers.activation import register_act_fn 9 | 10 | 11 | @register_act_fn(name="gelu") 12 | class GELU(nn.GELU): 13 | """ 14 | Applies the `Gaussian Error Linear Units `_ function 15 | """ 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /cvnets/layers/activation/hard_sigmoid.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | from torch.nn import functional as F 10 | 11 | from cvnets.layers.activation import register_act_fn 12 | 13 | 14 | @register_act_fn(name="hard_sigmoid") 15 | class Hardsigmoid(nn.Hardsigmoid): 16 | """ 17 | Applies the `Hard Sigmoid `_ function 18 | """ 19 | 20 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 21 | super().__init__(inplace=inplace) 22 | 23 | def forward(self, input: Tensor, *args, **kwargs) -> Tensor: 24 | if hasattr(F, "hardsigmoid"): 25 | return F.hardsigmoid(input, self.inplace) 26 | else: 27 | return F.relu(input + 3) / 6 28 | -------------------------------------------------------------------------------- /cvnets/layers/activation/hard_swish.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | from torch.nn import functional as F 10 | 11 | from cvnets.layers.activation import register_act_fn 12 | 13 | 14 | @register_act_fn(name="hard_swish") 15 | class Hardswish(nn.Hardswish): 16 | """ 17 | Applies the HardSwish function, as described in the paper 18 | `Searching for MobileNetv3 `_ 19 | """ 20 | 21 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 22 | super().__init__(inplace=inplace) 23 | 24 | def forward(self, input: Tensor, *args, **kwargs) -> Tensor: 25 | if hasattr(F, "hardswish"): 26 | return F.hardswish(input, self.inplace) 27 | else: 28 | x_hard_sig = F.relu(input + 3) / 6 29 | return input * x_hard_sig 30 | -------------------------------------------------------------------------------- /cvnets/layers/activation/leaky_relu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from cvnets.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="leaky_relu") 14 | class LeakyReLU(nn.LeakyReLU): 15 | """ 16 | Applies a leaky relu function. See `Rectifier Nonlinearities Improve Neural Network Acoustic Models` 17 | for more details. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | negative_slope: Optional[float] = 1e-2, 23 | inplace: Optional[bool] = False, 24 | *args, 25 | **kwargs 26 | ) -> None: 27 | super().__init__(negative_slope=negative_slope, inplace=inplace) 28 | -------------------------------------------------------------------------------- /cvnets/layers/activation/prelu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from cvnets.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="prelu") 14 | class PReLU(nn.PReLU): 15 | """ 16 | Applies the `Parametric Rectified Linear Unit `_ function 17 | """ 18 | 19 | def __init__( 20 | self, 21 | num_parameters: Optional[int] = 1, 22 | init: Optional[float] = 0.25, 23 | *args, 24 | **kwargs 25 | ) -> None: 26 | super().__init__(num_parameters=num_parameters, init=init) 27 | -------------------------------------------------------------------------------- /cvnets/layers/activation/relu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from cvnets.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="relu") 14 | class ReLU(nn.ReLU): 15 | """ 16 | Applies Rectified Linear Unit function 17 | """ 18 | 19 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 20 | super().__init__(inplace=inplace) 21 | -------------------------------------------------------------------------------- /cvnets/layers/activation/relu6.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from cvnets.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="relu6") 14 | class ReLU6(nn.ReLU6): 15 | """ 16 | Applies the ReLU6 function 17 | """ 18 | 19 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 20 | super().__init__(inplace=inplace) 21 | -------------------------------------------------------------------------------- /cvnets/layers/activation/sigmoid.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from cvnets.layers.activation import register_act_fn 9 | 10 | 11 | @register_act_fn(name="sigmoid") 12 | class Sigmoid(nn.Sigmoid): 13 | """ 14 | Applies the sigmoid function 15 | """ 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /cvnets/layers/activation/swish.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from cvnets.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="swish") 14 | class Swish(nn.SiLU): 15 | """ 16 | Applies the `Swish (also known as SiLU) `_ function. 17 | """ 18 | 19 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 20 | super().__init__(inplace=inplace) 21 | -------------------------------------------------------------------------------- /cvnets/layers/activation/tanh.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from cvnets.layers.activation import register_act_fn 9 | 10 | 11 | @register_act_fn(name="tanh") 12 | class Tanh(nn.Tanh): 13 | """ 14 | Applies Tanh function 15 | """ 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /cvnets/layers/adaptive_pool.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Tuple, Union 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): 12 | """ 13 | Applies a 2D adaptive average pooling over an input tensor. 14 | 15 | Args: 16 | output_size (Optional, int or Tuple[int, int]): The target output size. If a single int :math:`h` is passed, 17 | then a square output of size :math:`hxh` is produced. If a tuple of size :math:`hxw` is passed, then an 18 | output of size `hxw` is produced. Default is 1. 19 | Shape: 20 | - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, 21 | :math:`H` is the input height, and :math:`W` is the input width 22 | - Output: :math:`(N, C, h, h)` or :math:`(N, C, h, w)` 23 | """ 24 | 25 | def __init__( 26 | self, output_size: Union[int, Tuple[int, int]] = 1, *args, **kwargs 27 | ) -> None: 28 | super().__init__(output_size=output_size) 29 | -------------------------------------------------------------------------------- /cvnets/layers/base_layer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | from torch import nn 10 | 11 | from cvnets.misc.common import parameter_list 12 | 13 | 14 | class BaseLayer(nn.Module): 15 | """ 16 | Base class for neural network layers. Subclass must implement `forward` function. 17 | """ 18 | 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__() 21 | 22 | @classmethod 23 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 24 | """Add layer specific arguments""" 25 | return parser 26 | 27 | def get_trainable_parameters( 28 | self, 29 | weight_decay: Optional[float] = 0.0, 30 | no_decay_bn_filter_bias: Optional[bool] = False, 31 | *args, 32 | **kwargs 33 | ) -> Tuple[List[Dict], List[float]]: 34 | """ 35 | Get parameters for training along with the learning rate. 36 | 37 | Args: 38 | weight_decay: weight decay 39 | no_decay_bn_filter_bias: Do not decay BN and biases. Defaults to False. 40 | 41 | Returns: 42 | Returns a tuple of length 2. The first entry is a list of dictionary with three keys 43 | (params, weight_decay, param_names). The second entry is a list of floats containing 44 | learning rate for each parameter. 45 | 46 | Note: 47 | Learning rate multiplier is set to 1.0 here as it is handled inside the Central Model. 48 | """ 49 | param_list = parameter_list( 50 | named_parameters=self.named_parameters, 51 | weight_decay=weight_decay, 52 | no_decay_bn_filter_bias=no_decay_bn_filter_bias, 53 | *args, 54 | **kwargs 55 | ) 56 | return param_list, [1.0] * len(param_list) 57 | 58 | def forward(self, *args, **kwargs) -> Any: 59 | """Forward function.""" 60 | raise NotImplementedError("Sub-classes should implement forward method") 61 | 62 | def __repr__(self): 63 | return "{}".format(self.__class__.__name__) 64 | -------------------------------------------------------------------------------- /cvnets/layers/dropout.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class Dropout(nn.Dropout): 12 | """ 13 | This layer, during training, randomly zeroes some of the elements of the input tensor with probability `p` 14 | using samples from a Bernoulli distribution. 15 | 16 | Args: 17 | p: probability of an element to be zeroed. Default: 0.5 18 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 19 | 20 | Shape: 21 | - Input: :math:`(N, *)` where :math:`N` is the batch size 22 | - Output: same as the input 23 | 24 | """ 25 | 26 | def __init__( 27 | self, p: Optional[float] = 0.5, inplace: Optional[bool] = False, *args, **kwargs 28 | ) -> None: 29 | super().__init__(p=p, inplace=inplace) 30 | 31 | 32 | class Dropout2d(nn.Dropout2d): 33 | """ 34 | This layer, during training, randomly zeroes some of the elements of the 4D input tensor with probability `p` 35 | using samples from a Bernoulli distribution. 36 | 37 | Args: 38 | p: probability of an element to be zeroed. Default: 0.5 39 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 40 | 41 | Shape: 42 | - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the input channels, 43 | :math:`H` is the input tensor height, and :math:`W` is the input tensor width 44 | - Output: same as the input 45 | 46 | """ 47 | 48 | def __init__(self, p: float = 0.5, inplace: bool = False): 49 | super().__init__(p=p, inplace=inplace) 50 | -------------------------------------------------------------------------------- /cvnets/layers/embedding.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | from cvnets.layers.base_layer import BaseLayer 12 | from cvnets.layers.normalization_layers import get_normalization_layer 13 | from utils import logger 14 | 15 | 16 | class Embedding(nn.Embedding): 17 | r"""A lookup table that stores embeddings of a fixed dictionary and size. 18 | 19 | Args: 20 | num_embeddings (int): size of the dictionary of embeddings 21 | embedding_dim (int): the size of each embedding vector 22 | padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; 23 | therefore, the embedding vector at :attr:`padding_idx` is not updated during training, 24 | i.e. it remains as a fixed "pad". For a newly constructed Embedding, 25 | the embedding vector at :attr:`padding_idx` will default to all zeros, 26 | but can be updated to another value to be used as the padding vector. 27 | 28 | Shape: 29 | - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract 30 | - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` 31 | """ 32 | 33 | def __init__( 34 | self, 35 | opts, 36 | num_embeddings: int, 37 | embedding_dim: int, 38 | padding_idx: Optional[int] = None, 39 | *args, 40 | **kwargs 41 | ): 42 | super().__init__( 43 | num_embeddings=num_embeddings, 44 | embedding_dim=embedding_dim, 45 | padding_idx=padding_idx, 46 | ) 47 | 48 | def reset_parameters(self) -> None: 49 | nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) 50 | if self.padding_idx is not None: 51 | nn.init.constant_(self.weight[self.padding_idx], 0) 52 | -------------------------------------------------------------------------------- /cvnets/layers/flatten.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class Flatten(nn.Flatten): 12 | r""" 13 | This layer flattens a contiguous range of dimensions into a tensor. 14 | 15 | Args: 16 | start_dim (Optional[int]): first dim to flatten. Default: 1 17 | end_dim (Optional[int]): last dim to flatten. Default: -1 18 | 19 | Shape: 20 | - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' 21 | where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any 22 | number of dimensions including none. 23 | - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. 24 | """ 25 | 26 | def __init__(self, start_dim: Optional[int] = 1, end_dim: Optional[int] = -1): 27 | super(Flatten, self).__init__(start_dim=start_dim, end_dim=end_dim) 28 | -------------------------------------------------------------------------------- /cvnets/layers/identity.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor 7 | 8 | from cvnets.layers.base_layer import BaseLayer 9 | 10 | 11 | class Identity(BaseLayer): 12 | """ 13 | This is a place-holder and returns the same tensor. 14 | """ 15 | 16 | def __init__(self): 17 | super(Identity, self).__init__() 18 | 19 | def forward(self, x: Tensor) -> Tensor: 20 | return x 21 | -------------------------------------------------------------------------------- /cvnets/layers/normalization/group_norm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from cvnets.layers.normalization import register_norm_fn 11 | 12 | 13 | @register_norm_fn(name="group_norm") 14 | class GroupNorm(nn.GroupNorm): 15 | """ 16 | Applies a `Group Normalization `_ over an input tensor 17 | 18 | Args: 19 | num_groups (int): number of groups to separate the input channels into 20 | num_features (int): :math:`C` from an expected input of size :math:`(N, C, *)` 21 | eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 22 | affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` 23 | 24 | Shape: 25 | - Input: :math:`(N, C, *)` where :math:`N` is the batch size, :math:`C` is the number of input channels, 26 | and :math:`*` is the remaining dimensions of the input tensor 27 | - Output: same shape as the input 28 | 29 | .. note:: 30 | GroupNorm is the same as LayerNorm when `num_groups=1` and it is the same as InstanceNorm when 31 | `num_groups=C`. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | num_groups: int, 37 | num_features: int, 38 | eps: Optional[float] = 1e-5, 39 | affine: Optional[bool] = True, 40 | *args, 41 | **kwargs 42 | ) -> None: 43 | super().__init__( 44 | num_groups=num_groups, num_channels=num_features, eps=eps, affine=affine 45 | ) 46 | -------------------------------------------------------------------------------- /cvnets/layers/pixel_shuffle.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | 9 | class PixelShuffle(nn.PixelShuffle): 10 | """ 11 | Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` 12 | to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. 13 | 14 | Args: 15 | upscale_factor (int): factor to increase spatial resolution by 16 | 17 | Shape: 18 | - Input: :math:`(*, C \times r^2, H, W)`, where * is zero or more dimensions 19 | - Output: :math:`(*, C, H \times r, W \times r)` 20 | """ 21 | 22 | def __init__(self, upscale_factor: int, *args, **kwargs) -> None: 23 | super(PixelShuffle, self).__init__(upscale_factor=upscale_factor) 24 | 25 | def __repr__(self): 26 | return "{}(upscale_factor={})".format( 27 | self.__class__.__name__, self.upscale_factor 28 | ) 29 | -------------------------------------------------------------------------------- /cvnets/layers/random_layers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import random 7 | from typing import List, Optional 8 | 9 | from torch import Tensor 10 | 11 | from cvnets.layers.base_layer import BaseLayer 12 | from utils.math_utils import bound_fn 13 | 14 | 15 | class RandomApply(BaseLayer): 16 | """ 17 | This layer randomly applies a list of modules during training. 18 | 19 | Args: 20 | module_list (List): List of modules 21 | keep_p (Optional[float]): Keep P modules from the list during training. Default: 0.8 (or 80%) 22 | """ 23 | 24 | def __init__( 25 | self, module_list: List, keep_p: Optional[float] = 0.8, *args, **kwargs 26 | ) -> None: 27 | super().__init__() 28 | n_modules = len(module_list) 29 | self.module_list = module_list 30 | 31 | self.module_indexes = [i for i in range(1, n_modules)] 32 | k = int(round(n_modules * keep_p)) 33 | self.keep_k = bound_fn(min_val=1, max_val=n_modules, value=k) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | if self.training: 37 | indexes = [0] + sorted(random.sample(self.module_indexes, k=self.keep_k)) 38 | for idx in indexes: 39 | x = self.module_list[idx](x) 40 | else: 41 | for layer in self.module_list: 42 | x = layer(x) 43 | return x 44 | 45 | def __repr__(self): 46 | format_string = "{}(apply_k (N={})={}, ".format( 47 | self.__class__.__name__, len(self.module_list), self.keep_k 48 | ) 49 | for layer in self.module_list: 50 | format_string += "\n\t {}".format(layer) 51 | format_string += "\n)" 52 | return format_string 53 | -------------------------------------------------------------------------------- /cvnets/layers/softmax.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class Softmax(nn.Softmax): 12 | """ 13 | Applies the Softmax function to an input tensor along the specified dimension 14 | 15 | Args: 16 | dim (int): Dimension along which softmax to be applied. Default: -1 17 | 18 | Shape: 19 | - Input: :math:`(*)` where :math:`*` is one or more dimensions 20 | - Output: same shape as the input 21 | """ 22 | 23 | def __init__(self, dim: Optional[int] = -1, *args, **kwargs): 24 | super().__init__(dim=dim) 25 | -------------------------------------------------------------------------------- /cvnets/layers/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor 7 | from torchvision.ops import StochasticDepth as StochasticDepthTorch 8 | 9 | 10 | class StochasticDepth(StochasticDepthTorch): 11 | """ 12 | Implements the Stochastic Depth `"Deep Networks with Stochastic Depth" 13 | `_ used for randomly dropping residual 14 | branches of residual architectures. 15 | """ 16 | 17 | def __init__(self, p: float, mode: str) -> None: 18 | super().__init__(p=p, mode=mode) 19 | -------------------------------------------------------------------------------- /cvnets/layers/upsample.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | 12 | class UpSample(nn.Upsample): 13 | """ 14 | This layer upsamples a given input tensor. 15 | 16 | Args: 17 | size (Optional[Union[int, Tuple[int, ...]]): Output spatial size. Default: None 18 | scale_factor (Optional[float]): Scale each spatial dimension of the input by this factor. Default: None 19 | mode (Optional[str]): Upsampling algorithm (``'nearest'``, ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. Default: ``'nearest'`` 20 | align_corners (Optional[bool]): if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at 21 | those pixels. This only has effect when :attr:`mode` is ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. 22 | Default: ``None`` 23 | 24 | Shape: 25 | - Input: :math:`(N, C, W_{in})` or :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` 26 | - Output: :math:`(N, C, W_{out})` or :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})` 27 | """ 28 | 29 | def __init__( 30 | self, 31 | size: Optional[Union[int, Tuple[int, ...]]] = None, 32 | scale_factor: Optional[float] = None, 33 | mode: Optional[str] = "nearest", 34 | align_corners: Optional[bool] = None, 35 | *args, 36 | **kwargs 37 | ) -> None: 38 | super().__init__( 39 | size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners 40 | ) 41 | -------------------------------------------------------------------------------- /cvnets/matcher_det/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from cvnets.matcher_det.base_matcher import BaseMatcher 9 | from utils import logger 10 | from utils.registry import Registry 11 | 12 | # register BOX Matcher 13 | MATCHER_REGISTRY = Registry( 14 | "matcher", 15 | base_class=BaseMatcher, 16 | lazy_load_dirs=["cvnets/matcher_det"], 17 | internal_dirs=["internal", "internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_box_matcher(parser: argparse.ArgumentParser): 22 | group = parser.add_argument_group("Matcher", "Matcher") 23 | group.add_argument( 24 | "--matcher.name", 25 | type=str, 26 | help="Name of the matcher. Matcher matches anchors with GT box coordinates", 27 | ) 28 | 29 | # add segmentation specific arguments 30 | parser = MATCHER_REGISTRY.all_arguments(parser) 31 | return parser 32 | 33 | 34 | def build_matcher(opts, *args, **kwargs): 35 | matcher_name = getattr(opts, "matcher.name", None) 36 | # We registered the base class using a special `name` (i.e., `__base__`) 37 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 38 | # be used. Therefore, we raise an error for such cases 39 | if matcher_name == "__base__": 40 | logger.error("__base__ can't be used as a projection name. Please check.") 41 | 42 | matcher = MATCHER_REGISTRY[matcher_name](opts, *args, **kwargs) 43 | return matcher 44 | -------------------------------------------------------------------------------- /cvnets/matcher_det/base_matcher.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | 9 | class BaseMatcher(object): 10 | """ 11 | Base class for matching anchor boxes and labels for the task of object detection 12 | """ 13 | 14 | def __init__(self, opts, *args, **kwargs) -> None: 15 | super(BaseMatcher, self).__init__() 16 | self.opts = opts 17 | 18 | @classmethod 19 | def add_arguments(cls, parser: argparse.ArgumentParser): 20 | """Add class-specific arguments""" 21 | return parser 22 | 23 | def __call__(self, *args, **kwargs): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /cvnets/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/cvnets/misc/__init__.py -------------------------------------------------------------------------------- /cvnets/misc/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/cvnets/misc/third_party/__init__.py -------------------------------------------------------------------------------- /cvnets/models/audio_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /cvnets/models/audio_classification/audio_byteformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Dict, Union 8 | 9 | from torch import Tensor 10 | 11 | from cvnets.models import MODEL_REGISTRY 12 | from cvnets.models.audio_classification.base_audio_classification import ( 13 | BaseAudioClassification, 14 | ) 15 | from cvnets.models.classification.byteformer import ByteFormer 16 | 17 | 18 | @MODEL_REGISTRY.register(name="byteformer", type="audio_classification") 19 | class AudioByteFormer(ByteFormer, BaseAudioClassification): 20 | """Identical to byteformer.ByteFormer, but registered as an audio classification 21 | model.""" 22 | 23 | def forward(self, x: Dict[str, Tensor], *args, **kwargs) -> Tensor: 24 | """ 25 | Perform a forward pass on input bytes. The input is a dictionary 26 | containing the input tensor. The tensor is stored as an integer tensor 27 | of shape [batch_size, sequence_length]. Integer tensors are used because 28 | the tensor usually contains mask tokens. 29 | 30 | Args: 31 | x: A dictionary containing {"audio": audio_bytes}. 32 | 33 | Returns: 34 | The output logits. 35 | """ 36 | return super().forward(x["audio"], *args, **kwargs) 37 | 38 | def dummy_input_and_label(self, batch_size: int) -> Dict: 39 | """ 40 | Get a dummy input and label that could be passed to the model. 41 | 42 | Args: 43 | batch_size: The batch size to use for the generated inputs. 44 | 45 | Returns: 46 | A dict with 47 | { 48 | "samples": {"audio": tensor of shape [batch_size, sequence_length]}, 49 | "targets": tensor of shape [batch_size], 50 | } 51 | """ 52 | input_and_label = super().dummy_input_and_label(batch_size) 53 | 54 | ret = { 55 | "samples": {"audio": input_and_label["samples"]}, 56 | "targets": input_and_label["targets"], 57 | } 58 | return ret 59 | -------------------------------------------------------------------------------- /cvnets/models/audio_classification/base_audio_classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel 9 | 10 | 11 | @MODEL_REGISTRY.register(name="__base__", type="audio_classification") 12 | class BaseAudioClassification(BaseAnyNNModel): 13 | """Base class for audio classification. 14 | 15 | Args: 16 | opts: Command-line arguments 17 | """ 18 | 19 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 20 | super().__init__(opts, *args, **kwargs) 21 | 22 | @classmethod 23 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 24 | """Add model specific arguments""" 25 | if cls != BaseAudioClassification: 26 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 27 | return parser 28 | group = parser.add_argument_group(title=cls.__name__) 29 | group.add_argument( 30 | "--model.audio-classification.name", 31 | type=str, 32 | default=None, 33 | help="Name of the audio classification model. Defaults to None.", 34 | ) 35 | group.add_argument( 36 | "--model.audio-classification.pretrained", 37 | type=str, 38 | default=None, 39 | help="Path of the pretrained backbone. Defaults to None.", 40 | ) 41 | return parser 42 | -------------------------------------------------------------------------------- /cvnets/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /cvnets/models/classification/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/cvnets/models/classification/config/__init__.py -------------------------------------------------------------------------------- /cvnets/models/classification/config/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import math 7 | from typing import Dict 8 | 9 | from utils.math_utils import make_divisible 10 | 11 | 12 | def get_configuration(opts) -> Dict: 13 | width_mult = getattr(opts, "model.classification.mobilenetv1.width_multiplier", 1.0) 14 | 15 | def scale_channels(in_channels): 16 | return make_divisible(int(math.ceil(in_channels * width_mult)), 16) 17 | 18 | config = { 19 | "conv1_out": scale_channels(32), 20 | "layer1": {"out_channels": scale_channels(64), "stride": 1, "repeat": 1}, 21 | "layer2": { 22 | "out_channels": scale_channels(128), 23 | "stride": 2, 24 | "repeat": 1, 25 | }, 26 | "layer3": { 27 | "out_channels": scale_channels(256), 28 | "stride": 2, 29 | "repeat": 1, 30 | }, 31 | "layer4": { 32 | "out_channels": scale_channels(512), 33 | "stride": 2, 34 | "repeat": 5, 35 | }, 36 | "layer5": { 37 | "out_channels": scale_channels(1024), 38 | "stride": 2, 39 | "repeat": 1, 40 | }, 41 | } 42 | return config 43 | -------------------------------------------------------------------------------- /cvnets/models/classification/config/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Dict 7 | 8 | 9 | def get_configuration(opts) -> Dict: 10 | 11 | mobilenetv2_config = { 12 | "layer1": { 13 | "expansion_ratio": 1, 14 | "out_channels": 16, 15 | "num_blocks": 1, 16 | "stride": 1, 17 | }, 18 | "layer2": { 19 | "expansion_ratio": 6, 20 | "out_channels": 24, 21 | "num_blocks": 2, 22 | "stride": 2, 23 | }, 24 | "layer3": { 25 | "expansion_ratio": 6, 26 | "out_channels": 32, 27 | "num_blocks": 3, 28 | "stride": 2, 29 | }, 30 | "layer4": { 31 | "expansion_ratio": 6, 32 | "out_channels": 64, 33 | "num_blocks": 4, 34 | "stride": 2, 35 | }, 36 | "layer4_a": { 37 | "expansion_ratio": 6, 38 | "out_channels": 96, 39 | "num_blocks": 3, 40 | "stride": 1, 41 | }, 42 | "layer5": { 43 | "expansion_ratio": 6, 44 | "out_channels": 160, 45 | "num_blocks": 3, 46 | "stride": 2, 47 | }, 48 | "layer5_a": { 49 | "expansion_ratio": 6, 50 | "out_channels": 320, 51 | "num_blocks": 1, 52 | "stride": 1, 53 | }, 54 | } 55 | return mobilenetv2_config 56 | -------------------------------------------------------------------------------- /cvnets/models/classification/config/mobileone.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | from typing import Dict 7 | 8 | from utils import logger 9 | 10 | 11 | def get_configuration(opts: argparse.Namespace) -> Dict: 12 | """Get configuration of MobileOne models.""" 13 | variant = getattr(opts, "model.classification.mobileone.variant") 14 | config = dict() 15 | 16 | if variant == "s0": 17 | config = { 18 | "num_blocks_per_stage": [2, 8, 10, 1], 19 | "width_multipliers": (0.75, 1.0, 1.0, 2.0), 20 | "num_conv_branches": 4, 21 | "use_se": False, 22 | } 23 | elif variant == "s1": 24 | config = { 25 | "num_blocks_per_stage": [2, 8, 10, 1], 26 | "width_multipliers": (1.5, 1.5, 2.0, 2.5), 27 | "num_conv_branches": 1, 28 | "use_se": False, 29 | } 30 | elif variant == "s2": 31 | config = { 32 | "num_blocks_per_stage": [2, 8, 10, 1], 33 | "width_multipliers": (1.5, 2.0, 2.5, 4.0), 34 | "num_conv_branches": 1, 35 | "use_se": False, 36 | } 37 | elif variant == "s3": 38 | config = { 39 | "num_blocks_per_stage": [2, 8, 10, 1], 40 | "width_multipliers": (2.0, 2.5, 3.0, 4.0), 41 | "num_conv_branches": 1, 42 | "use_se": False, 43 | } 44 | elif variant == "s4": 45 | config = { 46 | "num_blocks_per_stage": [2, 8, 10, 1], 47 | "width_multipliers": (3.0, 3.5, 3.5, 4.0), 48 | "num_conv_branches": 1, 49 | "use_se": True, 50 | } 51 | else: 52 | logger.error( 53 | "MobileOne supported variants: `s0`, `s1`, `s2`, `s3` and `s4`. Please specify variant using " 54 | "--model.classification.mobileone.variant flag. Got: {}".format(variant) 55 | ) 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /cvnets/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from collections import namedtuple 7 | 8 | DetectionPredTuple = namedtuple( 9 | typename="DetectionPredTuple", 10 | field_names=("labels", "scores", "boxes", "masks"), 11 | defaults=(None, None, None, None), 12 | ) 13 | -------------------------------------------------------------------------------- /cvnets/models/detection/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/cvnets/models/detection/utils/__init__.py -------------------------------------------------------------------------------- /cvnets/models/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /cvnets/models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /cvnets/models/segmentation/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /cvnets/models/segmentation/heads/simple_seg_head.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Dict, Optional 8 | 9 | from torch import Tensor 10 | 11 | from cvnets.layers import ConvLayer2d 12 | from cvnets.models import MODEL_REGISTRY 13 | from cvnets.models.segmentation.heads.base_seg_head import BaseSegHead 14 | 15 | 16 | @MODEL_REGISTRY.register(name="simple_seg_head", type="segmentation_head") 17 | class SimpleSegHead(BaseSegHead): 18 | """ 19 | This class defines the simple segmentation head with merely a classification layer. This is useful for performing 20 | linear probling on segmentation task. 21 | Args: 22 | opts: command-line arguments 23 | enc_conf (Dict): Encoder input-output configuration at each spatial level 24 | use_l5_exp (Optional[bool]): Use features from expansion layer in Level5 in the encoder 25 | """ 26 | 27 | def __init__( 28 | self, opts, enc_conf: Dict, use_l5_exp: Optional[bool] = False, *args, **kwargs 29 | ) -> None: 30 | 31 | super().__init__(opts=opts, enc_conf=enc_conf, use_l5_exp=use_l5_exp) 32 | 33 | in_channels = ( 34 | self.enc_l5_channels if not self.use_l5_exp else self.enc_l5_exp_channels 35 | ) 36 | 37 | self.classifier = ConvLayer2d( 38 | opts=opts, 39 | in_channels=in_channels, 40 | out_channels=self.n_seg_classes, 41 | kernel_size=1, 42 | stride=1, 43 | use_norm=False, 44 | use_act=False, 45 | bias=True, 46 | ) 47 | 48 | self.reset_head_parameters(opts=opts) 49 | 50 | @classmethod 51 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 52 | return parser 53 | 54 | def forward_seg_head(self, enc_out: Dict) -> Tensor: 55 | x = enc_out["out_l5_exp"] if self.use_l5_exp else enc_out["out_l5"] 56 | # classify 57 | x = self.classifier(x) 58 | return x 59 | -------------------------------------------------------------------------------- /cvnets/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # isort: skip_file 6 | from cvnets.modules.base_module import BaseModule 7 | from cvnets.modules.squeeze_excitation import SqueezeExcitation 8 | from cvnets.modules.mobilenetv2 import InvertedResidual, InvertedResidualSE 9 | from cvnets.modules.resnet_modules import ( 10 | BasicResNetBlock, 11 | BottleneckResNetBlock, 12 | ) 13 | from cvnets.modules.aspp_block import ASPP 14 | from cvnets.modules.transformer import TransformerEncoder 15 | from cvnets.modules.windowed_transformer import WindowedTransformerEncoder 16 | from cvnets.modules.pspnet_module import PSP 17 | from cvnets.modules.mobilevit_block import MobileViTBlock, MobileViTBlockv2 18 | from cvnets.modules.feature_pyramid import FeaturePyramidNetwork 19 | from cvnets.modules.ssd_heads import SSDHead, SSDInstanceHead 20 | from cvnets.modules.efficientnet import EfficientNetBlock 21 | from cvnets.modules.mobileone_block import MobileOneBlock, RepLKBlock 22 | from cvnets.modules.swin_transformer_block import ( 23 | SwinTransformerBlock, 24 | PatchMerging, 25 | Permute, 26 | ) 27 | from cvnets.modules.regnet_modules import XRegNetBlock, AnyRegNetStage 28 | 29 | 30 | __all__ = [ 31 | "InvertedResidual", 32 | "InvertedResidualSE", 33 | "BasicResNetBlock", 34 | "BottleneckResNetBlock", 35 | "ASPP", 36 | "TransformerEncoder", 37 | "WindowedTransformerEncoder", 38 | "SqueezeExcitation", 39 | "PSP", 40 | "MobileViTBlock", 41 | "MobileViTBlockv2", 42 | "MobileOneBlock", 43 | "RepLKBlock", 44 | "FeaturePyramidNetwork", 45 | "SSDHead", 46 | "SSDInstanceHead", 47 | "EfficientNetBlock", 48 | "SwinTransformerBlock", 49 | "PatchMerging", 50 | "Permute", 51 | "XRegNetBlock", 52 | "AnyRegNetStage", 53 | ] 54 | -------------------------------------------------------------------------------- /cvnets/modules/base_module.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | 12 | class BaseModule(nn.Module): 13 | """Base class for all modules""" 14 | 15 | def __init__(self, *args, **kwargs): 16 | super(BaseModule, self).__init__() 17 | 18 | def forward(self, x: Any, *args, **kwargs) -> Any: 19 | raise NotImplementedError 20 | 21 | def __repr__(self): 22 | return "{}".format(self.__class__.__name__) 23 | -------------------------------------------------------------------------------- /cvnets/modules/efficientnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from cvnets.layers import StochasticDepth 9 | from cvnets.modules import InvertedResidualSE 10 | 11 | 12 | class EfficientNetBlock(InvertedResidualSE): 13 | """ 14 | This class implements a variant of the inverted residual block with squeeze-excitation unit, 15 | as described in `MobileNetv3 `_ paper. This variant 16 | includes stochastic depth, as used in `EfficientNet `_ paper. 17 | 18 | Args: 19 | stochastic_depth_prob: float, 20 | For other arguments, refer to the parent class. 21 | 22 | Shape: 23 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 24 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` 25 | """ 26 | 27 | def __init__(self, stochastic_depth_prob: float, *args, **kwargs) -> None: 28 | super().__init__(*args, **kwargs) 29 | self.stochastic_depth = StochasticDepth(p=stochastic_depth_prob, mode="row") 30 | 31 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: 32 | y = self.block(x) 33 | if self.use_res_connect: 34 | # Pass the output through the stochastic layer module, potentially zeroing it. 35 | y = self.stochastic_depth(y) 36 | # residual connection 37 | y = y + x 38 | return y 39 | 40 | def __repr__(self) -> str: 41 | return ( 42 | super().__repr__()[:-1] 43 | + f", stochastic_depth_prob={self.stochastic_depth.p})" 44 | ) 45 | -------------------------------------------------------------------------------- /cvnets/neural_augmentor/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | # 4 | 5 | import argparse 6 | 7 | from cvnets.neural_augmentor.neural_aug import ( 8 | BaseNeuralAugmentor, 9 | build_neural_augmentor, 10 | ) 11 | 12 | 13 | def arguments_neural_augmentor( 14 | parser: argparse.ArgumentParser, 15 | ) -> argparse.ArgumentParser: 16 | return BaseNeuralAugmentor.add_arguments(parser=parser) 17 | -------------------------------------------------------------------------------- /cvnets/neural_augmentor/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/cvnets/neural_augmentor/utils/__init__.py -------------------------------------------------------------------------------- /cvnets/text_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from cvnets.text_encoders.base_text_encoder import BaseTextEncoder 9 | from utils import logger 10 | from utils.registry import Registry 11 | 12 | TEXT_ENCODER_REGISTRY = Registry( 13 | "text_encoder", 14 | base_class=BaseTextEncoder, 15 | lazy_load_dirs=["cvnets/text_encoders"], 16 | internal_dirs=["internal", "internal/projects/*"], 17 | ) 18 | 19 | 20 | def arguments_text_encoder(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 21 | """Register arguments of all text encoders.""" 22 | # add arguments for text_encoder 23 | parser = BaseTextEncoder.add_arguments(parser) 24 | 25 | # add class specific arguments 26 | parser = TEXT_ENCODER_REGISTRY.all_arguments(parser) 27 | return parser 28 | 29 | 30 | def build_text_encoder(opts, projection_dim: int, *args, **kwargs) -> BaseTextEncoder: 31 | """Helper function to build the text encoder from command-line arguments. 32 | 33 | Args: 34 | opts: Command-line arguments 35 | projection_dim: The dimensionality of the projection head after text encoder. 36 | 37 | Returns: 38 | Text encoder module. 39 | """ 40 | text_encoder_name = getattr(opts, "model.text.name") 41 | 42 | # We registered the base class using a special `name` (i.e., `__base__`) 43 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 44 | # be used. Therefore, we raise an error for such cases 45 | if text_encoder_name == "__base__": 46 | logger.error("__base__ can't be used as a projection name. Please check.") 47 | 48 | text_encoder = TEXT_ENCODER_REGISTRY[text_encoder_name]( 49 | opts, projection_dim, *args, **kwargs 50 | ) 51 | return text_encoder 52 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from data.data_loaders import create_test_loader, create_train_val_loader 7 | -------------------------------------------------------------------------------- /data/datasets/audio_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /data/datasets/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /data/datasets/classification/imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from data.datasets import DATASET_REGISTRY 9 | from data.datasets.classification.base_image_classification_dataset import ( 10 | BaseImageClassificationDataset, 11 | ) 12 | 13 | 14 | @DATASET_REGISTRY.register(name="imagenet", type="classification") 15 | class ImageNetDataset(BaseImageClassificationDataset): 16 | """ 17 | ImageNet dataset that follows the structure of ImageClassificationDataset. 18 | 19 | "ImageNet: A large-scale hierarchical image database" 20 | Jia Deng; Wei Dong; Richard Socher; Li-Jia Li; Kai Li; Li Fei-Fei 21 | 2009 IEEE Conference on Computer Vision and Pattern Recognition 22 | """ 23 | 24 | def __init__( 25 | self, 26 | opts: argparse.Namespace, 27 | *args, 28 | **kwargs, 29 | ) -> None: 30 | BaseImageClassificationDataset.__init__( 31 | self, 32 | opts=opts, 33 | *args, 34 | **kwargs, 35 | ) 36 | -------------------------------------------------------------------------------- /data/datasets/classification/imagenet_a.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """ImageNetA dataset, a distribution shift of ImageNet.""" 6 | import argparse 7 | 8 | from data.datasets import DATASET_REGISTRY 9 | from data.datasets.classification.base_imagenet_shift_dataset import ( 10 | BaseImageNetShiftDataset, 11 | ) 12 | from data.datasets.classification.imagenet_synsets import ( 13 | IMAGENET_A_SYNSETS, 14 | IMAGENET_SYNSETS, 15 | ) 16 | 17 | IMAGENET_A_CLASS_SUBLIST = [ 18 | IMAGENET_SYNSETS.index(IMAGENET_A_SYNSETS[synset]) 19 | for synset in range(len(IMAGENET_A_SYNSETS)) 20 | ] 21 | 22 | 23 | @DATASET_REGISTRY.register(name="imagenet_a", type="classification") 24 | class ImageNetADataset(BaseImageNetShiftDataset): 25 | """ImageNetA dataset, a distribution shift of ImageNet. 26 | 27 | ImageNet-A contains real-world, unmodified natural images that cause model accuracy 28 | to substantially degrade. 29 | 30 | @article{hendrycks2021nae, 31 | title={Natural Adversarial Examples}, 32 | author={Dan Hendrycks and Kevin Zhao and Steven Basart and Jacob Steinhardt and Dawn 33 | Song}, 34 | journal={CVPR}, 35 | year={2021} 36 | } 37 | """ 38 | 39 | def __init__( 40 | self, 41 | opts: argparse.Namespace, 42 | *args, 43 | **kwargs, 44 | ) -> None: 45 | """Initialize ImageNetA.""" 46 | BaseImageNetShiftDataset.__init__(self, opts=opts, *args, **kwargs) 47 | 48 | @staticmethod 49 | def class_id_to_imagenet_class_id(class_id: int) -> int: 50 | """Return the mapped class index using precomputed mapping.""" 51 | return IMAGENET_A_CLASS_SUBLIST[class_id] 52 | -------------------------------------------------------------------------------- /data/datasets/classification/imagenet_r.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """ImageNetR dataset, a distribution shift of ImageNet.""" 6 | import argparse 7 | 8 | from data.datasets import DATASET_REGISTRY 9 | from data.datasets.classification.base_imagenet_shift_dataset import ( 10 | BaseImageNetShiftDataset, 11 | ) 12 | from data.datasets.classification.imagenet_synsets import ( 13 | IMAGENET_R_SYNSETS, 14 | IMAGENET_SYNSETS, 15 | ) 16 | 17 | IMAGENET_R_CLASS_SUBLIST = [ 18 | IMAGENET_SYNSETS.index(IMAGENET_R_SYNSETS[synset]) 19 | for synset in range(len(IMAGENET_R_SYNSETS)) 20 | ] 21 | 22 | 23 | @DATASET_REGISTRY.register(name="imagenet_r", type="classification") 24 | class ImageNetRDataset(BaseImageNetShiftDataset): 25 | """ImageNetR dataset, a distribution shift of ImageNet. 26 | 27 | ImageNet-R(endition) contains art, cartoons, deviantart, graffiti, embroidery, 28 | graphics, origami, paintings, patterns, plastic objects, plush objects, sculptures, 29 | sketches, tattoos, toys, and video game renditions of ImageNet classes. 30 | 31 | @article{hendrycks2021many, 32 | title={The Many Faces of Robustness: A Critical Analysis of Out-of-Distribution 33 | Generalization}, 34 | author={Dan Hendrycks and Steven Basart and Norman Mu and Saurav Kadavath and Frank 35 | Wang and Evan Dorundo and Rahul Desai and Tyler Zhu and Samyak Parajuli and Mike Guo 36 | and Dawn Song and Jacob Steinhardt and Justin Gilmer}, 37 | journal={ICCV}, 38 | year={2021} 39 | } 40 | 41 | """ 42 | 43 | def __init__( 44 | self, 45 | opts: argparse.Namespace, 46 | *args, 47 | **kwargs, 48 | ) -> None: 49 | """Initialize ImageNetR.""" 50 | BaseImageNetShiftDataset.__init__(self, opts=opts, *args, **kwargs) 51 | 52 | @staticmethod 53 | def class_id_to_imagenet_class_id(class_id: int) -> int: 54 | """Return the mapped class index using precomputed mapping.""" 55 | return IMAGENET_R_CLASS_SUBLIST[class_id] 56 | -------------------------------------------------------------------------------- /data/datasets/classification/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """ImageNetSketch dataset, a distribution shift of ImageNet.""" 6 | import argparse 7 | 8 | from data.datasets import DATASET_REGISTRY 9 | from data.datasets.classification.base_imagenet_shift_dataset import ( 10 | BaseImageNetShiftDataset, 11 | ) 12 | 13 | 14 | @DATASET_REGISTRY.register(name="imagenet_sketch", type="classification") 15 | class ImageNetSketchDataset(BaseImageNetShiftDataset): 16 | """ImageNetSketch dataset, a distribution shift of ImageNet. 17 | 18 | Data set is created from Google Image queries "sketch of __", where __ is the 19 | standard class name. Search is only within the "black and white" color scheme. 20 | 21 | @inproceedings{wang2019learning, 22 | title={Learning Robust Global Representations by Penalizing Local Predictive 23 | Power}, 24 | author={Wang, Haohan and Ge, Songwei and Lipton, Zachary and Xing, Eric P}, 25 | booktitle={Advances in Neural Information Processing Systems}, 26 | pages={10506--10518}, 27 | year={2019} 28 | } 29 | """ 30 | 31 | def __init__( 32 | self, 33 | opts: argparse.Namespace, 34 | *args, 35 | **kwargs, 36 | ) -> None: 37 | """Initialize ImageNetSketchDataset.""" 38 | BaseImageNetShiftDataset.__init__(self, opts=opts, *args, **kwargs) 39 | 40 | @staticmethod 41 | def class_id_to_imagenet_class_id(class_id: int) -> int: 42 | """Return `class_id` as the ImageNet Sketch classes are the same as ImageNet.""" 43 | return class_id 44 | -------------------------------------------------------------------------------- /data/datasets/classification/places365.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from data.datasets import DATASET_REGISTRY 9 | from data.datasets.classification.base_image_classification_dataset import ( 10 | BaseImageClassificationDataset, 11 | ) 12 | 13 | 14 | @DATASET_REGISTRY.register(name="places365", type="classification") 15 | class Places365Dataset(BaseImageClassificationDataset): 16 | """ 17 | Places365 dataset that follows the structure of ImageClassificationDataset. 18 | 19 | "Places: A 10 million Image Database for Scene Recognition" 20 | B. Zhou, A. Lapedriza, A. Khosla, A. Oliva, and A. Torralba 21 | IEEE Transactions on Pattern Analysis and Machine Intelligence, 2017 22 | """ 23 | 24 | def __init__( 25 | self, 26 | opts: argparse.Namespace, 27 | *args, 28 | **kwargs, 29 | ) -> None: 30 | BaseImageClassificationDataset.__init__( 31 | self, 32 | opts=opts, 33 | *args, 34 | **kwargs, 35 | ) 36 | -------------------------------------------------------------------------------- /data/datasets/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /data/datasets/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from data.datasets.multi_modal_img_text.base_multi_modal_img_text import ( 9 | BaseMultiModalImgText, 10 | ) 11 | from data.datasets.multi_modal_img_text.zero_shot import arguments_zero_shot_dataset 12 | 13 | 14 | def arguments_multi_modal_img_text( 15 | parser: argparse.ArgumentParser, 16 | ) -> argparse.ArgumentParser: 17 | 18 | parser = arguments_zero_shot_dataset(parser) 19 | parser = BaseMultiModalImgText.add_arguments(parser) 20 | return parser 21 | -------------------------------------------------------------------------------- /data/datasets/multi_modal_img_text/zero_shot/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import os 8 | 9 | from data.datasets.multi_modal_img_text.zero_shot.base_zero_shot import ( 10 | BaseZeroShotDataset, 11 | ) 12 | from utils.registry import Registry 13 | 14 | ZERO_SHOT_DATASET_REGISTRY = Registry( 15 | registry_name="zero_shot_datasets", 16 | base_class=BaseZeroShotDataset, 17 | lazy_load_dirs=["data/datasets/multi_modal_img_text/zero_shot"], 18 | internal_dirs=["internal", "internal/projects/*"], 19 | ) 20 | 21 | 22 | def arguments_zero_shot_dataset( 23 | parser: argparse.ArgumentParser, 24 | ) -> argparse.ArgumentParser: 25 | """Helper function to get zero-shot dataset arguments""" 26 | parser = BaseZeroShotDataset.add_arguments(parser=parser) 27 | parser = ZERO_SHOT_DATASET_REGISTRY.all_arguments(parser) 28 | return parser 29 | 30 | 31 | def build_zero_shot_dataset(opts, *args, **kwargs) -> BaseZeroShotDataset: 32 | """Helper function to build the zero shot datasets""" 33 | zero_shot_dataset_name = getattr( 34 | opts, "dataset.multi_modal_img_text.zero_shot.name" 35 | ) 36 | return ZERO_SHOT_DATASET_REGISTRY[zero_shot_dataset_name](opts, *args, **kwargs) 37 | -------------------------------------------------------------------------------- /data/datasets/multi_modal_img_text/zero_shot/imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import List 8 | 9 | from torchvision.datasets import ImageFolder 10 | 11 | from data.datasets.multi_modal_img_text.zero_shot import ( 12 | ZERO_SHOT_DATASET_REGISTRY, 13 | BaseZeroShotDataset, 14 | ) 15 | from data.datasets.multi_modal_img_text.zero_shot.imagenet_class_names import ( 16 | IMAGENET_CLASS_NAMES, 17 | ) 18 | from data.datasets.multi_modal_img_text.zero_shot.templates import ( 19 | generate_text_prompts_clip, 20 | ) 21 | 22 | 23 | @ZERO_SHOT_DATASET_REGISTRY.register(name="imagenet") 24 | class ImageNetDatasetZeroShot(BaseZeroShotDataset, ImageFolder): 25 | """ImageNet Dataset for zero-shot evaluation of Image-text models. 26 | 27 | Args: 28 | opts: Command-line arguments 29 | """ 30 | 31 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 32 | BaseZeroShotDataset.__init__(self, opts=opts, *args, **kwargs) 33 | root = self.root 34 | ImageFolder.__init__( 35 | self, root=root, transform=None, target_transform=None, is_valid_file=None 36 | ) 37 | 38 | # TODO: Refactor BaseZeroShotDataset to inherit from 39 | # BaseImageClassificationDataset then inherit from ImageNetDataset instead of 40 | # ImageFolder. Rename the base class to BaseZeroShotClassificationDataset. 41 | assert len(list(self.class_to_idx.keys())) == len(self.class_names()), ( 42 | "Number of classes from ImageFolder do not match the number of ImageNet" 43 | " classes." 44 | ) 45 | 46 | @classmethod 47 | def class_names(cls) -> List[str]: 48 | """Return the name of the classes present in the dataset.""" 49 | return IMAGENET_CLASS_NAMES 50 | 51 | @staticmethod 52 | def generate_text_prompts(class_name: str) -> List[str]: 53 | """Return a list of prompts for the given class name.""" 54 | return generate_text_prompts_clip(class_name) 55 | 56 | def __len__(self) -> int: 57 | """Return the number of samples in the dataset.""" 58 | return super(ImageFolder, self).__len__() 59 | -------------------------------------------------------------------------------- /data/datasets/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/data/datasets/segmentation/__init__.py -------------------------------------------------------------------------------- /data/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /data/datasets/utils/text.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import re 7 | import urllib 8 | 9 | import ftfy 10 | 11 | 12 | def caption_preprocessing(caption: str) -> str: 13 | """Removes the unwanted tokens (e.g., HTML tokens, next line, unwanted spaces) from 14 | the text.""" 15 | # captions may contain HTML tokens. Remove them 16 | html_re = re.compile("<.*?>") 17 | caption = urllib.parse.unquote(str(caption)) 18 | caption = caption.replace("+", " ") 19 | caption = re.sub(html_re, "", str(caption)) 20 | # remove the next line 21 | caption = caption.strip("\n") 22 | # remove unwanted spaces 23 | caption = re.sub(" +", " ", caption) 24 | 25 | caption = ftfy.fix_text(caption) 26 | return caption.strip().lower() 27 | -------------------------------------------------------------------------------- /data/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/data/loader/__init__.py -------------------------------------------------------------------------------- /data/loader/dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Optional, Union 7 | 8 | from torch.utils.data import DataLoader 9 | 10 | from data.datasets.dataset_base import BaseDataset 11 | from data.sampler import Sampler 12 | 13 | 14 | class CVNetsDataLoader(DataLoader): 15 | """This class extends PyTorch's Dataloader""" 16 | 17 | def __init__( 18 | self, 19 | dataset: BaseDataset, 20 | batch_size: int, 21 | batch_sampler: Union[Sampler], 22 | num_workers: Optional[int] = 1, 23 | pin_memory: Optional[bool] = False, 24 | persistent_workers: Optional[bool] = False, 25 | collate_fn: Optional = None, 26 | prefetch_factor: Optional[int] = 2, 27 | *args, 28 | **kwargs 29 | ): 30 | super(CVNetsDataLoader, self).__init__( 31 | dataset=dataset, 32 | batch_size=batch_size, 33 | batch_sampler=batch_sampler, 34 | num_workers=num_workers, 35 | pin_memory=pin_memory, 36 | persistent_workers=persistent_workers, 37 | collate_fn=collate_fn, 38 | prefetch_factor=prefetch_factor, 39 | *args, 40 | **kwargs 41 | ) 42 | 43 | def update_indices(self, new_indices: List, *args, **kwargs): 44 | """Update indices in the dataset class""" 45 | if hasattr(self.batch_sampler, "img_indices") and hasattr( 46 | self.batch_sampler, "update_indices" 47 | ): 48 | self.batch_sampler.update_indices(new_indices) 49 | 50 | def samples_in_dataset(self): 51 | """Number of samples in the dataset""" 52 | return len(self) 53 | 54 | def get_sample_indices(self) -> List: 55 | """Sample IDs""" 56 | return self.batch_sampler.img_indices 57 | -------------------------------------------------------------------------------- /data/text_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from data.text_tokenizer.base_tokenizer import BaseTokenizer 9 | from utils import logger 10 | from utils.registry import Registry 11 | 12 | TOKENIZER_REGISTRY = Registry( 13 | "tokenizer", 14 | base_class=BaseTokenizer, 15 | lazy_load_dirs=["data/text_tokenizer"], 16 | internal_dirs=["internal", "internal/projects/*"], 17 | ) 18 | 19 | 20 | def arguments_tokenizer(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 21 | # add arguments for text_tokenizer 22 | parser = BaseTokenizer.add_arguments(parser) 23 | 24 | # add class specific arguments 25 | parser = TOKENIZER_REGISTRY.all_arguments(parser) 26 | return parser 27 | 28 | 29 | def build_tokenizer(opts, *args, **kwargs) -> BaseTokenizer: 30 | """Helper function to build the text tokenizer from command-line arguments. 31 | 32 | Args: 33 | opts: Command-line arguments 34 | 35 | Returns: 36 | Image projection head module. 37 | """ 38 | tokenizer_name = getattr(opts, "text_tokenizer.name", None) 39 | 40 | # We registered the base class using a special `name` (i.e., `__base__`) 41 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 42 | # be used. Therefore, we raise an error for such cases 43 | if tokenizer_name == "__base__": 44 | logger.error("__base__ can't be used as a projection name. Please check.") 45 | 46 | tokenizer = TOKENIZER_REGISTRY[tokenizer_name](opts, *args, **kwargs) 47 | return tokenizer 48 | -------------------------------------------------------------------------------- /data/text_tokenizer/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Any 8 | 9 | from torch import nn 10 | 11 | 12 | class BaseTokenizer(nn.Module): 13 | def __init__(self, opts, *args, **kwargs): 14 | super().__init__() 15 | self.opts = opts 16 | 17 | @classmethod 18 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 19 | group = parser.add_argument_group(title=cls.__name__) 20 | 21 | group.add_argument( 22 | "--text-tokenizer.name", 23 | type=str, 24 | default=None, 25 | help="Name of the text tokenizer.", 26 | ) 27 | 28 | return parser 29 | 30 | def get_vocab_size(self): 31 | raise NotImplementedError 32 | 33 | def get_eot_token(self): 34 | raise NotImplementedError 35 | 36 | def get_sot_token(self): 37 | raise NotImplementedError 38 | 39 | def get_encodings(self): 40 | raise NotImplementedError 41 | 42 | def forward(self, input_sentence: Any, *args, **kwargs) -> Any: 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from data.transforms.base_transforms import BaseTransformation 9 | from utils.registry import Registry 10 | 11 | TRANSFORMATIONS_REGISTRY = Registry( 12 | "transformation", 13 | base_class=BaseTransformation, 14 | lazy_load_dirs=["data/transforms"], 15 | internal_dirs=["internal", "internal/projects/*"], 16 | ) 17 | 18 | 19 | def arguments_augmentation(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 20 | # add arguments for base image projection layer 21 | parser = BaseTransformation.add_arguments(parser) 22 | 23 | # add augmentation specific arguments 24 | parser = TRANSFORMATIONS_REGISTRY.all_arguments(parser) 25 | return parser 26 | -------------------------------------------------------------------------------- /data/transforms/audio_aux/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/data/transforms/audio_aux/__init__.py -------------------------------------------------------------------------------- /data/transforms/base_transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Dict 8 | 9 | 10 | class BaseTransformation(object): 11 | """ 12 | Base class for augmentation methods 13 | """ 14 | 15 | def __init__(self, opts, *args, **kwargs) -> None: 16 | self.opts = opts 17 | 18 | def __call__(self, data: Dict) -> Dict: 19 | raise NotImplementedError 20 | 21 | def __repr__(self) -> str: 22 | return "{}()".format(self.__class__.__name__) 23 | 24 | @classmethod 25 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 26 | return parser 27 | -------------------------------------------------------------------------------- /data/transforms/common.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Dict, List 6 | 7 | from data.transforms import TRANSFORMATIONS_REGISTRY, BaseTransformation 8 | 9 | 10 | @TRANSFORMATIONS_REGISTRY.register(name="compose", type="common") 11 | class Compose(BaseTransformation): 12 | """ 13 | This method applies a list of transforms in a sequential fashion. 14 | """ 15 | 16 | def __init__(self, opts, img_transforms: List, *args, **kwargs) -> None: 17 | super().__init__(opts=opts) 18 | self.img_transforms = img_transforms 19 | 20 | def __call__(self, data: Dict) -> Dict: 21 | for t in self.img_transforms: 22 | data = t(data) 23 | return data 24 | 25 | def __repr__(self) -> str: 26 | transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.img_transforms) 27 | repr_str = "{}({}\n\t\t)".format(self.__class__.__name__, transform_str) 28 | return repr_str 29 | -------------------------------------------------------------------------------- /data/transforms/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any 7 | 8 | import numpy as np 9 | 10 | 11 | def setup_size(size: Any, error_msg="Need a tuple of length 2"): 12 | if size is None: 13 | raise ValueError("Size can't be None") 14 | 15 | if isinstance(size, int): 16 | return size, size 17 | elif isinstance(size, (list, tuple)) and len(size) == 1: 18 | return size[0], size[0] 19 | 20 | if len(size) != 2: 21 | raise ValueError(error_msg) 22 | 23 | return size 24 | 25 | 26 | def intersect(box_a, box_b): 27 | """Computes the intersection between box_a and box_b""" 28 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 29 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 30 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 31 | return inter[:, 0] * inter[:, 1] 32 | 33 | 34 | def jaccard_numpy(box_a: np.ndarray, box_b: np.ndarray): 35 | """ 36 | Computes the intersection of two boxes. 37 | Args: 38 | box_a (np.ndarray): Boxes of shape [Num_boxes_A, 4] 39 | box_b (np.ndarray): Box osf shape [Num_boxes_B, 4] 40 | 41 | Returns: 42 | intersection over union scores. Shape is [box_a.shape[0], box_a.shape[1]] 43 | """ 44 | inter = intersect(box_a, box_b) 45 | area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) # [A,B] 46 | area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) # [A,B] 47 | union = area_a + area_b - inter 48 | return inter / union # [A,B] 49 | -------------------------------------------------------------------------------- /data/video_reader/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from data.video_reader.base_av_reader import BaseAVReader 9 | from utils import logger 10 | from utils.ddp_utils import is_master 11 | from utils.registry import Registry 12 | 13 | VIDEO_READER_REGISTRY = Registry( 14 | "video_reader", 15 | base_class=BaseAVReader, 16 | lazy_load_dirs=["data/video_reader"], 17 | internal_dirs=["internal", "internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_video_reader(parser: argparse.ArgumentParser): 22 | parser = BaseAVReader.add_arguments(parser=parser) 23 | 24 | # add video reader specific arguments 25 | parser = VIDEO_READER_REGISTRY.all_arguments(parser) 26 | return parser 27 | 28 | 29 | def get_video_reader(opts, *args, **kwargs) -> BaseAVReader: 30 | """Helper function to build the video reader from command-line arguments. 31 | 32 | Args: 33 | opts: Command-line arguments 34 | is_training: 35 | 36 | Returns: 37 | Image projection head module. 38 | """ 39 | 40 | video_reader_name = getattr(opts, "video_reader.name") 41 | 42 | # We registered the base class using a special `name` (i.e., `__base__`) 43 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 44 | # be used. Therefore, we raise an error for such cases 45 | if video_reader_name == "__base__": 46 | logger.error("__base__ can't be used as a projection name. Please check.") 47 | 48 | video_reader = VIDEO_READER_REGISTRY[video_reader_name](opts, *args, **kwargs) 49 | 50 | is_master_node = is_master(opts) 51 | if is_master_node: 52 | logger.log("Video reader details: ") 53 | print("{}".format(video_reader)) 54 | return video_reader 55 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | _build/ 3 | source/autogen/ 4 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/docs/.nojekyll -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS = 7 | SPHINXBUILD = python -msphinx 8 | SPHINXPROJ = cvnets 9 | SOURCEDIR = source 10 | BUILDDIR = _build 11 | AUTOGENDIR = $(SOURCEDIR)/autogen 12 | 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | 18 | .PHONY: help Makefile 19 | 20 | # Catch-all target: route all unknown targets to Sphinx using the new 21 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 22 | %: Makefile 23 | echo '$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)' 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | 26 | rst: 27 | sphinx-apidoc -f -o $(AUTOGENDIR) ../data/ 28 | sphinx-apidoc -f -o $(AUTOGENDIR) ../cvnets/ 29 | sphinx-apidoc -f -o $(AUTOGENDIR) ../engine/ 30 | sphinx-apidoc -f -o $(AUTOGENDIR) ../loss_fn/ 31 | sphinx-apidoc -f -o $(AUTOGENDIR) ../loss_landscape/ 32 | sphinx-apidoc -f -o $(AUTOGENDIR) ../optim/ 33 | sphinx-apidoc -f -o $(AUTOGENDIR) ../metrics/ 34 | sphinx-apidoc -f -o $(AUTOGENDIR) ../options/ 35 | sphinx-apidoc -f -o $(AUTOGENDIR) ../utils/ 36 | rm $(AUTOGENDIR)/modules.rst 37 | 38 | github_pages: 39 | @make rst 40 | @make clean html 41 | @cp -a ./_build/html/ . 42 | -------------------------------------------------------------------------------- /docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/docs/__init__.py -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source_static 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/en/models/classification/README-mobilevit.md: -------------------------------------------------------------------------------- 1 | # Training MobileViT on the ImageNet dataset 2 | 3 | Single node 8-GPU training of `MobileViT-S` can be done using below command: 4 | 5 | ``` 6 | export CFG_FILE="config/classification/imagenet/mobilevit.yaml" 7 | cvnets-train --common.config-file $CFG_FILE --common.results-loc classification_results 8 | ``` 9 | 10 | ***Note***: Do not forget to change the training and validation dataset locations in configuration files. 11 | 12 | ## Citation 13 | 14 | ``` 15 | @inproceedings{mehta2022mobilevit, 16 | title={MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer}, 17 | author={Sachin Mehta and Mohammad Rastegari}, 18 | booktitle={International Conference on Learning Representations}, 19 | year={2022}, 20 | url={https://openreview.net/forum?id=vh-0sUt8HlG} 21 | } 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/source/en/models/classification/README-resnet.md: -------------------------------------------------------------------------------- 1 | # Training ResNets on the ImageNet dataset 2 | 3 | Single node 8-GPU training of `ResNet-50` with `simple training recipe` can be done using below command: 4 | 5 | ``` 6 | export CFG_FILE="config/classification/imagenet/resnet.yaml" 7 | cvnets-train --common.config-file $CFG_FILE --common.results-loc classification_results 8 | ``` 9 | 10 | For advanced training recipe, see [this](../../../../../config/classification/imagenet/resnet_adv.yaml) configuration file. 11 | 12 | ***Note***: Do not forget to change the training and validation dataset locations in configuration files. 13 | 14 |
15 | 16 | Single node 8-GPU training of ResNet-101 with simple training recipe 17 | 18 | 19 | ``` 20 | export CFG_FILE="config/classification/imagenet/resnet.yaml" 21 | cvnets-train --common.config-file $CFG_FILE --common.results-loc classification_results --common.override-kwargs model.classification.resnet.depth=101 22 | ``` 23 |
24 | 25 | 26 |
27 | 28 | Single node 8-GPU training of ResNet-34 with simple training recipe 29 | 30 | 31 | ``` 32 | export CFG_FILE="config/classification/imagenet/resnet.yaml" 33 | cvnets-train --common.config-file $CFG_FILE --common.results-loc classification_results --common.override-kwargs model.classification.resnet.depth=34 34 | ``` 35 |
36 | 37 | ## Citation 38 | 39 | ``` 40 | @inproceedings{he2016deep, 41 | title={Deep residual learning for image recognition}, 42 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 43 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 44 | pages={770--778}, 45 | year={2016} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /docs/source/en/models/classification/README-robustness-evaluations.md: -------------------------------------------------------------------------------- 1 | # Evaluating classification models on ImageNet distribution shift datasets 2 | 3 | We support evaluating on ImageNet-A/R/Sketch datasets. For evaluation, download 4 | these datasets and set `root_val` path to the root directory of the dataset. 5 | 6 | ## Downloading datasets 7 | Please follow instructions for each datasets from the following links: 8 | - [ImageNet-A](https://github.com/hendrycks/natural-adv-examples) 9 | - [ImageNet-R](https://github.com/hendrycks/imagenet-r) 10 | - [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch) 11 | 12 | 13 | ## Evaluating a classification model 14 | 15 | Evaluation can be done using the below command. Please set the `root_val` in 16 | the configuration to the dataset path. 17 | 18 | ``` 19 | export CFG_FILE="PATH_TO_MODEL_CONFIGURATION_FILE" 20 | export MODEL_WEIGHTS="PATH_TO_MODEL_WEIGHTS_FILE" 21 | CUDA_VISIBLE_DEVICES=0 cvnets-eval --common.config-file $CFG_FILE --common.results-loc classification_results --model.classification.pretrained $MODEL_WEIGHTS 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/source/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | installation 8 | en/general/README-config-files-intro 9 | en/general/README-directory-structure 10 | en/models/classification/README-classification-tutorial.md 11 | en/models/detection/README-detection-SSD-tutorial.md 12 | en/models/segmentation/README-segmentation-deeplabv3-tutorial.md 13 | -------------------------------------------------------------------------------- /docs/source/how_to.rst: -------------------------------------------------------------------------------- 1 | How To 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | en/general/README-pytorch-to-coreml 8 | en/general/README-new-dataset 9 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. ml-cvnets documentation master file, created by 2 | sphinx-quickstart on Mon Dec 6 10:41:36 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to CVNets documentation! 7 | ===================================== 8 | 9 | CVNets is an open-source library for training deep neural networks for visual recognition tasks, 10 | including classification, detection, and segmentation. 11 | 12 | CVNets supports image and video understanding tools, including data loading, data transformations, novel data sampling methods, 13 | and implementations of several state-of-the-art networks. 14 | 15 | Our source code is available on `Github `_ . 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: Table of Contents 21 | 22 | getting_started 23 | sample_recipes 24 | how_to 25 | data_samplers 26 | en/general/README-model-zoo 27 | 28 | 29 | Citation 30 | ======== 31 | 32 | If you find CVNets useful, please cite the following papers: 33 | 34 | .. code-block:: 35 | 36 | @inproceedings{mehta2022mobilevit, 37 | title={MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer}, 38 | author={Sachin Mehta and Mohammad Rastegari}, 39 | booktitle={International Conference on Learning Representations}, 40 | year={2022} 41 | } 42 | 43 | @inproceedings{mehta2022cvnets, 44 | author = {Mehta, Sachin and Abdolhosseini, Farzad and Rastegari, Mohammad}, 45 | title = {CVNets: High Performance Library for Computer Vision}, 46 | year = {2022}, 47 | booktitle = {Proceedings of the 30th ACM International Conference on Multimedia}, 48 | series = {MM '22} 49 | } 50 | 51 | Indices and tables 52 | ================== 53 | 54 | * :ref:`genindex` 55 | * :ref:`modindex` 56 | .. * :ref:`search` 57 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | =============== 3 | 4 | CVNets can be installed in the local python environment using the below command: 5 | 6 | .. code-block:: bash 7 | 8 | git clone git@github.com:apple/ml-cvnets.git 9 | cd ml-cvnets 10 | pip install -r requirements.txt 11 | pip install --editable . 12 | 13 | 14 | We recommend to use Python 3.6+ and PyTorch (version >= v1.8.0) with conda environment. For setting-up python environment with conda, see `here `_. 15 | -------------------------------------------------------------------------------- /docs/source/models.md: -------------------------------------------------------------------------------- 1 | # Available Models 2 | 3 | 4 | ## ResNet 5 | 6 | ## MobileNetv1 7 | ## MobileNetv2 8 | ## MobileNetv3 9 | 10 | ## ViT 11 | ## MobileViT 12 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | data 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | data 8 | cvnets 9 | engine 10 | loss_fn 11 | loss_landscape 12 | optim 13 | metrics 14 | options 15 | utils 16 | -------------------------------------------------------------------------------- /docs/source/sample_recipes.rst: -------------------------------------------------------------------------------- 1 | Sample Recipes 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | en/models/classification/README-resnet.md 8 | en/models/classification/README-mobilenets.md 9 | en/models/classification/README-mobilevit.md 10 | en/models/classification/README-mobilevit-v2.md 11 | en/models/classification/README-vit.md 12 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from engine.evaluation_engine import Evaluator 7 | from engine.training_engine import Trainer 8 | -------------------------------------------------------------------------------- /engine/detection_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/engine/detection_utils/__init__.py -------------------------------------------------------------------------------- /engine/segmentation_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/engine/segmentation_utils/__init__.py -------------------------------------------------------------------------------- /engine/segmentation_utils/cityscapes_iou.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import glob 7 | import os 8 | 9 | import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_semseg_eval 10 | 11 | from utils import logger 12 | 13 | 14 | def eval_cityscapes(pred_dir: str, gt_dir: str) -> None: 15 | """Utility to evaluate on cityscapes dataset""" 16 | cityscapes_semseg_eval.args.predictionPath = pred_dir 17 | cityscapes_semseg_eval.args.predictionWalk = None 18 | cityscapes_semseg_eval.args.JSONOutput = False 19 | cityscapes_semseg_eval.args.colorized = False 20 | 21 | gt_img_list = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png")) 22 | if len(gt_img_list) == 0: 23 | logger.error("Cannot find ground truth images at: {}".format(gt_dir)) 24 | 25 | pred_img_list = [] 26 | for gt in gt_img_list: 27 | pred_img_list.append( 28 | cityscapes_semseg_eval.getPrediction(cityscapes_semseg_eval.args, gt) 29 | ) 30 | 31 | results = cityscapes_semseg_eval.evaluateImgLists( 32 | pred_img_list, gt_img_list, cityscapes_semseg_eval.args 33 | ) 34 | 35 | logger.info("Evaluation results summary") 36 | eval_res_str = "\n\t IoU_cls: {:.2f} \n\t iIOU_cls: {:.2f} \n\t IoU_cat: {:.2f} \n\t iIOU_cat: {:.2f}".format( 37 | 100.0 * results["averageScoreClasses"], 38 | 100.0 * results["averageScoreInstClasses"], 39 | 100.0 * results["averageScoreCategories"], 40 | 100.0 * results["averageScoreInstCategories"], 41 | ) 42 | print(eval_res_str) 43 | -------------------------------------------------------------------------------- /examples/byteformer/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/examples/byteformer/model_arch.png -------------------------------------------------------------------------------- /examples/byteformer/speech_commands_mp3/conv_kernel_size=4,window_size=[128].yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | common: 3 | run_label: train 4 | log_freq: 500 5 | auto_resume: true 6 | mixed_precision: true 7 | tensorboard_logging: false 8 | accum_freq: 2 9 | dataset: 10 | root_train: /mnt/audio_datasets/google_speech_recognition_v2 11 | root_val: /mnt/audio_datasets/google_speech_recognition_v2 12 | name: speech_commands_v2 13 | category: audio_classification 14 | train_batch_size0: 48 15 | val_batch_size0: 48 16 | eval_batch_size0: 48 17 | workers: 10 18 | persistent_workers: false 19 | pin_memory: true 20 | collate_fn_name_train: byteformer_audio_collate_fn 21 | collate_fn_name_val: byteformer_audio_collate_fn 22 | collate_fn_name_test: byteformer_audio_collate_fn 23 | speech_commands_v2: 24 | mixup: true 25 | audio_augmentation: 26 | noise: 27 | enable: true 28 | levels: 29 | - -50 30 | refresh_freq: 100 31 | roll: 32 | enable: true 33 | window: 0.1 34 | torchaudio_save: 35 | enable: true 36 | format: mp3 37 | sampler: 38 | name: batch_sampler 39 | loss: 40 | category: classification 41 | classification: 42 | name: cross_entropy 43 | cross_entropy: 44 | label_smoothing: 0.1 45 | optim: 46 | name: adamw 47 | weight_decay: 0.05 48 | no_decay_bn_filter_bias: true 49 | adamw: 50 | beta1: 0.9 51 | beta2: 0.999 52 | scheduler: 53 | name: cosine 54 | is_iteration_based: false 55 | max_epochs: 300 56 | warmup_iterations: 500 57 | warmup_init_lr: 1.0e-06 58 | cosine: 59 | max_lr: 0.001 60 | min_lr: 2.0e-05 61 | model: 62 | audio_classification: 63 | name: byteformer 64 | classification: 65 | name: byteformer 66 | byteformer: 67 | mode: tiny 68 | max_num_tokens: 50000 69 | conv_kernel_size: 4 70 | window_sizes: 71 | - 128 72 | n_classes: 12 73 | activation: 74 | name: gelu 75 | layer: 76 | global_pool: mean 77 | conv_init: kaiming_uniform 78 | linear_init: trunc_normal 79 | linear_init_std_dev: 0.02 80 | ema: 81 | enable: true 82 | momentum: 0.0001 83 | stats: 84 | val: 85 | - loss 86 | - top1 87 | - top5 88 | train: 89 | - loss 90 | checkpoint_metric: top1 91 | checkpoint_metric_max: true 92 | -------------------------------------------------------------------------------- /examples/byteformer/speech_commands_mp3/conv_kernel_size=4,window_size=[32].yaml: -------------------------------------------------------------------------------- 1 | common: 2 | run_label: train 3 | log_freq: 500 4 | auto_resume: true 5 | mixed_precision: true 6 | tensorboard_logging: false 7 | accum_freq: 2 8 | dataset: 9 | root_train: /mnt/audio_datasets/google_speech_recognition_v2 10 | root_val: /mnt/audio_datasets/google_speech_recognition_v2 11 | name: speech_commands_v2 12 | category: audio_classification 13 | train_batch_size0: 48 14 | val_batch_size0: 48 15 | eval_batch_size0: 48 16 | workers: 10 17 | persistent_workers: false 18 | pin_memory: true 19 | collate_fn_name_train: byteformer_audio_collate_fn 20 | collate_fn_name_val: byteformer_audio_collate_fn 21 | collate_fn_name_test: byteformer_audio_collate_fn 22 | speech_commands_v2: 23 | mixup: true 24 | audio_augmentation: 25 | noise: 26 | enable: true 27 | levels: 28 | - -50 29 | refresh_freq: 100 30 | roll: 31 | enable: true 32 | window: 0.1 33 | torchaudio_save: 34 | enable: true 35 | format: mp3 36 | sampler: 37 | name: batch_sampler 38 | loss: 39 | category: classification 40 | classification: 41 | name: cross_entropy 42 | cross_entropy: 43 | label_smoothing: 0.1 44 | optim: 45 | name: adamw 46 | weight_decay: 0.05 47 | no_decay_bn_filter_bias: true 48 | adamw: 49 | beta1: 0.9 50 | beta2: 0.999 51 | scheduler: 52 | name: cosine 53 | is_iteration_based: false 54 | max_epochs: 300 55 | warmup_iterations: 500 56 | warmup_init_lr: 1.0e-06 57 | cosine: 58 | max_lr: 0.001 59 | min_lr: 2.0e-05 60 | model: 61 | audio_classification: 62 | name: byteformer 63 | classification: 64 | name: byteformer 65 | byteformer: 66 | mode: tiny 67 | max_num_tokens: 50000 68 | conv_kernel_size: 4 69 | window_sizes: 70 | - 32 71 | n_classes: 12 72 | activation: 73 | name: gelu 74 | layer: 75 | global_pool: mean 76 | conv_init: kaiming_uniform 77 | linear_init: trunc_normal 78 | linear_init_std_dev: 0.02 79 | ema: 80 | enable: true 81 | momentum: 0.0001 82 | stats: 83 | val: 84 | - loss 85 | - top1 86 | - top5 87 | train: 88 | - loss 89 | checkpoint_metric: top1 90 | checkpoint_metric_max: true 91 | -------------------------------------------------------------------------------- /examples/byteformer/speech_commands_mp3/conv_kernel_size=8,window_size=[128].yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | common: 3 | run_label: train 4 | log_freq: 500 5 | auto_resume: true 6 | mixed_precision: true 7 | tensorboard_logging: false 8 | accum_freq: 2 9 | dataset: 10 | root_train: /mnt/audio_datasets/google_speech_recognition_v2 11 | root_val: /mnt/audio_datasets/google_speech_recognition_v2 12 | name: speech_commands_v2 13 | category: audio_classification 14 | train_batch_size0: 48 15 | val_batch_size0: 48 16 | eval_batch_size0: 48 17 | workers: 10 18 | persistent_workers: false 19 | pin_memory: true 20 | collate_fn_name_train: byteformer_audio_collate_fn 21 | collate_fn_name_val: byteformer_audio_collate_fn 22 | collate_fn_name_test: byteformer_audio_collate_fn 23 | speech_commands_v2: 24 | mixup: true 25 | audio_augmentation: 26 | noise: 27 | enable: true 28 | levels: 29 | - -50 30 | refresh_freq: 100 31 | roll: 32 | enable: true 33 | window: 0.1 34 | torchaudio_save: 35 | enable: true 36 | format: mp3 37 | sampler: 38 | name: batch_sampler 39 | loss: 40 | category: classification 41 | classification: 42 | name: cross_entropy 43 | cross_entropy: 44 | label_smoothing: 0.1 45 | optim: 46 | name: adamw 47 | weight_decay: 0.05 48 | no_decay_bn_filter_bias: true 49 | adamw: 50 | beta1: 0.9 51 | beta2: 0.999 52 | scheduler: 53 | name: cosine 54 | is_iteration_based: false 55 | max_epochs: 300 56 | warmup_iterations: 500 57 | warmup_init_lr: 1.0e-06 58 | cosine: 59 | max_lr: 0.001 60 | min_lr: 2.0e-05 61 | model: 62 | audio_classification: 63 | name: byteformer 64 | classification: 65 | name: byteformer 66 | byteformer: 67 | mode: tiny 68 | max_num_tokens: 50000 69 | conv_kernel_size: 8 70 | window_sizes: 71 | - 128 72 | n_classes: 12 73 | activation: 74 | name: gelu 75 | layer: 76 | global_pool: mean 77 | conv_init: kaiming_uniform 78 | linear_init: trunc_normal 79 | linear_init_std_dev: 0.02 80 | ema: 81 | enable: true 82 | momentum: 0.0001 83 | stats: 84 | val: 85 | - loss 86 | - top1 87 | - top5 88 | train: 89 | - loss 90 | checkpoint_metric: top1 91 | checkpoint_metric_max: true 92 | -------------------------------------------------------------------------------- /examples/byteformer/speech_commands_mp3/conv_kernel_size=8,window_size=[32].yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | common: 3 | run_label: train 4 | log_freq: 500 5 | auto_resume: true 6 | mixed_precision: true 7 | tensorboard_logging: false 8 | accum_freq: 2 9 | dataset: 10 | root_train: /mnt/audio_datasets/google_speech_recognition_v2 11 | root_val: /mnt/audio_datasets/google_speech_recognition_v2 12 | name: speech_commands_v2 13 | category: audio_classification 14 | train_batch_size0: 48 15 | val_batch_size0: 48 16 | eval_batch_size0: 48 17 | workers: 10 18 | persistent_workers: false 19 | pin_memory: true 20 | collate_fn_name_train: byteformer_audio_collate_fn 21 | collate_fn_name_val: byteformer_audio_collate_fn 22 | collate_fn_name_test: byteformer_audio_collate_fn 23 | speech_commands_v2: 24 | mixup: true 25 | audio_augmentation: 26 | noise: 27 | enable: true 28 | levels: 29 | - -50 30 | refresh_freq: 100 31 | roll: 32 | enable: true 33 | window: 0.1 34 | torchaudio_save: 35 | enable: true 36 | format: mp3 37 | sampler: 38 | name: batch_sampler 39 | loss: 40 | category: classification 41 | classification: 42 | name: cross_entropy 43 | cross_entropy: 44 | label_smoothing: 0.1 45 | optim: 46 | name: adamw 47 | weight_decay: 0.05 48 | no_decay_bn_filter_bias: true 49 | adamw: 50 | beta1: 0.9 51 | beta2: 0.999 52 | scheduler: 53 | name: cosine 54 | is_iteration_based: false 55 | max_epochs: 300 56 | warmup_iterations: 500 57 | warmup_init_lr: 1.0e-06 58 | cosine: 59 | max_lr: 0.001 60 | min_lr: 2.0e-05 61 | model: 62 | audio_classification: 63 | name: byteformer 64 | classification: 65 | name: byteformer 66 | byteformer: 67 | mode: tiny 68 | max_num_tokens: 50000 69 | conv_kernel_size: 8 70 | window_sizes: 71 | - 32 72 | n_classes: 12 73 | activation: 74 | name: gelu 75 | layer: 76 | global_pool: mean 77 | conv_init: kaiming_uniform 78 | linear_init: trunc_normal 79 | linear_init_std_dev: 0.02 80 | ema: 81 | enable: true 82 | momentum: 0.0001 83 | stats: 84 | val: 85 | - loss 86 | - top1 87 | - top5 88 | train: 89 | - loss 90 | checkpoint_metric: top1 91 | checkpoint_metric_max: true 92 | -------------------------------------------------------------------------------- /examples/range_augment/README.md: -------------------------------------------------------------------------------- 1 | # RangeAugment: Efficient Online Augmentation with Range Learning 2 | 3 | [RangeAugment](https://arxiv.org/abs/2212.10553) is an automatic augmentation method that allows us to learn `model- and task-specific` magnitude range of each augmentation operation. 4 | 5 | We provide training and evaluation code along with pretrained models and configuration files for the following tasks: 6 | 7 | 1. [Image Classification on the ImageNet dataset](./README-classification.md) 8 | 2. [Semantic segmentation on the ADE20k and the PASCAL VOC datasets](./README-segmentation.md) 9 | 3. [Object detection on the MS-COCO dataset](./README-object-detection.md) 10 | 4. [Contrastive Learning using Image-Text pairs](./README-clip.md) 11 | 5. [Distillation on the ImageNet dataset](./README-distillation.md) 12 | 13 | ***Note***: In the [codebase](../../cvnets/neural_augmentor), we refer RangeAugment as Neural Augmentor (or NA). 14 | 15 | 16 | ## Citation 17 | 18 | If you find our work useful, please cite: 19 | 20 | ``` 21 | @article{mehta2022rangeaugment, 22 | title={RangeAugment: Efficient Online Augmentation with Range Learning}, 23 | author = {Mehta, Sachin and Naderiparizi, Saeid and Faghri, Fartash and Horton, Maxwell and Chen, Lailin and Farhadi, Ali and Tuzel, Oncel and Rastegari, Mohammad}, 24 | journal={arXiv preprint arXiv:2212.10553}, 25 | year={2022}, 26 | url={https://arxiv.org/abs/2212.10553}, 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /loss_fn/base_criteria.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import abc 7 | import argparse 8 | from typing import Any 9 | 10 | from torch import nn 11 | 12 | from utils import logger 13 | 14 | 15 | class BaseCriteria(nn.Module, abc.ABC): 16 | """Base class for defining loss functions. Sub-classes must implement compute_loss function. 17 | 18 | Args: 19 | opts: command line arguments 20 | """ 21 | 22 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 23 | super(BaseCriteria, self).__init__() 24 | self.opts = opts 25 | # small value for numerical stability purposes that sub-classes may want to use. 26 | self.eps = 1e-7 27 | 28 | @classmethod 29 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 30 | """Add criterion-specific arguments to the parser.""" 31 | if cls != BaseCriteria: 32 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 33 | return parser 34 | group = parser.add_argument_group(cls.__name__) 35 | 36 | group.add_argument( 37 | "--loss.category", 38 | type=str, 39 | default=None, 40 | help="Loss function category (e.g., classification). Defaults to None.", 41 | ) 42 | return parser 43 | 44 | @abc.abstractmethod 45 | def forward( 46 | self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs 47 | ) -> Any: 48 | """Compute the loss. 49 | 50 | Args: 51 | input_sample: Input to the model. 52 | prediction: Model's output 53 | target: Ground truth labels 54 | """ 55 | raise NotImplementedError 56 | 57 | def extra_repr(self) -> str: 58 | return "" 59 | 60 | def __repr__(self) -> str: 61 | return "{}({}\n)".format(self.__class__.__name__, self.extra_repr()) 62 | -------------------------------------------------------------------------------- /loss_fn/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /loss_fn/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /loss_fn/detection/base_detection_criteria.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from loss_fn import LOSS_REGISTRY, BaseCriteria 7 | 8 | 9 | @LOSS_REGISTRY.register(name="__base__", type="detection") 10 | class BaseDetectionCriteria(BaseCriteria): 11 | """Base class for defining detection loss functions. Sub-classes must implement forward function. 12 | 13 | Args: 14 | opts: command line arguments 15 | """ 16 | 17 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 18 | super().__init__(opts, *args, **kwargs) 19 | 20 | @classmethod 21 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | if cls != BaseDetectionCriteria: 23 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 24 | return parser 25 | 26 | group = parser.add_argument_group(cls.__name__) 27 | group.add_argument( 28 | "--loss.detection.name", 29 | type=str, 30 | default=None, 31 | help=f"Name of the loss function in {cls.__name__}. Defaults to None.", 32 | ) 33 | return parser 34 | -------------------------------------------------------------------------------- /loss_fn/distillation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /loss_fn/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /loss_fn/multi_modal_img_text/base_multi_modal_img_text_criteria.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from loss_fn import LOSS_REGISTRY, BaseCriteria 7 | 8 | 9 | @LOSS_REGISTRY.register(name="__base__", type="multi_modal_image_text") 10 | class BaseMultiModalImageTextCriteria(BaseCriteria): 11 | """Base class for defining multi-modal image-text loss functions. Sub-classes must implement forward function. 12 | 13 | Args: 14 | opts: command line arguments 15 | """ 16 | 17 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 18 | super().__init__(opts, *args, **kwargs) 19 | 20 | @classmethod 21 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | if cls != BaseMultiModalImageTextCriteria: 23 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 24 | return parser 25 | 26 | group = parser.add_argument_group(cls.__name__) 27 | group.add_argument( 28 | "--loss.multi-modal-image-text.name", 29 | type=str, 30 | default=None, 31 | help="Name of the loss function. Defaults to None.", 32 | ) 33 | return parser 34 | -------------------------------------------------------------------------------- /loss_fn/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /loss_fn/segmentation/base_segmentation_criteria.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from loss_fn import LOSS_REGISTRY, BaseCriteria 7 | 8 | 9 | @LOSS_REGISTRY.register(name="__base__", type="segmentation") 10 | class BaseSegmentationCriteria(BaseCriteria): 11 | """Base class for defining segmentation loss functions. Sub-classes must implement forward function. 12 | 13 | Args: 14 | opts: command line arguments 15 | """ 16 | 17 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 18 | super().__init__(opts, *args, **kwargs) 19 | 20 | @classmethod 21 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | if cls != BaseSegmentationCriteria: 23 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 24 | return parser 25 | 26 | group = parser.add_argument_group(cls.__name__) 27 | group.add_argument( 28 | "--loss.segmentation.name", 29 | type=str, 30 | default=None, 31 | help="Name of the loss function. Defaults to None.", 32 | ) 33 | return parser 34 | -------------------------------------------------------------------------------- /loss_fn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/loss_fn/utils/__init__.py -------------------------------------------------------------------------------- /loss_fn/utils/build_helper.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from torch import nn 7 | 8 | from common import is_test_env 9 | from cvnets.models import get_model 10 | from options.utils import extract_opts_with_prefix_replacement 11 | from utils import logger 12 | 13 | 14 | def build_cls_teacher_from_opts(opts: argparse.Namespace) -> nn.Module: 15 | """Helper function to build a classification teacher model from command-line arguments 16 | 17 | Args: 18 | opts: command-line arguments 19 | 20 | Returns: 21 | A teacher model 22 | """ 23 | pretrained_model = getattr(opts, "teacher.model.classification.pretrained") 24 | 25 | pytest_env = is_test_env() 26 | if not pytest_env and pretrained_model is None: 27 | logger.error( 28 | "For distillation, please specify teacher weights using teacher.model.classification.pretrained" 29 | ) 30 | teacher_opts = extract_opts_with_prefix_replacement( 31 | opts, "teacher.model.", "model." 32 | ) 33 | 34 | # build teacher model 35 | return get_model(teacher_opts, category="classification") 36 | -------------------------------------------------------------------------------- /loss_fn/utils/class_weighting.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | 10 | def compute_class_weights( 11 | target: Tensor, n_classes: int, norm_val: float = 1.1 12 | ) -> Tensor: 13 | """Implementation of a class-weighting scheme, as defined in Section 5.2 14 | of `ENet `_ paper. 15 | 16 | Args: 17 | target: Tensor of shape [Batch_size, *] containing values in the range `[0, C)`. 18 | n_classes: Integer specifying the number of classes :math:`C` 19 | norm_val: Normalization value. Defaults to 1.1. This value is decided based on the 20 | `ESPNetv2 paper `_. 21 | Link: https://github.com/sacmehta/ESPNetv2/blob/b78e323039908f31347d8ca17f49d5502ef1a594/segmentation/loadData.py#L16 22 | 23 | Returns: 24 | A :math:`C`-dimensional tensor containing class weights 25 | """ 26 | 27 | class_hist = torch.histc(target.float(), bins=n_classes, min=0, max=n_classes - 1) 28 | print(class_hist) 29 | mask_indices = class_hist == 0 30 | 31 | # normalize between 0 and 1 by dividing by the sum 32 | norm_hist = torch.div(class_hist, class_hist.sum()) 33 | print(norm_hist) 34 | norm_hist = torch.add(norm_hist, norm_val) 35 | 36 | # compute class weights. 37 | # samples with more frequency will have less weight and vice-versa 38 | class_wts = torch.div(torch.ones_like(class_hist), torch.log(norm_hist)) 39 | 40 | # mask the classes which do not have samples in the current batch 41 | class_wts[mask_indices] = 0.0 42 | 43 | return class_wts.to(device=target.device) 44 | -------------------------------------------------------------------------------- /loss_landscape/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/loss_landscape/__init__.py -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from utils.registry import Registry 9 | 10 | METRICS_REGISTRY = Registry( 11 | "metrics", 12 | lazy_load_dirs=["metrics"], 13 | internal_dirs=["internal", "internal/projects/*"], 14 | ) 15 | 16 | 17 | def arguments_stats(parser: argparse.ArgumentParser): 18 | group = parser.add_argument_group(title="Statistics", description="Statistics") 19 | group.add_argument( 20 | "--stats.val", type=str, default=["loss"], nargs="+", help="Name of statistics" 21 | ) 22 | group.add_argument( 23 | "--stats.train", 24 | type=str, 25 | default=["loss"], 26 | nargs="+", 27 | help="Name of statistics", 28 | ) 29 | group.add_argument( 30 | "--stats.checkpoint-metric", 31 | type=str, 32 | default="loss", 33 | help="Metric to use for saving checkpoints", 34 | ) 35 | group.add_argument( 36 | "--stats.checkpoint-metric-max", 37 | action="store_true", 38 | default=False, 39 | help="Maximize checkpoint metric", 40 | ) 41 | group.add_argument( 42 | "--stats.coco-map.iou-types", 43 | type=str, 44 | default=["bbox"], 45 | nargs="+", 46 | choices=("bbox", "segm"), 47 | help="Types of IOU to compute for MSCoco.", 48 | ) 49 | 50 | return parser 51 | -------------------------------------------------------------------------------- /metrics/average_precision.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import traceback 7 | from numbers import Number 8 | from typing import Dict, Union 9 | 10 | import numpy as np 11 | from sklearn.metrics import average_precision_score 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | 15 | from metrics import METRICS_REGISTRY 16 | from metrics.metric_base import EpochMetric 17 | from utils import logger 18 | 19 | 20 | @METRICS_REGISTRY.register("average_precision") 21 | class AveragePrecisionMetric(EpochMetric): 22 | def compute_with_aggregates( 23 | self, y_pred: Tensor, y_true: Tensor 24 | ) -> Union[Number, Dict[str, Number]]: 25 | y_pred, y_true = self.get_aggregates() 26 | 27 | y_pred = F.softmax(y_pred, dim=-1).numpy().astype(np.float32) 28 | y_true = y_true.numpy().astype(np.float32) 29 | 30 | # Clip predictions to reduce chance of getting INF 31 | y_pred = y_pred.clip(0, 1) 32 | 33 | if y_pred.ndim == 1 or y_pred.ndim == 2 and y_pred.shape[1] == 1: 34 | pass # TODO? 35 | elif y_pred.ndim == 2 and y_pred.shape[1] == 2: 36 | y_pred = y_pred[:, 1] 37 | else: 38 | logger.warning( 39 | "Expected only two classes, got prediction Tensor of shape {}".format( 40 | y_pred.shape 41 | ) 42 | ) 43 | 44 | try: 45 | ap = 100 * average_precision_score(y_true, y_pred, average=None) 46 | except ValueError as e: 47 | logger.warning("Could not compute Average Precision: {}".format(str(e))) 48 | traceback.print_exc() 49 | ap = 0 # we don't want the job to fail over a metric computation issue 50 | 51 | return ap 52 | -------------------------------------------------------------------------------- /metrics/metric_base_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any, Dict, Union 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from metrics.metric_base import AverageMetric 12 | 13 | 14 | class DummyMetric(AverageMetric): 15 | def gather_metrics( 16 | self, 17 | prediction: Union[Tensor, Dict], 18 | target: Union[Tensor, Dict], 19 | extras: Dict[str, Any], 20 | ) -> Union[Tensor, Dict[str, Tensor]]: 21 | return prediction 22 | 23 | 24 | def test_average_metric_distributed_batchsize(mocker): 25 | mocker.patch("torch.distributed.is_initialized", return_value=True) 26 | mocker.patch("torch.distributed.get_world_size", return_value=2) 27 | mocker.patch("torch.distributed.all_reduce", lambda x, *_, **__: x.add_(1)) 28 | 29 | metric = DummyMetric(None, is_distributed=True) 30 | metric.update(torch.tensor([2.0]), None, batch_size=torch.tensor([2])) 31 | 32 | # Value is 2 and batch size is 2, but we're simulating the second device 33 | # having value 1 and batch size 1 by making sure all_reduce adds 1 to both 34 | # the value and the batch size. It's as if we have [2, 2] in GPU1 and [1] 35 | # in GPU 2. Therefore the expected average is 5/3. 36 | 37 | expected_value = (2 * 2 + 1 * 1) / 3 38 | assert metric.compute() == expected_value 39 | -------------------------------------------------------------------------------- /metrics/probability_histograms.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from numbers import Number 8 | from typing import Dict, Union 9 | 10 | import numpy as np 11 | from torch import Tensor 12 | from torch.nn import functional as F 13 | 14 | from metrics import METRICS_REGISTRY 15 | from metrics.metric_base import EpochMetric 16 | from utils import logger 17 | 18 | 19 | @METRICS_REGISTRY.register("prob_hist") 20 | class ProbabilityHistogramMetric(EpochMetric): 21 | def __init__( 22 | self, 23 | opts: argparse.Namespace = None, 24 | is_distributed: bool = False, 25 | pred: str = None, 26 | target: str = None, 27 | ): 28 | super().__init__(opts, is_distributed, pred, target) 29 | self.num_bins = getattr(self.opts, "stats.metrics.prob_hist.num_bins") 30 | 31 | @classmethod 32 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 33 | """Add metric specific arguments""" 34 | if cls == ProbabilityHistogramMetric: 35 | parser.add_argument( 36 | "--stats.metrics.prob-hist.num-bins", type=int, default=10 37 | ) 38 | return parser 39 | 40 | def compute_with_aggregates( 41 | self, y_pred: Tensor, y_true: Tensor 42 | ) -> Union[Number, Dict[str, Number]]: 43 | y_pred = F.softmax(y_pred, dim=-1).numpy() 44 | y_true = y_true.numpy() 45 | 46 | max_confs = y_pred.max(axis=-1) 47 | max_hist = np.histogram(max_confs, bins=self.num_bins, range=[0, 1])[0] 48 | max_hist = max_hist / max_hist.sum() 49 | 50 | target_confs = np.take_along_axis(y_pred, y_true.reshape(-1, 1), 1) 51 | target_hist = np.histogram(target_confs, bins=self.num_bins, range=[0, 1])[0] 52 | target_hist = target_hist / target_hist.sum() 53 | 54 | return { 55 | "max": max_hist.tolist(), 56 | "target": target_hist.tolist(), 57 | } 58 | -------------------------------------------------------------------------------- /optim/sgd.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Dict, Iterable, Union 8 | 9 | from torch import Tensor 10 | from torch.optim import SGD 11 | 12 | from optim import OPTIM_REGISTRY 13 | from optim.base_optim import BaseOptim 14 | 15 | 16 | @OPTIM_REGISTRY.register(name="sgd") 17 | class SGDOptimizer(BaseOptim, SGD): 18 | """ 19 | `SGD `_ optimizer 20 | 21 | Args: 22 | opts: Command-line arguments 23 | model_params: Model parameters 24 | """ 25 | 26 | def __init__( 27 | self, 28 | opts: argparse.Namespace, 29 | model_params: Iterable[Union[Tensor, Dict]], 30 | *args, 31 | **kwargs 32 | ) -> None: 33 | BaseOptim.__init__(self, opts=opts) 34 | nesterov = getattr(opts, "optim.sgd.nesterov") 35 | momentum = getattr(opts, "optim.sgd.momentum") 36 | 37 | SGD.__init__( 38 | self, 39 | params=model_params, 40 | lr=self.lr, 41 | momentum=momentum, 42 | weight_decay=self.weight_decay, 43 | nesterov=nesterov, 44 | ) 45 | 46 | @classmethod 47 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 48 | if cls != SGDOptimizer: 49 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 50 | return parser 51 | group = parser.add_argument_group(cls.__name__) 52 | group.add_argument( 53 | "--optim.sgd.momentum", 54 | default=0.9, 55 | type=float, 56 | help="The value of momemtum in SGD. Defaults to 0.9", 57 | ) 58 | group.add_argument( 59 | "--optim.sgd.nesterov", 60 | action="store_true", 61 | default=False, 62 | help="Use nesterov momentum in SGD. Defaults to False.", 63 | ) 64 | return parser 65 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/options/__init__.py -------------------------------------------------------------------------------- /options/errors.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from common import is_test_env 7 | 8 | 9 | class UnrecognizedYamlConfigEntry(Warning): 10 | # TODO: consider converting UnrecognizedYamlConfigEntry Warning to an Exception. 11 | def __init__(self, key: str) -> None: 12 | message = ( 13 | f"Yaml config key '{key}' was not recognized by argparser. If you think that you have already added " 14 | f"argument in options/opts.py file, then check for typos. If not, then please add it to options/opts.py." 15 | ) 16 | super().__init__(message) 17 | 18 | if is_test_env(): 19 | # Currently, we only raise an exception in test environment. 20 | raise ValueError(message) 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | 4 | [tool.black] 5 | extend-exclude = '.history' 6 | 7 | [tool.pytest.ini_options] 8 | junit_family = 'xunit2' 9 | addopts = '--junit-xml=./build/test-results/junit_reports/junit.xml' 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psutil 2 | scikit-learn 3 | scikit-image 4 | 5 | # requirement for Pytorch, Torchvision, TorchText 6 | torch 7 | torchvision 8 | torchtext 9 | torchaudio 10 | torchdata 11 | 12 | # dependency for coremltools 13 | coremltools 14 | 15 | # dependency for MSCOCO dataset 16 | pycocotools 17 | 18 | # dependency for cityscape evaluation 19 | cityscapesscripts 20 | 21 | # added as a dependency to reproduce 3rd party models 22 | pytorchvideo 23 | 24 | # PyAV for video decoding 25 | av 26 | 27 | # FVCore for FLOP calculation 28 | fvcore 29 | 30 | # black for reformatting 31 | black==22.10.0 32 | 33 | isort==5.12.0 34 | 35 | # testing 36 | pytest 37 | pytest-mock 38 | 39 | # torchtext for multi-model learning 40 | ftfy 41 | 42 | # for hdf5 reading 43 | h5py 44 | 45 | # for reading byte data 46 | pybase64 47 | -------------------------------------------------------------------------------- /requirements_docs.txt: -------------------------------------------------------------------------------- 1 | # docs 2 | sphinx 3 | sphinx-rtd-theme 4 | sphinx-argparse 5 | myst-parser 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/__init__.py -------------------------------------------------------------------------------- /tests/configs.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from options.opts import get_training_arguments 9 | from options.utils import load_config_file 10 | 11 | 12 | def get_config( 13 | config_file: str = None, disable_ddp_distributed: bool = True 14 | ) -> argparse.Namespace: 15 | """Produces a resolved config (i.e. opts) object to be used in tests. 16 | 17 | Args: 18 | config_file: If provided, the contents of the @config_file path will override 19 | the default configs. 20 | disable_ddp_distributed: ``ddp.distributed`` config entry is not defined in 21 | the parser, but rather set by the entrypoints on the fly based on the 22 | availability of multiple gpus. In the tests, we usually don't want to use 23 | ``ddp.distributed``, even if multiple gpus are available. 24 | """ 25 | parser = get_training_arguments(parse_args=False) 26 | opts = parser.parse_args([]) 27 | setattr(opts, "common.config_file", config_file) 28 | opts = load_config_file(opts) 29 | 30 | if disable_ddp_distributed: 31 | setattr(opts, "ddp.use_distributed", False) 32 | 33 | return opts 34 | 35 | 36 | # If slow, this can be turned into a "session"-scoped fixture 37 | # @pytest.fixture(scope='session') 38 | def default_training_opts() -> argparse.Namespace: 39 | opts = get_training_arguments(args=[]) 40 | return opts 41 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/audio_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/data/datasets/audio_classification/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/image_classification_dataset.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_train: "tests/data/datasets/classification/dummy_images/training" 4 | root_val: "tests/data/datasets/classification/dummy_images/validation" 5 | collate_fn_name_train: "image_classification_data_collate_fn" 6 | collate_fn_name_val: "image_classification_data_collate_fn" 7 | collate_fn_name_test: "image_classification_data_collate_fn" 8 | name: "dummy" 9 | category: "classification" 10 | train_batch_size0: 2 11 | val_batch_size0: 4 12 | eval_batch_size0: 4 13 | workers: 8 14 | persistent_workers: true 15 | pin_memory: true 16 | image_augmentation: 17 | # training related parameters 18 | random_resized_crop: 19 | enable: true 20 | interpolation: "bilinear" 21 | random_horizontal_flip: 22 | enable: true 23 | auto_augment: 24 | enable: true 25 | cutmix: 26 | alpha: 1.0 27 | enable: true 28 | p: 1.0 29 | mixup: 30 | alpha: 0.2 31 | enable: true 32 | p: 1.0 33 | # validation related parameters 34 | resize: 35 | enable: true 36 | size: 232 37 | interpolation: "bilinear" 38 | center_crop: 39 | enable: true 40 | size: 224 41 | 42 | sampler: 43 | name: "batch_sampler" 44 | bs: 45 | crop_size_width: 256 46 | crop_size_height: 256 47 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_train: "/mnt/imagenet/training" 4 | root_val: "/mnt/imagenet/validation" 5 | name: "imagenet" 6 | category: "classification" 7 | train_batch_size0: 2 8 | val_batch_size0: 2 9 | eval_batch_size0: 2 10 | workers: 8 11 | persistent_workers: true 12 | pin_memory: true 13 | 14 | image_augmentation: 15 | # training related parameters 16 | random_resized_crop: 17 | enable: true 18 | interpolation: "bilinear" 19 | random_horizontal_flip: 20 | enable: true 21 | auto_augment: 22 | enable: true 23 | cutmix: 24 | alpha: 1.0 25 | enable: true 26 | p: 1.0 27 | mixup: 28 | alpha: 0.2 29 | enable: true 30 | p: 1.0 31 | # validation related parameters 32 | resize: 33 | enable: true 34 | size: 232 35 | interpolation: "bilinear" 36 | center_crop: 37 | enable: true 38 | size: 224 39 | 40 | sampler: 41 | name: "variable_batch_sampler" 42 | vbs: 43 | crop_size_width: 224 44 | crop_size_height: 224 45 | max_n_scales: 5 46 | min_crop_size_width: 128 47 | max_crop_size_width: 320 48 | min_crop_size_height: 128 49 | max_crop_size_height: 320 50 | check_scale: 32 51 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_val: "/mnt/vision_datasets/imagenet-a-1.0.0/data/raw/" 4 | name: "imagenet_a" 5 | category: "classification" 6 | train_batch_size0: 2 7 | val_batch_size0: 2 8 | eval_batch_size0: 2 9 | workers: 8 10 | persistent_workers: true 11 | pin_memory: true 12 | 13 | model: 14 | classification: 15 | n_classes: 1000 16 | 17 | image_augmentation: 18 | # training related parameters 19 | random_resized_crop: 20 | enable: true 21 | interpolation: "bilinear" 22 | random_horizontal_flip: 23 | enable: true 24 | auto_augment: 25 | enable: true 26 | cutmix: 27 | alpha: 1.0 28 | enable: true 29 | p: 1.0 30 | mixup: 31 | alpha: 0.2 32 | enable: true 33 | p: 1.0 34 | # validation related parameters 35 | resize: 36 | enable: true 37 | size: 232 38 | interpolation: "bilinear" 39 | center_crop: 40 | enable: true 41 | size: 224 42 | 43 | sampler: 44 | name: "variable_batch_sampler" 45 | vbs: 46 | crop_size_width: 224 47 | crop_size_height: 224 48 | max_n_scales: 5 49 | min_crop_size_width: 128 50 | max_crop_size_width: 320 51 | min_crop_size_height: 128 52 | max_crop_size_height: 320 53 | check_scale: 32 54 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_val: "/mnt/vision_datasets/imagenet-r-1.0.0/data/raw/" 4 | name: "imagenet_r" 5 | category: "classification" 6 | train_batch_size0: 2 7 | val_batch_size0: 2 8 | eval_batch_size0: 2 9 | workers: 8 10 | persistent_workers: true 11 | pin_memory: true 12 | 13 | model: 14 | classification: 15 | n_classes: 1000 16 | 17 | image_augmentation: 18 | # training related parameters 19 | random_resized_crop: 20 | enable: true 21 | interpolation: "bilinear" 22 | random_horizontal_flip: 23 | enable: true 24 | auto_augment: 25 | enable: true 26 | cutmix: 27 | alpha: 1.0 28 | enable: true 29 | p: 1.0 30 | mixup: 31 | alpha: 0.2 32 | enable: true 33 | p: 1.0 34 | # validation related parameters 35 | resize: 36 | enable: true 37 | size: 232 38 | interpolation: "bilinear" 39 | center_crop: 40 | enable: true 41 | size: 224 42 | 43 | sampler: 44 | name: "variable_batch_sampler" 45 | vbs: 46 | crop_size_width: 224 47 | crop_size_height: 224 48 | max_n_scales: 5 49 | min_crop_size_width: 128 50 | max_crop_size_width: 320 51 | min_crop_size_height: 128 52 | max_crop_size_height: 320 53 | check_scale: 32 54 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_val: "/mnt/vision_datasets/imagenet-sketch-1.0.0/data/raw/" 4 | name: "imagenet_sketch" 5 | category: "classification" 6 | train_batch_size0: 2 7 | val_batch_size0: 2 8 | eval_batch_size0: 2 9 | workers: 8 10 | persistent_workers: true 11 | pin_memory: true 12 | 13 | model: 14 | classification: 15 | n_classes: 1000 16 | 17 | image_augmentation: 18 | # training related parameters 19 | random_resized_crop: 20 | enable: true 21 | interpolation: "bilinear" 22 | random_horizontal_flip: 23 | enable: true 24 | auto_augment: 25 | enable: true 26 | cutmix: 27 | alpha: 1.0 28 | enable: true 29 | p: 1.0 30 | mixup: 31 | alpha: 0.2 32 | enable: true 33 | p: 1.0 34 | # validation related parameters 35 | resize: 36 | enable: true 37 | size: 232 38 | interpolation: "bilinear" 39 | center_crop: 40 | enable: true 41 | size: 224 42 | 43 | sampler: 44 | name: "variable_batch_sampler" 45 | vbs: 46 | crop_size_width: 224 47 | crop_size_height: 224 48 | max_n_scales: 5 49 | min_crop_size_width: 128 50 | max_crop_size_width: 320 51 | min_crop_size_height: 128 52 | max_crop_size_height: 320 53 | check_scale: 32 54 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/data/datasets/multi_modal_img_text/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/data/datasets/multi_modal_img_text/zero_shot/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot/dummy_imagenet_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | category: "multi_modal_image_text" 4 | multi_modal_img_text: 5 | context_length: 77 6 | zero_shot_eval: true 7 | zero_shot: 8 | name: "imagenet" 9 | root_val: "/mnt/imagenet/validation" 10 | 11 | image_augmentation: 12 | # training related parameters 13 | random_resized_crop: 14 | enable: true 15 | interpolation: "bilinear" 16 | random_horizontal_flip: 17 | enable: true 18 | auto_augment: 19 | enable: true 20 | cutmix: 21 | alpha: 1.0 22 | enable: true 23 | p: 1.0 24 | mixup: 25 | alpha: 0.2 26 | enable: true 27 | p: 1.0 28 | # validation related parameters 29 | resize: 30 | enable: true 31 | size: 232 32 | interpolation: "bilinear" 33 | center_crop: 34 | enable: true 35 | size: 224 36 | 37 | sampler: 38 | name: "variable_batch_sampler" 39 | vbs: 40 | crop_size_width: 224 41 | crop_size_height: 224 42 | max_n_scales: 5 43 | min_crop_size_width: 128 44 | max_crop_size_width: 320 45 | min_crop_size_height: 128 46 | max_crop_size_height: 320 47 | check_scale: 32 48 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot/mock_imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import random 7 | from typing import Optional 8 | 9 | from PIL import Image 10 | 11 | from data.datasets.multi_modal_img_text.zero_shot.imagenet import ( 12 | ImageNetDatasetZeroShot, 13 | generate_text_prompts_clip, 14 | ) 15 | 16 | TOTAL_SAMPLES = 100 17 | 18 | 19 | class MockImageNetDatasetZeroShot(ImageNetDatasetZeroShot): 20 | """Mock the ImageNetDatasetZeroShot without initializing from image folders.""" 21 | 22 | def __init__( 23 | self, 24 | opts: argparse.Namespace, 25 | is_training: bool = False, 26 | is_evaluation: bool = False, 27 | *args, 28 | **kwargs 29 | ) -> None: 30 | """Mock the init logic for ImageNet dataset. 31 | 32 | Specifically, we replace the samples and targets with random data so that actual 33 | dataset is not required for testing purposes. 34 | """ 35 | # super() is not called here intentionally. 36 | self.opts = opts 37 | self.root = None 38 | self.samples = [ 39 | ["img_path", random.randint(1, 4)] for _ in range(TOTAL_SAMPLES) 40 | ] 41 | self.text_prompts = [ 42 | generate_text_prompts_clip(class_name) for class_name in ["cat", "dog"] 43 | ] 44 | self.targets = [class_id for img_path, class_id in self.samples] 45 | self.imgs = [img_path for img_path, class_id in self.samples] 46 | self.is_training = is_training 47 | self.is_evaluation = is_evaluation 48 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot/test_mock_imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Type 7 | 8 | from tests.configs import get_config 9 | from tests.data.datasets.multi_modal_img_text.zero_shot.mock_imagenet import ( 10 | MockImageNetDatasetZeroShot, 11 | ) 12 | 13 | 14 | def test_imagenet_dataset_zero_shot( 15 | config_file_path: str = "tests/data/datasets/multi_modal_img_text/zero_shot/dummy_imagenet_config.yaml", 16 | mock_dataset_class: Type[MockImageNetDatasetZeroShot] = MockImageNetDatasetZeroShot, 17 | ) -> None: 18 | """Test for ImageNet zero-shot. 19 | 20 | Similar test to ImageNet test but only for validation because zero-shot datasets are 21 | not supposed to be used for training. We also test the text prompts. 22 | """ 23 | opts = get_config(config_file=config_file_path) 24 | 25 | imagenet_zero_shot_dataset = mock_dataset_class( 26 | opts, is_training=False, is_evaluation=False 27 | ) 28 | 29 | for image_id in range(2): 30 | data = imagenet_zero_shot_dataset[image_id] 31 | # values from config file 32 | assert len(data) == 3, "ImageNet zero shot should return a tuple of 3." 33 | img_path, text_prompts, target = data 34 | assert isinstance(img_path, str), "ImageNet zero shot should return (str, ...)." 35 | assert ( 36 | isinstance(text_prompts, list) 37 | and isinstance(text_prompts[0], list) 38 | and isinstance(text_prompts[0][0], str) 39 | ), "ImageNet zero shot should return (..., list[list[str]], ...)." 40 | assert isinstance(target, int), "ImageNet zero shot should return (..., int)." 41 | -------------------------------------------------------------------------------- /tests/data/datasets/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/segmentation/dummy_ade20k_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_train: "/mnt/vision_datasets/ADEChallengeData2016/" 4 | root_val: "/mnt/vision_datasets/ADEChallengeData2016/" 5 | name: "ade20k" 6 | category: "segmentation" 7 | train_batch_size0: 4 8 | val_batch_size0: 4 9 | eval_batch_size0: 2 10 | workers: 4 11 | persistent_workers: false 12 | pin_memory: false 13 | image_augmentation: 14 | random_crop: 15 | enable: true 16 | seg_class_max_ratio: 0.75 17 | pad_if_needed: true 18 | mask_fill: 0 # background idx is 0 19 | random_horizontal_flip: 20 | enable: true 21 | resize: 22 | enable: true 23 | size: [512, 512] 24 | interpolation: "bilinear" 25 | random_short_size_resize: 26 | enable: true 27 | interpolation: "bilinear" 28 | short_side_min: 256 29 | short_side_max: 768 30 | max_img_dim: 1024 31 | photo_metric_distort: 32 | enable: true 33 | random_rotate: 34 | enable: true 35 | angle: 10 36 | mask_fill: 0 # background idx is 0 37 | random_gaussian_noise: 38 | enable: true 39 | sampler: 40 | name: "batch_sampler" 41 | bs: 42 | crop_size_width: 512 43 | crop_size_height: 512 44 | evaluation: 45 | segmentation: 46 | resize_input_images: false 47 | -------------------------------------------------------------------------------- /tests/data/datasets/segmentation/mock_ade20k.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | from typing import Optional 7 | 8 | from PIL import Image 9 | 10 | from data.datasets.segmentation.ade20k import ADE20KDataset 11 | 12 | TOTAL_SAMPLES = 100 13 | 14 | 15 | class MockADE20KDataset(ADE20KDataset): 16 | def __init__( 17 | self, 18 | opts: argparse.Namespace, 19 | is_training: bool = False, 20 | is_evaluation: bool = False, 21 | *args, 22 | **kwargs 23 | ) -> None: 24 | """Mock the init logic for ImageNet dataset 25 | 26 | Specifically, we replace the samples and targets with random data so that actual dataset is not 27 | required for testing purposes. 28 | """ 29 | # super() is not called here intentionally. 30 | self.opts = opts 31 | self.root = None 32 | self.images = ["dummy_img_path.jpg" for _ in range(TOTAL_SAMPLES)] 33 | self.masks = ["dummy_mask_path.png" for _ in range(TOTAL_SAMPLES)] 34 | self.ignore_label = 255 35 | self.background_idx = 0 36 | self.is_training = is_training 37 | self.is_evaluation = is_evaluation 38 | self.check_dataset() 39 | 40 | @staticmethod 41 | def read_image_pil(path: str) -> Optional[Image.Image]: 42 | """Mock the init logic for read_image_pil function 43 | 44 | Instead of reading a PIL RGB image at location specified by `path`, a PIL 45 | RGB image of size (20, 40) returned. 46 | """ 47 | return Image.new("RGB", (20, 30)) 48 | 49 | @staticmethod 50 | def read_mask_pil(path: str) -> Optional[Image.Image]: 51 | """Mock the init logic for read_mask_pil function 52 | 53 | Instead of reading a mask at location specified by `path`, a PIL mask image of 54 | size (20, 40) is returned. 55 | """ 56 | return Image.new("L", (20, 30)) 57 | -------------------------------------------------------------------------------- /tests/data/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/dummy_silent_video.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/data/dummy_silent_video.mov -------------------------------------------------------------------------------- /tests/data/dummy_video.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/data/dummy_video.mov -------------------------------------------------------------------------------- /tests/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/samplers/test_batch_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "batch_sampler" 4 | num_repeats: 1 5 | truncated_repeat_aug_sampler: false 6 | bs: 7 | crop_size_width: 224 8 | crop_size_height: 224 9 | -------------------------------------------------------------------------------- /tests/data/samplers/test_chain_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | # dummy configuration for testing chain sampler 3 | # We over-ride different options inside the test to study different cases of chain sampler 4 | sampler: 5 | name: "chain_sampler" 6 | chain_sampler_mode: "sequential" 7 | chain_sampler: 8 | - task_name: "task_1" 9 | train_batch_size0: 128 10 | val_batch_size0: 100 11 | sampler_config: 12 | name: "variable_batch_sampler" 13 | vbs: 14 | crop_size_width: 224 15 | crop_size_height: 224 16 | max_n_scales: 25 17 | min_crop_size_width: 128 18 | max_crop_size_width: 320 19 | min_crop_size_height: 128 20 | max_crop_size_height: 320 21 | check_scale: 16 22 | -------------------------------------------------------------------------------- /tests/data/samplers/test_multi_scale_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "multi_scale_sampler" 4 | num_repeats: 1 5 | truncated_repeat_aug_sampler: false 6 | msc: 7 | crop_size_width: 224 8 | crop_size_height: 224 9 | max_n_scales: 5 10 | min_crop_size_width: 128 11 | max_crop_size_width: 320 12 | min_crop_size_height: 128 13 | max_crop_size_height: 320 14 | check_scale: 32 15 | -------------------------------------------------------------------------------- /tests/data/samplers/test_variable_batch_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "variable_batch_sampler" 4 | vbs: 5 | crop_size_width: 224 6 | crop_size_height: 224 7 | max_n_scales: 5 8 | min_crop_size_width: 128 9 | max_crop_size_width: 320 10 | min_crop_size_height: 128 11 | max_crop_size_height: 320 12 | check_scale: 32 13 | -------------------------------------------------------------------------------- /tests/dummy_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | 7 | def train_val_datasets(opts): 8 | dataset_category = getattr(opts, "dataset.category", None) 9 | dataset_name = getattr(opts, "dataset.name", None) 10 | 11 | assert dataset_category is not None 12 | assert dataset_name is not None 13 | 14 | # we may not have access to the dataset, so for CI/CD, we only compute loss 15 | setattr(opts, "stats.train", "loss") 16 | # relaxing val statistics to test different metrics 17 | # setattr(opts, "stats.val", "loss") 18 | setattr(opts, "stats.checkpoint_metric", "loss") 19 | setattr(opts, "stats.checkpoint_metric_max", False) 20 | 21 | if dataset_category == "classification": 22 | # image classification 23 | from tests.dummy_datasets.classification import ( 24 | DummyClassificationDataset as dataset_cls, 25 | ) 26 | elif dataset_category == "detection" and dataset_name.find("ssd") > -1: 27 | # Object detection using SSD 28 | from tests.dummy_datasets.ssd_detection import ( 29 | DummySSDDetectionDataset as dataset_cls, 30 | ) 31 | elif dataset_category == "segmentation": 32 | from tests.dummy_datasets.segmentation import ( 33 | DummySegmentationDataset as dataset_cls, 34 | ) 35 | elif dataset_category == "video_classification": 36 | from tests.dummy_datasets.video_classification import ( 37 | DummyVideoClassificationDataset as dataset_cls, 38 | ) 39 | elif dataset_category == "multi_modal_image_text": 40 | from tests.dummy_datasets.multi_modal_img_text import ( 41 | DummyMultiModalImageTextDataset as dataset_cls, 42 | ) 43 | else: 44 | raise NotImplementedError( 45 | "Dummy datasets for {} not yet implemented".format(dataset_category) 46 | ) 47 | 48 | train_dataset = dataset_cls(opts) 49 | valid_dataset = dataset_cls(opts) 50 | 51 | return train_dataset, valid_dataset 52 | -------------------------------------------------------------------------------- /tests/dummy_datasets/classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Dict, Tuple 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | 12 | class DummyClassificationDataset(data.Dataset): 13 | """ 14 | Dummy Classification Dataset for CI/CD testing 15 | 16 | Args: 17 | opts: command-line arguments 18 | 19 | """ 20 | 21 | def __init__(self, opts, *args, **kwargs) -> None: 22 | super().__init__() 23 | 24 | self.n_classes = 1000 25 | setattr(opts, "model.classification.n_classes", self.n_classes) 26 | setattr( 27 | opts, 28 | "dataset.collate_fn_name_train", 29 | "image_classification_data_collate_fn", 30 | ) 31 | setattr( 32 | opts, "dataset.collate_fn_name_val", "image_classification_data_collate_fn" 33 | ) 34 | setattr( 35 | opts, "dataset.collate_fn_name_test", "image_classification_data_collate_fn" 36 | ) 37 | 38 | def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: 39 | """ 40 | :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) 41 | :return: dictionary containing input image, label, and sample_id. 42 | """ 43 | crop_size_h, crop_size_w, img_index = batch_indexes_tup 44 | 45 | input_img = torch.randn(size=(3, crop_size_h, crop_size_w), dtype=torch.float) 46 | target = torch.randint(low=0, high=self.n_classes, size=(1,)).long() 47 | 48 | return { 49 | "samples": input_img, 50 | "targets": target, 51 | "sample_id": torch.randint(low=0, high=1000, size=(1,)).long(), 52 | } 53 | 54 | def __len__(self) -> int: 55 | return 10 56 | -------------------------------------------------------------------------------- /tests/dummy_datasets/multi_modal_img_text.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Dict, Tuple 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | 12 | class DummyMultiModalImageTextDataset(data.Dataset): 13 | """ 14 | Dummy Dataset for CI/CD testing 15 | 16 | Args: 17 | opts: command-line arguments 18 | 19 | """ 20 | 21 | def __init__(self, opts, *args, **kwargs) -> None: 22 | super().__init__() 23 | 24 | self.context_length = 5 25 | self.vocab_size = 100 26 | setattr(opts, "dataset.text_vocab_size", self.vocab_size) 27 | setattr(opts, "dataset.text_context_length", self.context_length) 28 | 29 | setattr( 30 | opts, "dataset.collate_fn_name_train", "multi_modal_img_text_collate_fn" 31 | ) 32 | setattr(opts, "dataset.collate_fn_name_val", "multi_modal_img_text_collate_fn") 33 | setattr(opts, "dataset.collate_fn_name_test", "multi_modal_img_text_collate_fn") 34 | 35 | def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: 36 | """ 37 | :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) 38 | :return: dictionary containing input image, label, and sample_id. 39 | """ 40 | crop_size_h, crop_size_w, img_index = batch_indexes_tup 41 | 42 | input_img = torch.randn(size=(3, crop_size_h, crop_size_w), dtype=torch.float) 43 | text = torch.randint( 44 | low=0, high=self.vocab_size, size=(self.context_length,), dtype=torch.int 45 | ) 46 | 47 | return { 48 | "samples": {"image": input_img, "text": text, "padding_mask": None}, 49 | "targets": torch.randint(low=0, high=1, size=(1,)).long(), 50 | } 51 | 52 | def __len__(self) -> int: 53 | return 10 54 | -------------------------------------------------------------------------------- /tests/dummy_datasets/segmentation.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Dict, Tuple 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | 12 | class DummySegmentationDataset(data.Dataset): 13 | """ 14 | Dummy Segmentation Dataset for CI/CD testing 15 | 16 | Args: 17 | opts: command-line arguments 18 | 19 | """ 20 | 21 | def __init__(self, opts, *args, **kwargs) -> None: 22 | super().__init__() 23 | 24 | self.n_classes = 20 25 | setattr(opts, "model.segmentation.n_classes", self.n_classes) 26 | setattr(opts, "dataset.collate_fn_name_train", "default_collate_fn") 27 | setattr(opts, "dataset.collate_fn_name_val", "default_collate_fn") 28 | setattr(opts, "dataset.collate_fn_name_test", None) 29 | 30 | def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: 31 | """ 32 | :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) 33 | :return: dictionary containing input image, label, and sample_id. 34 | """ 35 | crop_size_h, crop_size_w, img_index = batch_indexes_tup 36 | 37 | input_img = torch.randn(size=(3, crop_size_h, crop_size_w), dtype=torch.float) 38 | target = torch.randint( 39 | low=0, high=self.n_classes, size=(crop_size_h, crop_size_w) 40 | ).long() 41 | 42 | return {"samples": input_img, "targets": target} 43 | 44 | def __len__(self) -> int: 45 | return 10 46 | -------------------------------------------------------------------------------- /tests/loss_fns/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /tests/loss_fns/test_class_weighting.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | 6 | from loss_fn.utils.class_weighting import compute_class_weights 7 | 8 | 9 | def test_class_weighting(): 10 | # test for checking the class weighting method 11 | targets = torch.tensor([1, 1, 1, 2, 2, 3], dtype=torch.long) 12 | n_classes = 4 13 | norm_val = 1.0 14 | 15 | weights = compute_class_weights( 16 | target=targets, n_classes=n_classes, norm_val=norm_val 17 | ) 18 | weights = torch.round(weights, decimals=2) 19 | 20 | expected_weights = torch.tensor([0.0, 2.47, 3.48, 6.49]) 21 | 22 | torch.testing.assert_allclose(actual=weights, expected=expected_weights) 23 | -------------------------------------------------------------------------------- /tests/loss_fns/test_contrastive_loss.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | import pytest 7 | import torch 8 | 9 | from loss_fn.multi_modal_img_text.contrastive_loss_clip import ContrastiveLossClip 10 | 11 | 12 | @pytest.mark.parametrize("batch_size", [1, 2]) 13 | @pytest.mark.parametrize("projection_dim", [256, 512]) 14 | def test_contrastive_loss_in_out(batch_size: int, projection_dim: int) -> None: 15 | # These tests check the input and output formats are correct or not. 16 | parser = argparse.ArgumentParser() 17 | parser = ContrastiveLossClip.add_arguments(parser) 18 | 19 | opts = parser.parse_args([]) 20 | criteria = ContrastiveLossClip(opts) 21 | 22 | image_features = torch.randn(size=(batch_size, projection_dim)) 23 | text_features = torch.randn(size=(batch_size, projection_dim)) 24 | 25 | input_sample = None 26 | targets = None 27 | 28 | prediction = {"image": image_features, "text": text_features} 29 | 30 | loss_output = criteria(input_sample, prediction, targets) 31 | expected_output_keys = {"total_loss", "image_loss", "text_loss", "logit_scale"} 32 | assert expected_output_keys.issubset(loss_output.keys()) 33 | 34 | for loss_name, loss_val in loss_output.items(): 35 | if loss_name == "logit_scale" and isinstance(loss_val, (float, int)): 36 | loss_val = torch.tensor(loss_val) 37 | assert isinstance( 38 | loss_val, torch.Tensor 39 | ), "Loss should be an instance of torch.Tensor" 40 | assert loss_val.dim() == 0, "Loss value should be a scalar" 41 | -------------------------------------------------------------------------------- /tests/loss_fns/test_neural_aug.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | import pytest 7 | import torch 8 | 9 | from loss_fn.neural_augmentation import NeuralAugmentation 10 | 11 | 12 | @pytest.mark.parametrize("batch_size", [1, 2]) 13 | def test_neural_aug_loss_in_out(batch_size: int) -> None: 14 | # These tests check the input and output formats are correct or not. 15 | # build configuration 16 | parser = argparse.ArgumentParser() 17 | parser = NeuralAugmentation.add_arguments(parser) 18 | 19 | opts = parser.parse_args([]) 20 | setattr(opts, "scheduler.max_epochs", 20) 21 | 22 | # build loss function 23 | neural_aug_loss_fn = NeuralAugmentation(opts) 24 | pred_tensor = { 25 | "augmented_tensor": torch.zeros( 26 | size=(batch_size, 3, 224, 224), dtype=torch.float 27 | ) 28 | } 29 | 30 | # Three input cases: 31 | # Case 1: Input image is a tensor 32 | # Case 2: Input is a dictionary, with image as a mandatory key and value as a batch of input image tensor 33 | # Case 3: Input is a dictionary, with image as a mandatory key and value as a list of input image tensor 34 | input_case_1 = torch.randint(low=0, high=1, size=(batch_size, 3, 224, 224)) 35 | input_case_2 = { 36 | "image": torch.randint(low=0, high=1, size=(batch_size, 3, 224, 224)) 37 | } 38 | input_case_3 = { 39 | "image": [torch.randint(low=0, high=1, size=(1, 3, 224, 224))] * batch_size 40 | } 41 | 42 | for inp in [input_case_1, input_case_2, input_case_3]: 43 | loss = neural_aug_loss_fn(inp, pred_tensor) 44 | assert isinstance( 45 | loss, torch.Tensor 46 | ), "Loss should be an instance of torch.Tensor" 47 | assert loss.dim() == 0, "Loss value should be a scalar" 48 | -------------------------------------------------------------------------------- /tests/loss_fns/test_neural_aug_compatibility.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import re 5 | import sys 6 | from pathlib import Path 7 | 8 | sys.path.append("..") 9 | 10 | from tests.configs import get_config 11 | from tests.test_model import exclude_yaml_from_test 12 | 13 | 14 | def test_neural_aug_backward_compatibility(config_file: str): 15 | opts = get_config(config_file=config_file) 16 | 17 | opts_dict = vars(opts) 18 | for k, v in opts_dict.items(): 19 | if isinstance(v, str) and re.search(".*_with_na$", v): 20 | raise DeprecationWarning( 21 | "We deprecated the usage of _with_na loss functions. " 22 | "Please see examples/range_augment for examples." 23 | ) 24 | 25 | 26 | def pytest_generate_tests(metafunc): 27 | configs = [ 28 | str(x) 29 | for x in Path("config").rglob("**/*.yaml") 30 | if not exclude_yaml_from_test(x) 31 | ] 32 | configs += [ 33 | str(x) 34 | for x in Path("examples").rglob("**/*.yaml") 35 | if not exclude_yaml_from_test(x) 36 | ] 37 | metafunc.parametrize("config_file", configs) 38 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/metrics/test_image_text_retrieval_metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import pytest 7 | import torch 8 | 9 | from metrics.stats import Statistics 10 | from tests.configs import default_training_opts 11 | 12 | 13 | @pytest.mark.parametrize("batch_size", (2, 4)) 14 | @pytest.mark.parametrize("num_captions", (1, 5)) 15 | @pytest.mark.parametrize("hidden_dim", (8,)) 16 | @pytest.mark.parametrize("text_dim", (2, 3)) 17 | def test_image_text_retrieval( 18 | batch_size: int, num_captions: int, hidden_dim: int, text_dim: int 19 | ) -> None: 20 | stats = Statistics( 21 | opts=default_training_opts(), metric_names=["image_text_retrieval"] 22 | ) 23 | for _ in range(3): 24 | image_emb = torch.randn(batch_size, hidden_dim) 25 | text_emb = torch.randn(batch_size, num_captions, hidden_dim) 26 | if text_dim == 2: 27 | text_emb = text_emb.reshape(-1, hidden_dim) 28 | stats.update({"image": image_emb, "text": text_emb}, {}, {}) 29 | 30 | metrics = stats._compute_avg_statistics_all() 31 | img_text_metrics = metrics["image_text_retrieval"] 32 | 33 | parent_keys = ["text2image", "image2text"] 34 | child_keys = ["recall@1", "recall@5", "recall@10", "mean_rank", "median_rank"] 35 | for parent_key in parent_keys: 36 | assert parent_key in img_text_metrics 37 | for child_key in child_keys: 38 | assert child_key in img_text_metrics[parent_key] 39 | -------------------------------------------------------------------------------- /tests/metrics/test_iou.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Callable 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from metrics.stats import Statistics 12 | from tests.metrics.base import transform_args 13 | 14 | 15 | def test_gather_iou_metrics(transform_args: Callable): 16 | # [Batch, num_classes, height, width] 17 | # in this example, [1, 2, 2, 3] 18 | prediction = torch.tensor( 19 | [ 20 | [ 21 | [[0.2, 0.8, 0.2], [0.9, 0.2, 0.1]], 22 | [[0.8, 0.2, 0.8], [0.1, 0.8, 0.9]], # spatial dms 23 | ] # classes 24 | ] # batch 25 | ) 26 | 27 | target = torch.tensor([[[0, 0, 0], [0, 1, 1]]]) 28 | 29 | metric_names, stats_args = transform_args(["iou"], prediction, target) 30 | 31 | expected_inter = np.array([2.0, 2.0]) 32 | expected_union = np.array([4.0, 4.0]) 33 | 34 | expected_iou = np.mean(expected_inter / expected_union) * 100 35 | 36 | stats = Statistics(opts=None, metric_names=metric_names) 37 | stats.update(*stats_args) 38 | 39 | np.testing.assert_equal( 40 | actual=stats.avg_statistics(metric_names[0]), desired=expected_iou 41 | ) 42 | -------------------------------------------------------------------------------- /tests/metrics/test_probability_histogram.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Callable 7 | 8 | import numpy as np 9 | 10 | from metrics.stats import Statistics 11 | from tests.configs import default_training_opts 12 | from tests.metrics.base import sample_classification_outputs, transform_args 13 | 14 | 15 | def test_probability_histogram(transform_args: Callable): 16 | metric_names, stats_args = transform_args( 17 | ["prob_hist"], *sample_classification_outputs() 18 | ) 19 | 20 | stats = Statistics(opts=default_training_opts(), metric_names=metric_names) 21 | stats.update(*stats_args) 22 | 23 | # max values -> 0.91, 0.81, 0.51 24 | max_conf_hist = stats.avg_statistics(metric_names[0], "max") 25 | np.testing.assert_almost_equal( 26 | max_conf_hist, 27 | [0, 0, 0, 0, 0, 0.33, 0, 0, 0.33, 0.33], 28 | decimal=2, 29 | ) 30 | 31 | # target values -> 0.05, 0.16, 0.51 32 | target_conf_hist = stats.avg_statistics(metric_names[0], "target") 33 | np.testing.assert_almost_equal( 34 | target_conf_hist, 35 | [0.33, 0.33, 0, 0, 0, 0.33, 0, 0, 0, 0], 36 | decimal=2, 37 | ) 38 | -------------------------------------------------------------------------------- /tests/metrics/test_psnr.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import math 7 | from typing import Callable 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from metrics.stats import Statistics 13 | from tests.metrics.base import transform_args 14 | 15 | 16 | def test_gather_psnr_metrics(transform_args: Callable): 17 | # Test for case 1 18 | inp_tensor = torch.randn((3, 2), dtype=torch.float) 19 | target_tensor = inp_tensor 20 | 21 | # Ideally, the PSNR should be infinite when input and target are the same, because error between 22 | # signal and noise is 0. However, we add a small eps value (error of 1e-10) in the computation 23 | # for numerical stability. Therefore, PSNR will not be infinite. 24 | expected_psnr = 10.0 * math.log10(255.0**2 / 1e-10) 25 | 26 | metric_names, stats_args = transform_args(["psnr"], inp_tensor, target_tensor) 27 | 28 | stats = Statistics(opts=None, metric_names=metric_names) 29 | stats.update(*stats_args) 30 | 31 | np.testing.assert_almost_equal( 32 | stats.avg_statistics(metric_names[0]), expected_psnr, decimal=2 33 | ) 34 | -------------------------------------------------------------------------------- /tests/metrics/test_topk_accuracy.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Callable 7 | 8 | import numpy as np 9 | 10 | from metrics.stats import Statistics 11 | from tests.metrics.base import sample_classification_outputs, transform_args 12 | 13 | 14 | def test_gather_top_k_metrics(transform_args: Callable): 15 | metric_names, stats_args = transform_args( 16 | ["top1", "top5"], *sample_classification_outputs() 17 | ) 18 | 19 | stats = Statistics(opts=None, metric_names=metric_names) 20 | stats.update(*stats_args) 21 | top1_acc = round(stats.avg_statistics(metric_names[0]), 2) 22 | top5_acc = round(stats.avg_statistics(metric_names[1]), 2) 23 | 24 | np.testing.assert_almost_equal(top1_acc, 33.33, decimal=2) 25 | np.testing.assert_almost_equal(top5_acc, 100.00, decimal=2) 26 | -------------------------------------------------------------------------------- /tests/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /tests/models/audio_classification/test_base_audio_classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from cvnets.models.audio_classification import base_audio_classification 9 | 10 | 11 | def test_base_audio_classification_adds_arguments() -> None: 12 | opts = argparse.Namespace() 13 | model = base_audio_classification.BaseAudioClassification(opts) 14 | 15 | parser = argparse.ArgumentParser() 16 | model.add_arguments(parser) 17 | assert hasattr(parser.parse_args([]), "model.audio_classification.name") 18 | -------------------------------------------------------------------------------- /tests/models/audio_classification/test_byteformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | 8 | from cvnets.models.audio_classification import audio_byteformer 9 | from cvnets.models.classification import byteformer as image_byteformer 10 | from tests.models.classification import test_byteformer 11 | 12 | 13 | def test_audio_byteformer() -> None: 14 | # Make sure it matches the image classification network. 15 | opts = test_byteformer.get_opts() 16 | 17 | byteformer1 = image_byteformer.ByteFormer(opts) 18 | byteformer2 = audio_byteformer.AudioByteFormer(opts) 19 | 20 | # Make their state_dicts match. 21 | byteformer2.load_state_dict(byteformer1.state_dict()) 22 | 23 | batch_size, sequence_length = 2, 32 24 | 25 | x = torch.randint(0, 128, [batch_size, sequence_length]) 26 | 27 | assert torch.all(byteformer1(x) == byteformer2({"audio": x})) 28 | -------------------------------------------------------------------------------- /tests/models/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/models/classification/__init__.py -------------------------------------------------------------------------------- /tests/models/classification/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/models/classification/config/__init__.py -------------------------------------------------------------------------------- /tests/models/classification/config/test_byteformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import pytest 9 | 10 | from cvnets.models.classification.config import byteformer 11 | 12 | 13 | @pytest.mark.parametrize("mode", ["tiny", "small", "base", "huge"]) 14 | def test_get_configuration(mode) -> None: 15 | opts = argparse.Namespace() 16 | setattr(opts, "model.classification.byteformer.mode", mode) 17 | setattr(opts, "model.classification.byteformer.dropout", 0.0) 18 | setattr(opts, "model.classification.byteformer.norm_layer", "layer_norm") 19 | byteformer.get_configuration(opts) 20 | -------------------------------------------------------------------------------- /tests/models/test_neural_aug_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | 8 | import pytest 9 | 10 | sys.path.append("../..") 11 | 12 | from cvnets.neural_augmentor.utils.neural_aug_utils import * 13 | 14 | 15 | @pytest.mark.parametrize("noise_var", [0.0001, 0.01, 0.1]) 16 | def test_random_noise(noise_var): 17 | in_channels = 3 18 | in_height = 224 19 | in_width = 224 20 | x = torch.ones(size=(1, in_channels, in_width, in_height), dtype=torch.float) 21 | 22 | aug_out = random_noise(x, variance=torch.tensor(noise_var, dtype=torch.float)) 23 | 24 | torch.testing.assert_allclose(actual=x.shape, expected=aug_out.shape) 25 | 26 | 27 | @pytest.mark.parametrize("magnitude", [0.1, 1.0, 2.0]) 28 | def test_random_brightness(magnitude): 29 | in_channels = 3 30 | in_height = 224 31 | in_width = 224 32 | x = torch.ones(size=(1, in_channels, in_width, in_height), dtype=torch.float) 33 | 34 | aug_out = random_brightness(x, magnitude=torch.tensor(magnitude, dtype=torch.float)) 35 | 36 | torch.testing.assert_allclose(actual=x.shape, expected=aug_out.shape) 37 | 38 | 39 | @pytest.mark.parametrize("magnitude", [0.1, 1.0, 2.0]) 40 | def test_random_contrast(magnitude): 41 | in_channels = 3 42 | in_height = 224 43 | in_width = 224 44 | x = torch.ones(size=(1, in_channels, in_width, in_height), dtype=torch.float) 45 | 46 | aug_out = random_contrast(x, magnitude=torch.tensor(magnitude, dtype=torch.float)) 47 | 48 | torch.testing.assert_allclose(actual=x.shape, expected=aug_out.shape) 49 | -------------------------------------------------------------------------------- /tests/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/modules/__init__.py -------------------------------------------------------------------------------- /tests/modules/test_transformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import torch 9 | 10 | from cvnets.modules import TransformerEncoder 11 | 12 | 13 | def get_opts() -> argparse.Namespace: 14 | opts = argparse.Namespace() 15 | setattr(opts, "model.normalization.groups", None) 16 | setattr(opts, "model.normalization.momentum", 0.9) 17 | setattr(opts, "model.activation.name", "relu") 18 | setattr(opts, "model.activation.inplace", False) 19 | setattr(opts, "model.activation.neg_slope", False) 20 | return opts 21 | 22 | 23 | def ensure_equal_in_range(t: torch.Tensor, start: int, end: int) -> None: 24 | """ 25 | Ensure values of @t are equal from @start to @end, but not after @end. 26 | 27 | The tensor can have any number of dimensions greater than 0. The first 28 | dimension is the dimension indexed by @start and @end. 29 | 30 | Args: 31 | t: The tensor to check. 32 | start: The start index. 33 | end: The end index. 34 | """ 35 | prototype = t[start] 36 | assert torch.all((prototype - t[start:end]).abs() < 1e-3) 37 | assert torch.all(prototype != t[end:]) 38 | 39 | 40 | def test_masked_attention() -> None: 41 | opts = get_opts() 42 | 43 | B, N, C = 2, 64 + 2, 8 44 | t = TransformerEncoder(opts, embed_dim=C, ffn_latent_dim=4 * C) 45 | prototype = torch.randn([C]) 46 | x = torch.ones([B, N, C]) 47 | x[:, :] = prototype 48 | 49 | key_padding_mask = torch.zeros([B, N]) 50 | key_padding_mask[0, 63:] = float("-inf") 51 | # Mask the @x values at the masked positions. 52 | x[0, 63:] = 0 53 | 54 | y = t(x, key_padding_mask=key_padding_mask) 55 | 56 | prototype = y[0, 0] 57 | assert torch.all(prototype == y[0, :63]) 58 | assert torch.all(prototype != y[0, 63:]) 59 | 60 | prototype = y[1, 0] 61 | assert torch.all(prototype == y[1, :]) 62 | -------------------------------------------------------------------------------- /tests/options/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/options/test_parse_args.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | from typing import Any, Dict, List, Tuple, Union 7 | 8 | import pytest 9 | 10 | from options.parse_args import JsonValidator 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "expected_type,valid,valid_parsed,invalid", 15 | [ 16 | (None, "null", None, "1"), 17 | (int, "1", 1, "1.0"), 18 | (float, "1.0", 1.0, '"1"'), 19 | (float, "1", 1.0, "s"), 20 | (bool, "true", True, "null"), 21 | (List[int], "[1, 2,3]", [1, 2, 3], "{1: 2}"), 22 | (List[int], "[]", [], '["s"]'), 23 | (Tuple[int, int], "[1, 2]", (1, 2), "[1, 2, 3]"), 24 | (Dict[str, Tuple[int, float]], '{"x": [1, 2]}', {"x": (1, 2.0)}, '{"x": "y"}'), 25 | (Union[Tuple[int, Any], int], "[1,null]", (1, None), "[null,1]"), 26 | (Union[Tuple[int, int], int], "1", 1, '"1"'), 27 | ], 28 | ) 29 | def test_json_validator( 30 | expected_type: type, valid: str, valid_parsed: Any, invalid: str 31 | ): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--x", type=JsonValidator(expected_type)) 34 | 35 | class ArgparseFailure(Exception): 36 | pass 37 | 38 | def _exit(status, message): 39 | raise ArgparseFailure(f"Unexpected argparse failure: {message}") 40 | 41 | parser.exit = ( 42 | _exit # override exit to raise exception, rather than invoking sys.exit() 43 | ) 44 | 45 | opts = parser.parse_args([f"--x={valid}"]) 46 | assert opts.x == valid_parsed 47 | assert repr(opts.x) == repr(valid_parsed) # check types 48 | 49 | with pytest.raises(ArgparseFailure): 50 | parser.parse_args([f"--x={invalid}"]) 51 | -------------------------------------------------------------------------------- /tests/options/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from tests.configs import get_config 10 | 11 | 12 | def test_load_config_file_produces_no_false_warnings() -> None: 13 | get_config() 14 | 15 | 16 | def test_load_config_file_produces_true_warning( 17 | tmp_path: Path, 18 | ) -> None: 19 | config_path = tmp_path.joinpath("config.yaml") 20 | config_path.write_text("an_invalid_key: 2") 21 | with pytest.raises(ValueError, match="an_invalid_key"): 22 | get_config(str(config_path)) 23 | -------------------------------------------------------------------------------- /tests/test_image_pil.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import pytest 9 | import torch 10 | 11 | from data.transforms import image_pil 12 | 13 | 14 | def test_to_tensor() -> None: 15 | parser = argparse.ArgumentParser() 16 | parser = image_pil.ToTensor.add_arguments(parser) 17 | opts = parser.parse_args([]) 18 | 19 | to_tensor = image_pil.ToTensor(opts=opts) 20 | 21 | H, W, C = 2, 2, 3 22 | num_masks = 2 23 | data = { 24 | "image": torch.rand([H, W, C]), 25 | "mask": torch.randint(0, 1, [num_masks, H, W]), 26 | } 27 | 28 | output = to_tensor(data) 29 | 30 | assert output["image"].shape == (H, W, C) 31 | assert output["mask"].shape == (num_masks, H, W) 32 | 33 | 34 | def test_to_tensor_bad_mask() -> None: 35 | parser = argparse.ArgumentParser() 36 | parser = image_pil.ToTensor.add_arguments(parser) 37 | opts = parser.parse_args([]) 38 | 39 | to_tensor = image_pil.ToTensor(opts=opts) 40 | 41 | H, W, C = 2, 2, 3 42 | num_categories = 2 43 | data = { 44 | "image": torch.rand([H, W, C]), 45 | "mask": torch.randint(0, 1, [num_categories, 1, H, W]), 46 | } 47 | 48 | with pytest.raises(SystemExit): 49 | to_tensor(data) 50 | -------------------------------------------------------------------------------- /tests/test_pos_embeddings.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | 8 | import numpy as np 9 | import pytest 10 | 11 | sys.path.append("..") 12 | 13 | from cvnets.layers.positional_embedding import PositionalEmbedding 14 | 15 | 16 | @pytest.mark.parametrize("is_learnable", [True, False]) 17 | @pytest.mark.parametrize("input_seq_len", [34, 128, 192]) 18 | @pytest.mark.parametrize("sequence_first", [True, False]) 19 | @pytest.mark.parametrize("padding_idx", [None, 0]) 20 | def test_pos_embedding( 21 | is_learnable: bool, input_seq_len: int, sequence_first: bool, padding_idx: int 22 | ): 23 | num_embeddings = 128 24 | pos_embedding = PositionalEmbedding( 25 | opts=None, 26 | num_embeddings=num_embeddings, 27 | embedding_dim=512, 28 | padding_idx=padding_idx, 29 | is_learnable=is_learnable, 30 | sequence_first=sequence_first, 31 | ) 32 | seq_dim = 0 if sequence_first else 1 33 | 34 | out = pos_embedding(input_seq_len) 35 | np.testing.assert_equal(out.shape[seq_dim], input_seq_len) 36 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | from argparse import Namespace 8 | 9 | import torch 10 | 11 | sys.path.append("..") 12 | 13 | from data.text_tokenizer.clip_tokenizer import ClipTokenizer 14 | 15 | 16 | def test_clip_tokenizer(): 17 | opts = Namespace() 18 | 19 | setattr( 20 | opts, 21 | "text_tokenizer.clip.merges_path", 22 | "http://download.pytorch.org/models/text/clip_merges.bpe", 23 | ) 24 | setattr( 25 | opts, 26 | "text_tokenizer.clip.encoder_json_path", 27 | "http://download.pytorch.org/models/text/clip_encoder.json", 28 | ) 29 | 30 | tokenizer = ClipTokenizer(opts=opts) 31 | out = tokenizer("the quick brown fox jumped over the lazy dog") 32 | 33 | expected_data = [ 34 | 49406, # Start token 35 | 518, 36 | 3712, 37 | 2866, 38 | 3240, 39 | 16901, 40 | 962, 41 | 518, 42 | 10753, 43 | 1929, 44 | 49407, # end token 45 | ] 46 | expected_out = torch.tensor(expected_data, dtype=out.dtype) 47 | torch.testing.assert_close(actual=out, expected=expected_out) 48 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import re 7 | from typing import Tuple 8 | 9 | import pytest 10 | 11 | 12 | def unset_pretrained_models_from_opts(opts: argparse.Namespace) -> None: 13 | """Unset the argument corresponding to pretrained model path in opts during tests""" 14 | opts_as_dict = vars(opts) 15 | for k, v in opts_as_dict.items(): 16 | if is_pretrained_model_key(k): 17 | setattr(opts, k, None) 18 | 19 | 20 | def is_pretrained_model_key(key_name: str) -> bool: 21 | """Check if arguments corresponding to model have a pretrained key or not.""" 22 | return True if re.search(r".*model\..*\.pretrained$", key_name) else False 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "key_name_expected_output", 27 | [ 28 | ("model.classification.pretrained", True), 29 | ("model.segmentation.pretrained", True), 30 | ("model.video_classification.pretrained", True), 31 | ("teacher.model.classification.pretrained", True), 32 | ("loss.classification.pretrained", False), 33 | ("model.classification.pretrained_dummy", False), 34 | ("model.classification.mypretrained", False), 35 | ("model.classification.my.pretrained", True), 36 | ], 37 | ) 38 | def test_is_pretrained_model_key(key_name_expected_output: Tuple[str, bool]): 39 | key_name = key_name_expected_output[0] 40 | expected_output = key_name_expected_output[1] 41 | assert is_pretrained_model_key(key_name) == expected_output 42 | -------------------------------------------------------------------------------- /tests/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/tests/transforms/__init__.py -------------------------------------------------------------------------------- /tests/transforms/test_audio_bytes.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import pytest 9 | import torch 10 | 11 | from data.transforms import audio_bytes 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "format,encoding_dtype,num_samples,expected_length", 16 | [ 17 | ("wav", "float32", 4, 74), 18 | ("wav", "float32", 8, 90), 19 | ("wav", "int32", 8, 112), 20 | ("wav", "int16", 8, 60), 21 | ("wav", "uint8", 8, 52), 22 | ("mp3", None, 8, 216), 23 | ], 24 | ) 25 | def test_audio_save(format, encoding_dtype, num_samples, expected_length) -> None: 26 | opts = argparse.Namespace() 27 | setattr(opts, "audio_augmentation.torchaudio_save.encoding_dtype", encoding_dtype) 28 | setattr(opts, "audio_augmentation.torchaudio_save.format", format) 29 | t = audio_bytes.TorchaudioSave(opts) 30 | 31 | x = { 32 | "samples": {"audio": torch.randn([2, num_samples])}, 33 | "metadata": {"audio_fps": 16}, 34 | } 35 | 36 | outputs = t(x)["samples"]["audio"] 37 | assert torch.all(0 <= outputs) 38 | assert torch.all(outputs <= 255) 39 | assert outputs.shape == (expected_length,) 40 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/utils/test_common_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from utils.common_utils import unwrap_model_fn 11 | 12 | 13 | def check_models( 14 | original_unwrapped_model: torch.nn.Module, model_after_unwrapping: torch.nn.Module 15 | ) -> None: 16 | """Helper function to test original and unwrapped models are the same.""" 17 | for layer_id in range(len(original_unwrapped_model)): 18 | # for unwrapped models, we should be able to index them 19 | assert repr(model_after_unwrapping[layer_id]) == repr( 20 | original_unwrapped_model[layer_id] 21 | ) 22 | 23 | 24 | def test_unwrap_model_fn(): 25 | """Test for unwrap_model_fn""" 26 | 27 | dummy_model = torch.nn.Sequential( 28 | torch.nn.Linear(10, 20), 29 | torch.nn.Linear(20, 40), 30 | ) 31 | 32 | # test DataParallel wrapping 33 | wrapped_model_dp = torch.nn.DataParallel(dummy_model) 34 | unwrapped_model_dp = unwrap_model_fn(wrapped_model_dp) 35 | check_models(dummy_model, unwrapped_model_dp) 36 | 37 | # Initialize the distributed environment 38 | os.environ["MASTER_ADDR"] = "localhost" 39 | os.environ["MASTER_PORT"] = "1234" 40 | dist.init_process_group(backend="gloo", rank=0, world_size=1) 41 | 42 | # test DDP wrapping 43 | wrapped_model_ddp = torch.nn.parallel.DistributedDataParallel(dummy_model) 44 | unwrapped_model_ddp = unwrap_model_fn(wrapped_model_ddp) 45 | 46 | check_models(dummy_model, unwrapped_model_ddp) 47 | # clean up DDP environment 48 | dist.destroy_process_group() 49 | -------------------------------------------------------------------------------- /tests/utils/test_dict_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from utils.dict_utils import filter_keys 6 | 7 | 8 | def test_extract_keys(): 9 | d = {"x": 2, "y": 3, "z": 4} 10 | 11 | assert filter_keys(d, ["x", "y"]) == {"x": 2, "y": 3} 12 | assert filter_keys(d, ["w"]) == {} 13 | -------------------------------------------------------------------------------- /tests/utils/test_import_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | from pathlib import Path 8 | 9 | from pytest_mock import MockerFixture 10 | 11 | from utils import import_utils 12 | from utils.import_utils import import_modules_from_folder 13 | 14 | 15 | def test_import_utils(tmp_path: Path, mocker: MockerFixture) -> None: 16 | tmp_path_str = str(tmp_path) 17 | sys.path.append(tmp_path_str) 18 | mocker.patch.object(import_utils, "LIBRARY_ROOT", tmp_path) 19 | try: 20 | files = [ 21 | "my_test_parent/child/module.py", 22 | "my_test_parent/child/nested/module.py", 23 | "my_test_parent/sibling.py", 24 | "my_internal/my_test_parent/child/module.py", 25 | "my_internal/my_test_parent/sibling.py", 26 | "my_internal/projects/A/my_test_parent/child/module.py", 27 | "my_internal/projects/B/my_test_parent/child/module.py", 28 | ] 29 | for path in files: 30 | path = tmp_path / path 31 | for package in path.parents: 32 | if package == tmp_path: 33 | break 34 | package.mkdir(exist_ok=True, parents=True) 35 | if not package.joinpath("__init__.py").exists(): 36 | package.joinpath("__init__.py").write_bytes(b"") 37 | path.write_bytes(b"") 38 | 39 | import_modules_from_folder( 40 | "my_test_parent/child", 41 | extra_roots=["my_internal", "my_internal/projects/*"], 42 | ) 43 | assert "my_test_parent.child.module" in sys.modules 44 | assert "my_test_parent.child.nested.module" in sys.modules 45 | assert "my_test_parent.sibling" not in sys.modules 46 | assert "my_internal.my_test_parent.child.module" in sys.modules 47 | assert "my_internal.my_test_parent.sibling" not in sys.modules 48 | assert "my_internal.projects.A.my_test_parent.child.module" in sys.modules 49 | assert "my_internal.projects.B.my_test_parent.child.module" in sys.modules 50 | finally: 51 | sys.path.remove(tmp_path_str) 52 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/utils/__init__.py -------------------------------------------------------------------------------- /utils/color_map.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Optional 7 | 8 | import numpy as np 9 | 10 | 11 | class Colormap(object): 12 | """ 13 | Generate colormap for visualizing segmentation masks or bounding boxes. 14 | 15 | This is based on the MATLab code in the PASCAL VOC repository: 16 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 17 | """ 18 | 19 | def __init__(self, n: Optional[int] = 256, normalized: Optional[bool] = False): 20 | super(Colormap, self).__init__() 21 | self.n = n 22 | self.normalized = normalized 23 | 24 | @staticmethod 25 | def get_bit_at_idx(val, idx): 26 | return (val & (1 << idx)) != 0 27 | 28 | def get_color_map(self) -> np.ndarray: 29 | 30 | dtype = "float32" if self.normalized else "uint8" 31 | color_map = np.zeros((self.n, 3), dtype=dtype) 32 | for i in range(self.n): 33 | r = g = b = 0 34 | c = i 35 | for j in range(8): 36 | r = r | (self.get_bit_at_idx(c, 0) << 7 - j) 37 | g = g | (self.get_bit_at_idx(c, 1) << 7 - j) 38 | b = b | (self.get_bit_at_idx(c, 2) << 7 - j) 39 | c = c >> 3 40 | 41 | color_map[i] = np.array([r, g, b]) 42 | color_map = color_map / 255 if self.normalized else color_map 43 | return color_map 44 | 45 | def get_box_color_codes(self) -> List: 46 | box_codes = [] 47 | 48 | for i in range(self.n): 49 | r = g = b = 0 50 | c = i 51 | for j in range(8): 52 | r = r | (self.get_bit_at_idx(c, 0) << 7 - j) 53 | g = g | (self.get_bit_at_idx(c, 1) << 7 - j) 54 | b = b | (self.get_bit_at_idx(c, 2) << 7 - j) 55 | c = c >> 3 56 | box_codes.append((int(r), int(g), int(b))) 57 | return box_codes 58 | 59 | def get_color_map_list(self) -> List: 60 | cmap = self.get_color_map() 61 | cmap = np.asarray(cmap).flatten() 62 | return list(cmap) 63 | -------------------------------------------------------------------------------- /utils/dict_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Collection, Dict, Optional 6 | 7 | 8 | def filter_keys( 9 | d: Dict, 10 | whitelist: Optional[Collection[str]] = None, 11 | ) -> Dict: 12 | """Returns a copy of the input dict @d, with a subset of keys that are in 13 | @whitelist. 14 | 15 | Args: 16 | d: Intput dictionary that will be copied with a subset of keys. 17 | whitelist: List of keys to keep in the output (if exist in input dict). 18 | """ 19 | 20 | return {key: d[key] for key in whitelist if key in d} 21 | -------------------------------------------------------------------------------- /utils/download_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from utils.download_utils_base import get_basic_local_path 7 | 8 | try: 9 | from internal.utils.blobby_utils import get_local_path_blobby 10 | 11 | get_local_path = get_local_path_blobby 12 | 13 | except ModuleNotFoundError as mnfe: 14 | get_local_path = get_basic_local_path 15 | -------------------------------------------------------------------------------- /utils/import_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import importlib 7 | import os 8 | from typing import Sequence 9 | 10 | from common import LIBRARY_ROOT 11 | from utils import logger 12 | 13 | 14 | def import_modules_from_folder( 15 | folder_name: str, extra_roots: Sequence[str] = () 16 | ) -> None: 17 | """Automatically import all modules from public library root folder, in addition 18 | to the @extra_roots directories. 19 | 20 | The @folder_name directory must exist in LIBRARY_ROOT, but existence in @extra_roots 21 | is optional. 22 | 23 | Args: 24 | folder_name: Name of the folder to search for its internal and public modules. 25 | extra_roots: By default, this function only imports from 26 | `LIBRARY_ROOT/{folder_name}/**/*.py`. For any extra_root provided, it will 27 | also import `LIBRARY_ROOT/{extra_root}/{folder_name}/**/*.py` modules. 28 | """ 29 | if not LIBRARY_ROOT.joinpath(folder_name).exists(): 30 | logger.error( 31 | f"{folder_name} doesn't exist in the public library root directory." 32 | ) 33 | 34 | for base_dir in [".", *extra_roots]: 35 | for path in LIBRARY_ROOT.glob(os.path.join(base_dir, folder_name, "**/*.py")): 36 | filename = path.name 37 | if filename[0] not in (".", "_"): 38 | module_name = str( 39 | path.relative_to(LIBRARY_ROOT).with_suffix("") 40 | ).replace(os.sep, ".") 41 | importlib.import_module(module_name) 42 | -------------------------------------------------------------------------------- /utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional, Union 7 | 8 | 9 | def make_divisible( 10 | v: Union[float, int], 11 | divisor: Optional[int] = 8, 12 | min_value: Optional[Union[float, int]] = None, 13 | ) -> Union[float, int]: 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | def bound_fn( 34 | min_val: Union[float, int], max_val: Union[float, int], value: Union[float, int] 35 | ) -> Union[float, int]: 36 | return max(min_val, min(max_val, value)) 37 | -------------------------------------------------------------------------------- /utils/object_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | from numbers import Number 8 | from typing import Dict 9 | 10 | from common import is_test_env 11 | from utils import logger 12 | 13 | 14 | def is_iterable(x): 15 | return hasattr(x, "__iter__") and not isinstance(x, (str, bytes)) 16 | 17 | 18 | def apply_recursively(x, cb, *args, **kwargs): 19 | if isinstance(x, dict): 20 | return {k: apply_recursively(v, cb, *args, **kwargs) for k, v in x.items()} 21 | elif is_iterable(x): 22 | x_type = type(x) 23 | return x_type([apply_recursively(y, cb, *args, **kwargs) for y in x]) 24 | else: 25 | return cb(x, *args, **kwargs) 26 | 27 | 28 | def flatten_to_dict( 29 | x, name: str, dict_sep: str = "/", list_sep: str = "_" 30 | ) -> Dict[str, Number]: 31 | if x is None: 32 | return {} 33 | elif isinstance(x, Number): 34 | return {name: x} 35 | elif isinstance(x, list): 36 | return { 37 | k: v 38 | for i, inner in enumerate(x) 39 | for k, v in flatten_to_dict( 40 | inner, 41 | name=name + list_sep + str(i), 42 | dict_sep=dict_sep, 43 | list_sep=list_sep, 44 | ).items() 45 | } 46 | elif isinstance(x, dict): 47 | return { 48 | k: v 49 | for iname, inner in x.items() 50 | for k, v in flatten_to_dict( 51 | inner, 52 | name=name + dict_sep + iname, 53 | dict_sep=dict_sep, 54 | list_sep=list_sep, 55 | ).items() 56 | } 57 | 58 | logger.error("This should never be reached!") 59 | return {} 60 | 61 | 62 | def is_pytest_environment() -> bool: 63 | """Helper function to check if pytest environment or not""" 64 | logger.warning( 65 | DeprecationWarning( 66 | "utils.object_utils.is_pytest_environment is deprecated. Please use" 67 | " common.is_test_env instead." 68 | ) 69 | ) 70 | return is_test_env() 71 | -------------------------------------------------------------------------------- /utils/object_utils_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from utils.object_utils import apply_recursively, flatten_to_dict 7 | 8 | 9 | def test_apply_on_values(): 10 | d = { 11 | "top1": 1.112311, 12 | "prob_hist": {"max": [0.10003, 0.3, 0.5, 0.09997]}, 13 | "accuracy_per_class": [0.8286, 0.9124], 14 | } 15 | 16 | new_d = apply_recursively(d, lambda x: round(x, 2)) 17 | 18 | assert str(new_d["top1"]) == "1.11" 19 | assert str(new_d["prob_hist"]["max"][0]) == "0.1" 20 | 21 | 22 | def test_flatten_to_dict(): 23 | original = { 24 | "top1": 1.112311, 25 | "prob_hist": {"max": [0.10003, 0.3, 0.5, 0.09997]}, 26 | "accuracy_per_class": [0.8286, 0.9124], 27 | } 28 | flattened = { 29 | "metric/top1": 1.112311, 30 | "metric/prob_hist/max_0": 0.10003, 31 | "metric/prob_hist/max_1": 0.3, 32 | "metric/prob_hist/max_2": 0.5, 33 | "metric/prob_hist/max_3": 0.09997, 34 | "metric/accuracy_per_class_0": 0.8286, 35 | "metric/accuracy_per_class_1": 0.9124, 36 | } 37 | assert flatten_to_dict(original, "metric") == flattened 38 | -------------------------------------------------------------------------------- /utils/registry_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from utils.registry import Registry 7 | 8 | 9 | def test_functional_registry() -> None: 10 | reg = Registry("registry_name") 11 | reg.register("awesome_dict")(dict) 12 | 13 | assert "awesome_dict" in reg 14 | assert "awesome_dict(name=hello)" in reg 15 | 16 | obj = reg["awesome_dict(name=hello, type=fifo)"]() 17 | 18 | assert obj == {"name": "hello", "type": "fifo"} 19 | 20 | 21 | def test_basic_registration() -> None: 22 | my_registry = Registry("registry_name") 23 | 24 | @my_registry.register("awesome_class_or_func") 25 | def my_awesome_class_or_func(param): 26 | pass 27 | 28 | assert "awesome_class_or_func" in my_registry 29 | assert "awesome_class_or_func(param=value)" in my_registry 30 | -------------------------------------------------------------------------------- /utils/resources.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | try: 7 | from internal.utils.resources import cpu_count 8 | except ImportError: 9 | from multiprocessing import cpu_count 10 | 11 | __all__ = ["cpu_count"] 12 | -------------------------------------------------------------------------------- /utils/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cvnets/77717569ab4a852614dae01f010b32b820cb33bb/utils/third_party/__init__.py --------------------------------------------------------------------------------