├── .dockerignore ├── .flake8 ├── .gitattributes ├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── assets ├── cat.jpeg └── dog.jpeg ├── conftest.py ├── corenet ├── __init__.py ├── __main__.py ├── __version__.py ├── cli │ ├── __init__.py │ ├── entrypoints.py │ ├── main.py │ ├── main_benchmark.py │ ├── main_conversion.py │ ├── main_eval.py │ ├── main_eval_llmadapters.py │ └── main_train.py ├── constants.py ├── data │ ├── __init__.py │ ├── collate_fns │ │ ├── __init__.py │ │ ├── byteformer_collate_functions.py │ │ └── collate_functions.py │ ├── data_loaders.py │ ├── datasets │ │ ├── __init__.py │ │ ├── audio_classification │ │ │ ├── __init__.py │ │ │ └── speech_commands_v2.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── base_image_classification_dataset.py │ │ │ ├── base_imagenet_shift_dataset.py │ │ │ ├── coco.py │ │ │ ├── imagenet.py │ │ │ ├── imagenet_a.py │ │ │ ├── imagenet_r.py │ │ │ ├── imagenet_sketch.py │ │ │ ├── imagenet_synsets.py │ │ │ ├── imagenet_v2.py │ │ │ ├── places365.py │ │ │ └── wordnet_tagged_classification.py │ │ ├── dataset_base.py │ │ ├── detection │ │ │ ├── __init__.py │ │ │ ├── base_detection.py │ │ │ ├── coco_base.py │ │ │ ├── coco_mask_rcnn.py │ │ │ └── coco_ssd.py │ │ ├── language_modeling │ │ │ ├── __init__.py │ │ │ ├── base_lm.py │ │ │ ├── commonsense_170k.py │ │ │ └── general_lm.py │ │ ├── multi_modal_img_text │ │ │ ├── __init__.py │ │ │ ├── base_multi_modal_img_text.py │ │ │ ├── flickr.py │ │ │ ├── img_text_tar_dataset.py │ │ │ └── zero_shot_image_classification │ │ │ │ ├── __init__.py │ │ │ │ ├── base_zero_shot_image_classification.py │ │ │ │ ├── imagenet.py │ │ │ │ ├── imagenet_a.py │ │ │ │ ├── imagenet_class_names.py │ │ │ │ ├── imagenet_r.py │ │ │ │ ├── imagenet_sketch.py │ │ │ │ └── templates.py │ │ ├── segmentation │ │ │ ├── __init__.py │ │ │ ├── ade20k.py │ │ │ ├── base_segmentation.py │ │ │ ├── coco_segmentation.py │ │ │ ├── coco_stuff.py │ │ │ └── pascal_voc.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── text.py │ │ │ └── video.py │ ├── io │ │ ├── __init__.py │ │ └── transfer_clients.py │ ├── loader │ │ ├── __init__.py │ │ └── dataloader.py │ ├── sampler │ │ ├── __init__.py │ │ ├── base_sampler.py │ │ ├── batch_sampler.py │ │ ├── chain_sampler.py │ │ ├── multi_scale_sampler.py │ │ ├── utils.py │ │ ├── variable_batch_sampler.py │ │ ├── video_batch_sampler.py │ │ ├── video_clip_batch_sampler.py │ │ └── video_variable_seq_sampler.py │ ├── text_tokenizer │ │ ├── __init__.py │ │ ├── base_tokenizer.py │ │ ├── clip_tokenizer.py │ │ └── sentencepiece_tokenizer.py │ ├── transforms │ │ ├── __init__.py │ │ ├── audio.py │ │ ├── audio_aux │ │ │ ├── __init__.py │ │ │ └── mfccs.py │ │ ├── audio_bytes.py │ │ ├── base_transforms.py │ │ ├── common.py │ │ ├── image_bytes.py │ │ ├── image_pil.py │ │ ├── image_torch.py │ │ ├── utils.py │ │ └── video.py │ └── video_reader │ │ ├── __init__.py │ │ ├── base_av_reader.py │ │ ├── decord_reader.py │ │ ├── ffmpeg_reader.py │ │ ├── ffmpeg_utils.py │ │ └── pyav_reader.py ├── engine │ ├── __init__.py │ ├── default_trainer.py │ ├── detection_utils │ │ ├── __init__.py │ │ └── coco_map.py │ ├── eval_detection.py │ ├── eval_segmentation.py │ ├── evaluation_engine.py │ ├── fsdp_trainer.py │ ├── segmentation_utils │ │ ├── __init__.py │ │ └── cityscapes_iou.py │ └── utils.py ├── loss_fn │ ├── __init__.py │ ├── base_criteria.py │ ├── classification │ │ ├── __init__.py │ │ ├── base_classification_criteria.py │ │ ├── binary_cross_entropy.py │ │ ├── cross_entropy.py │ │ └── focal_loss.py │ ├── composite_loss.py │ ├── detection │ │ ├── __init__.py │ │ ├── base_detection_criteria.py │ │ ├── mask_rcnn_loss.py │ │ └── ssd_multibox_loss.py │ ├── distillation │ │ ├── __init__.py │ │ ├── base_distillation.py │ │ ├── hard_distillation.py │ │ └── soft_kl_distillation.py │ ├── language_modeling │ │ ├── __init__.py │ │ ├── base_lm.py │ │ ├── cross_entropy.py │ │ └── cross_entropy_for_kv_prediction.py │ ├── multi_modal_img_text │ │ ├── __init__.py │ │ ├── base_multi_modal_img_text_criteria.py │ │ └── contrastive_loss_clip.py │ ├── neural_augmentation.py │ ├── segmentation │ │ ├── __init__.py │ │ ├── base_segmentation_criteria.py │ │ └── cross_entropy.py │ └── utils │ │ ├── __init__.py │ │ ├── build_helper.py │ │ └── class_weighting.py ├── metrics │ ├── __init__.py │ ├── average_precision.py │ ├── coco_map.py │ ├── confusion_mat.py │ ├── image_text_retrieval.py │ ├── intersection_over_union.py │ ├── metric_base.py │ ├── metric_base_test.py │ ├── misc.py │ ├── multiclass_classification_pr.py │ ├── probability_histograms.py │ ├── psnr.py │ ├── retrieval_cmc.py │ ├── stats.py │ ├── topk_accuracy.py │ └── vqa_preset_score.py ├── modeling │ ├── __init__.py │ ├── anchor_generator │ │ ├── __init__.py │ │ ├── base_anchor_generator.py │ │ └── ssd_anchor_generator.py │ ├── image_projection_layers │ │ ├── __init__.py │ │ ├── attention_pool_2d.py │ │ ├── base_image_projection.py │ │ ├── global_pool_2d.py │ │ └── simple_projection_head.py │ ├── layers │ │ ├── __init__.py │ │ ├── activation │ │ │ ├── __init__.py │ │ │ ├── gelu.py │ │ │ ├── hard_sigmoid.py │ │ │ ├── hard_swish.py │ │ │ ├── leaky_relu.py │ │ │ ├── prelu.py │ │ │ ├── relu.py │ │ │ ├── relu6.py │ │ │ ├── sigmoid.py │ │ │ ├── swish.py │ │ │ └── tanh.py │ │ ├── adaptive_pool.py │ │ ├── base_layer.py │ │ ├── conv_layer.py │ │ ├── dropout.py │ │ ├── embedding.py │ │ ├── flash_multi_head_attention.py │ │ ├── flatten.py │ │ ├── global_pool.py │ │ ├── identity.py │ │ ├── linear_attention.py │ │ ├── linear_layer.py │ │ ├── multi_head_attention.py │ │ ├── normalization │ │ │ ├── __init__.py │ │ │ ├── batch_norm.py │ │ │ ├── group_norm.py │ │ │ ├── instance_norm.py │ │ │ ├── layer_norm.py │ │ │ ├── rms_norm.py │ │ │ └── sync_batch_norm.py │ │ ├── normalization_layers.py │ │ ├── pixel_shuffle.py │ │ ├── pooling.py │ │ ├── positional_embedding.py │ │ ├── positional_encoding.py │ │ ├── random_layers.py │ │ ├── rotary_embeddings.py │ │ ├── single_head_attention.py │ │ ├── softmax.py │ │ ├── stochastic_depth.py │ │ ├── token_merging.py │ │ └── upsample.py │ ├── matcher_det │ │ ├── __init__.py │ │ ├── base_matcher.py │ │ └── ssd_matcher.py │ ├── misc │ │ ├── __init__.py │ │ ├── averaging_utils.py │ │ ├── box_utils.py │ │ ├── common.py │ │ └── init_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── audio_classification │ │ │ ├── __init__.py │ │ │ ├── audio_byteformer.py │ │ │ └── base_audio_classification.py │ │ ├── base_model.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── base_image_encoder.py │ │ │ ├── byteformer.py │ │ │ ├── config │ │ │ │ ├── __init__.py │ │ │ │ ├── byteformer.py │ │ │ │ ├── efficientnet.py │ │ │ │ ├── fastvit.py │ │ │ │ ├── mobilenetv1.py │ │ │ │ ├── mobilenetv2.py │ │ │ │ ├── mobilenetv3.py │ │ │ │ ├── mobileone.py │ │ │ │ ├── mobilevit.py │ │ │ │ ├── mobilevit_v2.py │ │ │ │ ├── regnet.py │ │ │ │ ├── resnet.py │ │ │ │ ├── swin_transformer.py │ │ │ │ └── vit.py │ │ │ ├── efficientnet.py │ │ │ ├── fastvit.py │ │ │ ├── mobilenetv1.py │ │ │ ├── mobilenetv2.py │ │ │ ├── mobilenetv3.py │ │ │ ├── mobileone.py │ │ │ ├── mobilevit.py │ │ │ ├── mobilevit_v2.py │ │ │ ├── regnet.py │ │ │ ├── resnet.py │ │ │ ├── swin_transformer.py │ │ │ └── vit.py │ │ ├── detection │ │ │ ├── __init__.py │ │ │ ├── base_detection.py │ │ │ ├── mask_rcnn.py │ │ │ ├── ssd.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ └── rcnn_utils.py │ │ ├── fsdp_wrapper.py │ │ ├── language_modeling │ │ │ ├── __init__.py │ │ │ ├── base_lm.py │ │ │ ├── general_gpt.py │ │ │ └── kv_prediction.py │ │ ├── multi_modal_img_text │ │ │ ├── __init__.py │ │ │ ├── base_multi_modal_img_text.py │ │ │ └── clip.py │ │ ├── segmentation │ │ │ ├── __init__.py │ │ │ ├── base_seg.py │ │ │ ├── enc_dec.py │ │ │ └── heads │ │ │ │ ├── __init__.py │ │ │ │ ├── base_seg_head.py │ │ │ │ ├── deeplabv3.py │ │ │ │ ├── pspnet.py │ │ │ │ └── simple_seg_head.py │ │ └── video_classification │ │ │ ├── __init__.py │ │ │ └── base_video_encoder.py │ ├── modules │ │ ├── __init__.py │ │ ├── aspp_block.py │ │ ├── base_module.py │ │ ├── efficientnet.py │ │ ├── fastvit.py │ │ ├── feature_pyramid.py │ │ ├── flash_transformer.py │ │ ├── mobilenetv2.py │ │ ├── mobileone_block.py │ │ ├── mobilevit_block.py │ │ ├── pspnet_module.py │ │ ├── regnet_modules.py │ │ ├── resnet_modules.py │ │ ├── squeeze_excitation.py │ │ ├── ssd_heads.py │ │ ├── swin_transformer_block.py │ │ ├── transformer.py │ │ └── windowed_transformer.py │ ├── neural_augmentor │ │ ├── __init__.py │ │ ├── neural_aug.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── neural_aug_utils.py │ └── text_encoders │ │ ├── __init__.py │ │ ├── base_text_encoder.py │ │ └── transformer.py ├── optims │ ├── __init__.py │ ├── adam.py │ ├── adamw.py │ ├── base_optim.py │ ├── scheduler │ │ ├── __init__.py │ │ ├── base_scheduler.py │ │ ├── cosine.py │ │ ├── cyclic.py │ │ ├── fixed.py │ │ ├── multi_step.py │ │ └── polynomial.py │ └── sgd.py ├── options │ ├── __init__.py │ ├── errors.py │ ├── opts.py │ ├── parse_args.py │ └── utils.py ├── third_party │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ └── text_tokenizer │ │ │ ├── __init__.py │ │ │ └── openai_clip_tokenizer.py │ └── modeling │ │ ├── __init__.py │ │ ├── lora.py │ │ └── ssd_utils.py ├── train_eval_pipelines │ ├── __init__.py │ ├── base.py │ ├── default_train_eval.py │ └── fsdp_train_eval.py └── utils │ ├── __init__.py │ ├── activation_checkpointing_wrapper.py │ ├── check.py │ ├── checkpoint_utils.py │ ├── color_map.py │ ├── common_utils.py │ ├── context_managers.py │ ├── ddp_utils.py │ ├── dict_utils.py │ ├── download_utils.py │ ├── file_logger.py │ ├── fpdb.py │ ├── hf_adapter_utils.py │ ├── import_utils.py │ ├── io_utils.py │ ├── logger.py │ ├── math_utils.py │ ├── object_utils.py │ ├── object_utils_test.py │ ├── pytorch_to_coreml.py │ ├── registry.py │ ├── registry_test.py │ ├── resources.py │ ├── retry_utils.py │ ├── tensor_utils.py │ └── visualization_utils.py ├── mlx_examples ├── clip │ ├── README.md │ ├── __init__.py │ ├── clip.py │ ├── image_processor.py │ ├── main_clip_to_mlx.py │ ├── main_test_clip_mlx.py │ ├── model.py │ ├── requirements.txt │ ├── results │ │ └── .gitkeep │ └── tokenizer.py ├── open_elm │ ├── README.md │ ├── __init__.py │ ├── convert.py │ ├── inference.py │ └── open_elm.py └── requirements.txt ├── projects ├── byteformer │ ├── README.md │ ├── imagenet_file_encodings │ │ ├── encoding_type=PNG.yaml │ │ ├── encoding_type=TIFF.yaml │ │ ├── encoding_type=fCHW.yaml │ │ └── encoding_type=fHWC.yaml │ ├── imagenet_jpeg_q100 │ │ ├── conv_kernel_size=16.yaml │ │ ├── conv_kernel_size=32.yaml │ │ └── conv_kernel_size=8.yaml │ ├── imagenet_jpeg_q60 │ │ ├── conv_kernel_size=16,window_sizes=[128].yaml │ │ ├── conv_kernel_size=16,window_sizes=[32].yaml │ │ ├── conv_kernel_size=32,window_sizes=[128].yaml │ │ ├── conv_kernel_size=32,window_sizes=[32].yaml │ │ ├── conv_kernel_size=4,window_sizes=[128].yaml │ │ ├── conv_kernel_size=4,window_sizes=[32].yaml │ │ ├── conv_kernel_size=8,window_sizes=[128].yaml │ │ └── conv_kernel_size=8,window_sizes=[32].yaml │ ├── imagenet_jpeg_shuffle_bytes │ │ ├── mode=cyclic_half_length.yaml │ │ ├── mode=random_shuffle.yaml │ │ ├── mode=reverse.yaml │ │ ├── mode=stride.yaml │ │ └── mode=window_shuffle.yaml │ ├── imagenet_obfuscation │ │ ├── width_range=[-10,10].yaml │ │ ├── width_range=[-20,20].yaml │ │ ├── width_range=[-5,5].yaml │ │ └── width_range=[0,0].yaml │ ├── imagenet_privacy_preserving_camera │ │ ├── keep_frac=0.03,conv_kernel_size=4.yaml │ │ ├── keep_frac=0.05,conv_kernel_size=4.yaml │ │ ├── keep_frac=0.1,conv_kernel_size=4.yaml │ │ ├── keep_frac=0.25,conv_kernel_size=8.yaml │ │ ├── keep_frac=0.5,conv_kernel_size=16.yaml │ │ └── keep_frac=0.75,conv_kernel_size=32.yaml │ ├── model_arch.png │ ├── speech_commands_mp3 │ │ ├── conv_kernel_size=4,window_size=[128].yaml │ │ ├── conv_kernel_size=4,window_size=[32].yaml │ │ ├── conv_kernel_size=8,window_size=[128].yaml │ │ └── conv_kernel_size=8,window_size=[32].yaml │ └── speech_commands_wav │ │ ├── encoding_dtype=float32,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=float32,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=int16,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=int16,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=int16,conv_kernel_size=8.yaml │ │ ├── encoding_dtype=int32,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=int32,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=uint8,conv_kernel_size=16.yaml │ │ ├── encoding_dtype=uint8,conv_kernel_size=32.yaml │ │ ├── encoding_dtype=uint8,conv_kernel_size=4.yaml │ │ └── encoding_dtype=uint8,conv_kernel_size=8.yaml ├── catlip │ ├── README-multi-label-object-classification.md │ ├── README-object-detection.md │ ├── README-pretraining.md │ ├── README-semantic-segmentation.md │ ├── README-single-label-object-classification.md │ ├── README.md │ ├── image_classification │ │ ├── imagenet │ │ │ ├── vit_base.yaml │ │ │ ├── vit_base_512x512.yaml │ │ │ ├── vit_huge.yaml │ │ │ ├── vit_huge_512x512.yaml │ │ │ ├── vit_large.yaml │ │ │ └── vit_large_512x512.yaml │ │ └── places365 │ │ │ ├── vit_base.yaml │ │ │ ├── vit_base_512x512.yaml │ │ │ ├── vit_huge.yaml │ │ │ ├── vit_huge_512x512.yaml │ │ │ ├── vit_large.yaml │ │ │ └── vit_large_512x512.yaml │ ├── multi_label_image_classification │ │ ├── vit_base.yaml │ │ └── vit_large.yaml │ ├── object_detection │ │ ├── maskrcnn_vit_base.yaml │ │ ├── maskrcnn_vit_huge.yaml │ │ └── maskrcnn_vit_large.yaml │ ├── pretraining │ │ ├── vit_base.yaml │ │ ├── vit_huge.yaml │ │ └── vit_large.yaml │ └── semantic_segmentation │ │ ├── deeplabv3_vit_base.yaml │ │ ├── deeplabv3_vit_huge.yaml │ │ └── deeplabv3_vit_large.yaml ├── clip │ ├── README.md │ └── clip_vit_base.yaml ├── fastvit │ ├── README.md │ └── classification │ │ └── fastvit_t8_in1k.yaml ├── kv-prediction │ ├── README.md │ ├── model_arch.png │ ├── openelm │ │ ├── openelm_1_1B_0_25.yaml │ │ ├── openelm_1_1B_0_50.yaml │ │ ├── openelm_1_1B_0_75.yaml │ │ ├── openelm_1_1B_kvp_c_270M.yaml │ │ ├── openelm_1_1B_kvp_c_450M.yaml │ │ ├── openelm_1_1B_kvp_lp_0_25.yaml │ │ ├── openelm_1_1B_kvp_lp_0_50.yaml │ │ ├── openelm_1_1B_kvp_lp_0_75.yaml │ │ ├── openelm_3B_kvp_c_1_1B.yaml │ │ ├── openelm_3B_kvp_c_270M.yaml │ │ ├── openelm_3B_kvp_c_450M.yaml │ │ ├── openelm_3B_kvp_lp_0_25.yaml │ │ ├── openelm_3B_kvp_lp_0_50.yaml │ │ ├── openelm_3B_kvp_lp_0_75.yaml │ │ ├── openelm_base_3B_aux_0_25l.yaml │ │ ├── openelm_base_3B_aux_0_50l.yaml │ │ └── openelm_base_3B_aux_0_75l.yaml │ └── triviaqa-template.yaml ├── mobilenet_v1 │ ├── README.md │ └── classification │ │ └── mobilenetv1_1.0_in1k.yaml ├── mobilenet_v2 │ ├── README.md │ ├── classification │ │ └── mobilenetv2_1.0_in1k.yaml │ └── segmentation │ │ └── deeplabv3_ade20k.yaml ├── mobilenet_v3 │ ├── README.md │ └── classification │ │ └── mobilenetv3_large_in1k.yaml ├── mobileone │ ├── README.md │ └── classification │ │ └── mobileone_s1_in1k.yaml ├── mobilevit │ └── README.md ├── mobilevit_v2 │ ├── README.md │ ├── classification │ │ ├── mobilevitv2_2.0_ft_384x384.yaml │ │ └── mobilevitv2_2.0_in1k.yaml │ ├── detection │ │ └── mobilevitv2_2.0_ssd_coco.yaml │ └── segmentation │ │ └── deeplabv3_mobilevitv2_1.0_ade20k.yaml ├── openelm │ ├── README-instruct.md │ ├── README-peft.md │ ├── README-pretraining.md │ ├── README.md │ ├── instruction_tuning │ │ └── openelm-instruct.yaml │ ├── peft_configs │ │ ├── openelm_lora_1_1B.yaml │ │ ├── openelm_lora_270M.yaml │ │ ├── openelm_lora_270M_eval.yaml │ │ ├── openelm_lora_3B.yaml │ │ └── openelm_lora_450M.yaml │ └── pretraining_configs │ │ ├── openelm_1_1B.yaml │ │ ├── openelm_270M.yaml │ │ ├── openelm_3B.yaml │ │ └── openelm_450M.yaml ├── range_augment │ ├── README-classification.md │ ├── README-clip.md │ ├── README-distillation.md │ ├── README-object-detection.md │ ├── README-segmentation.md │ ├── README.md │ ├── classification │ │ ├── efficientnet_b0.yaml │ │ ├── efficientnet_b1.yaml │ │ ├── efficientnet_b2.yaml │ │ ├── efficientnet_b3.yaml │ │ ├── mobilenet_v1.yaml │ │ ├── mobilenet_v2.yaml │ │ ├── mobilenet_v3.yaml │ │ ├── mobilevit_v1.yaml │ │ ├── regnety_16gf.yaml │ │ ├── resnet_101.yaml │ │ ├── resnet_50.yaml │ │ ├── se_resnet_50.yaml │ │ ├── swin_transformer_small.yaml │ │ └── swin_transformer_tiny.yaml │ ├── clip │ │ ├── clip_vit_base.yaml │ │ └── clip_vit_huge.yaml │ ├── clip_finetune_imagenet │ │ ├── clip_vit_base.yaml │ │ └── clip_vit_huge.yaml │ ├── detection │ │ ├── maskrcnn_efficientnet_b3.yaml │ │ ├── maskrcnn_mobilenet_v1.yaml │ │ ├── maskrcnn_mobilenet_v2.yaml │ │ ├── maskrcnn_mobilenet_v3.yaml │ │ ├── maskrcnn_mobilevit.yaml │ │ ├── maskrcnn_resnet_101.yaml │ │ └── maskrcnn_resnet_50.yaml │ ├── distillation │ │ ├── teacher_resnet101_student_mobilenet_v1.yaml │ │ ├── teacher_resnet101_student_mobilenet_v2.yaml │ │ ├── teacher_resnet101_student_mobilenet_v3.yaml │ │ └── teacher_resnet101_student_mobilevit.yaml │ └── segmentation │ │ ├── ade20k │ │ ├── deeplabv3_efficientnet_b3.yaml │ │ ├── deeplabv3_mobilenet_v1.yaml │ │ ├── deeplabv3_mobilenet_v2.yaml │ │ ├── deeplabv3_mobilenet_v3.yaml │ │ ├── deeplabv3_mobilevit.yaml │ │ ├── deeplabv3_resnet_101.yaml │ │ └── deeplabv3_resnet_50.yaml │ │ └── pascal_voc │ │ ├── deeplabv3_efficientnet_b3.yaml │ │ ├── deeplabv3_mobilenet_v1.yaml │ │ ├── deeplabv3_mobilenet_v2.yaml │ │ ├── deeplabv3_mobilenet_v3.yaml │ │ ├── deeplabv3_resnet_101.yaml │ │ └── deeplabv3_resnet_50.yaml ├── resnet │ ├── README.md │ ├── classification │ │ └── resnet50_in1k.yaml │ └── detection │ │ └── ssd_resnet50_coco.yaml └── vit │ ├── README.md │ └── classification │ └── vit_base_in1k.yaml ├── pyproject.toml ├── requirements-optional.txt ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── configs.py ├── data │ ├── __init__.py │ ├── coco │ │ └── annotations │ │ │ ├── instances_train2017.json │ │ │ └── instances_val2017.json │ ├── collate_fns │ │ ├── __init__.py │ │ ├── test_byteformer_collate_fn.py │ │ └── test_collate_functions.py │ ├── datasets │ │ ├── __init__.py │ │ ├── audio_classification │ │ │ ├── __init__.py │ │ │ └── test_speech_commands_v2.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── dummy_configs │ │ │ │ ├── coco.yaml │ │ │ │ ├── image_classification_dataset.yaml │ │ │ │ ├── imagenet.yaml │ │ │ │ ├── imagenet_a.yaml │ │ │ │ ├── imagenet_r.yaml │ │ │ │ ├── imagenet_sketch.yaml │ │ │ │ └── wordnet_tagged_classification.yaml │ │ │ ├── dummy_images │ │ │ │ ├── training │ │ │ │ │ ├── class1 │ │ │ │ │ │ ├── dummy_image1.jpg │ │ │ │ │ │ └── dummy_image2.jpg │ │ │ │ │ └── class2 │ │ │ │ │ │ ├── dummy_image1.jpg │ │ │ │ │ │ └── dummy_image2.jpg │ │ │ │ └── validation │ │ │ │ │ ├── class1 │ │ │ │ │ ├── dummy_image1.jpg │ │ │ │ │ └── dummy_image2.jpg │ │ │ │ │ └── class2 │ │ │ │ │ ├── dummy_image1.jpg │ │ │ │ │ └── dummy_image2.jpg │ │ │ ├── mock_coco.py │ │ │ ├── mock_imagenet.py │ │ │ ├── mock_wordnet_tagged_classification.py │ │ │ ├── test_base_image_classification_dataset.py │ │ │ ├── test_mock_coco.py │ │ │ ├── test_mock_imagenet.py │ │ │ └── test_wordnet_tagged_classification.py │ │ ├── detection │ │ │ ├── __init__.py │ │ │ ├── mock_coco_mask_rcnn.py │ │ │ └── mock_coco_ssd.py │ │ ├── language_modeling │ │ │ ├── __init__.py │ │ │ ├── dummy_commonsense_170k.yaml │ │ │ ├── dummy_lm_dataset.yaml │ │ │ ├── mock_general_lm.py │ │ │ ├── test_commonsense_170k.py │ │ │ └── test_general_lm.py │ │ ├── multi_modal_img_text │ │ │ ├── __init__.py │ │ │ ├── dummy_img_text_tar_dataset.yaml │ │ │ ├── mock_img_text_tar_dataset.py │ │ │ ├── test_img_text_tar_dataset.py │ │ │ └── zero_shot_image_classification │ │ │ │ ├── __init__.py │ │ │ │ ├── dummy_configs │ │ │ │ ├── imagenet.yaml │ │ │ │ ├── imagenet_a.yaml │ │ │ │ ├── imagenet_r.yaml │ │ │ │ └── imagenet_sketch.yaml │ │ │ │ ├── mock_imagenet.py │ │ │ │ └── test_mock_imagenet.py │ │ ├── segmentation │ │ │ ├── __init__.py │ │ │ ├── dummy_ade20k_config.yaml │ │ │ ├── dummy_cocostuff_config.yaml │ │ │ ├── mock_ade20k.py │ │ │ ├── mock_coco_stuff.py │ │ │ ├── test_mock_ade20k.py │ │ │ └── test_mock_coco_stuff.py │ │ ├── test_dataset_base.py │ │ ├── test_image_pil.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── test_common.py │ │ │ └── test_video.py │ ├── dummy_silent_video.mov │ ├── dummy_video.mov │ ├── io │ │ ├── __init__.py │ │ └── test_transfer_clients.py │ ├── samplers │ │ ├── __init__.py │ │ ├── test_batch_sampler_config.yaml │ │ ├── test_chain_sampler.py │ │ ├── test_chain_sampler_config.yaml │ │ ├── test_data_samplers.py │ │ ├── test_multi_scale_sampler_config.yaml │ │ ├── test_variable_batch_sampler_config.yaml │ │ └── test_video_clip_batch_sampler_config.yaml │ ├── text_tokenizer │ │ ├── __init__.py │ │ ├── test_clip_tokenizer.py │ │ └── test_openai_clip_tokenizer.py │ └── video_reader │ │ ├── __init__.py │ │ ├── test_av_reader.py │ │ └── test_ffmpeg_utils.py ├── engine │ ├── __init__.py │ ├── dummy_configs │ │ ├── ade20k_segmentation │ │ │ └── deeplabv3_mobilenetv2.yaml │ │ ├── coco_detection │ │ │ ├── resnet_mask_rcnn.yaml │ │ │ └── resnet_ssd.yaml │ │ ├── image_text_clip │ │ │ └── clip_vit.yaml │ │ ├── imagenet_classification │ │ │ ├── efficientnet_b0.yaml │ │ │ ├── mobilevit.yaml │ │ │ └── mobilevit_v2.yaml │ │ └── language_modeling_gpt │ │ │ └── gpt.yaml │ └── test_training_engine.py ├── loss_fns │ ├── __init__.py │ ├── language_modeling │ │ ├── __init__.py │ │ ├── test_cross_entropy.py │ │ └── test_cross_entropy_for_kv_prediction.py │ ├── test_class_weighting.py │ ├── test_classification_loss.py │ ├── test_composite_loss.py │ ├── test_contrastive_loss.py │ ├── test_detection_loss.py │ ├── test_focal_loss.py │ ├── test_neural_aug.py │ ├── test_neural_aug_compatibility.py │ └── test_segmentation_loss.py ├── metrics │ ├── __init__.py │ ├── base.py │ ├── test_coco_map.py │ ├── test_image_text_retrieval_metrics.py │ ├── test_iou.py │ ├── test_misc.py │ ├── test_multiclass_classification_pr.py │ ├── test_probability_histogram.py │ ├── test_psnr.py │ ├── test_retrieval_cmc_metrics.py │ ├── test_topk_accuracy.py │ └── test_vqa_preset_score_metrics.py ├── misc │ ├── __init__.py │ ├── dummy_clip_config.yaml │ ├── dummy_linear_probe_config.yaml │ └── test_common.py ├── modeling │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── normalization_layers │ │ │ ├── __init__.py │ │ │ └── test_rms_norm.py │ │ ├── test_conv_layer.py │ │ ├── test_multi_head_attn.py │ │ ├── test_pos_embeddings.py │ │ ├── test_rotary_embeddings.py │ │ └── test_token_merging.py │ ├── models │ │ ├── __init__.py │ │ ├── audio_classification │ │ │ ├── __init__.py │ │ │ ├── test_base_audio_classification.py │ │ │ └── test_byteformer.py │ │ ├── classification │ │ │ ├── __init__.py │ │ │ ├── config │ │ │ │ ├── __init__.py │ │ │ │ ├── test_byteformer.py │ │ │ │ └── vit_config.yaml │ │ │ ├── test_byteformer.py │ │ │ └── test_vit.py │ │ ├── language_modeling │ │ │ ├── __init__.py │ │ │ ├── config │ │ │ │ ├── gpt_config.yaml │ │ │ │ └── kv_prediction_config.yaml │ │ │ ├── test_general_gpt.py │ │ │ └── test_kv_prediction.py │ │ ├── test_activation_checkpointing_wrapper.py │ │ ├── test_lora.py │ │ └── test_neural_aug_utils.py │ ├── modules │ │ ├── __init__.py │ │ ├── test_transformer.py │ │ └── test_windowed_transformer.py │ └── test_model.py ├── optims │ ├── __init__.py │ └── scheduler │ │ ├── __init__.py │ │ └── test_scheduler.py ├── options │ ├── __init__.py │ ├── test_parse_args.py │ └── test_utils.py ├── test_conventions.py ├── test_utils.py ├── transforms │ ├── __init__.py │ ├── test_audio.py │ ├── test_audio_bytes.py │ ├── test_image.py │ ├── test_image_bytes.py │ └── test_video.py └── utils │ ├── __init__.py │ ├── test_check.py │ ├── test_common_utils.py │ ├── test_dict_utils.py │ ├── test_download_utils.py │ ├── test_file_logger.py │ └── test_import_utils.py ├── tools ├── __init__.py └── converter_coco_stuff.py ├── tox.ini └── tutorials ├── clip.ipynb ├── guide_slurm_and_multi_node_training.md ├── object_detection.ipynb ├── semantic_segmentation.ipynb └── train_a_new_model_on_a_new_dataset_from_scratch.ipynb /.dockerignore: -------------------------------------------------------------------------------- 1 | # Docker-specific: 2 | **/Dockerfile 3 | setup_bolt.sh 4 | .git 5 | .gitignore 6 | .dockerignore 7 | 8 | # Mirroring .gitignore 9 | # Note: dockerignore matches paths only from the root dir, while gitignore matches 10 | # paths from any nested directory. Thus, you'll need to add **/ for some paths. 11 | .vscode/ 12 | .coverage 13 | *.pyc 14 | **/*.pyc 15 | __pycache__ 16 | **/__pycache__ 17 | .DS_STORE 18 | **/.DS_STORE 19 | .idea 20 | 21 | *.swp 22 | .pytest_cache 23 | .mypy_cache 24 | .corenet_data_cache 25 | 26 | build/ 27 | 28 | results* 29 | vision_datasets/ 30 | exp_results/ 31 | exp_results* 32 | results_* 33 | 34 | *.so 35 | **/*.so 36 | model_zoo 37 | model_zoo/* 38 | pipeline.yaml 39 | 40 | *.egg-info 41 | **/*.egg-info 42 | 43 | venv/ 44 | 45 | trash/ 46 | mlx_model/ 47 | .tox/ 48 | 49 | # The playground_* files get generated by tutorials/train_new_model_on_new_dataset_from_scratch.ipynb 50 | **/playground_* 51 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.tar filter=lfs diff=lfs merge=lfs -text 2 | *.gz filter=lfs diff=lfs merge=lfs -text 3 | *.zip filter=lfs diff=lfs merge=lfs -text 4 | *.zst filter=lfs diff=lfs merge=lfs -text 5 | *.zstd filter=lfs diff=lfs merge=lfs -text 6 | *.ipynb filter=lfs diff=lfs merge=lfs -text 7 | *.png filter=lfs diff=lfs merge=lfs -text 8 | *.jpeg filter=lfs diff=lfs merge=lfs -text 9 | *.jpg filter=lfs diff=lfs merge=lfs -text 10 | *.whl filter=lfs diff=lfs merge=lfs -text 11 | *.npy filter=lfs diff=lfs merge=lfs -text 12 | *.npz filter=lfs diff=lfs merge=lfs -text 13 | *.pt filter=lfs diff=lfs merge=lfs -text 14 | *.mov filter=lfs diff=lfs merge=lfs -text 15 | *.mp4 filter=lfs diff=lfs merge=lfs -text 16 | *.pdf filter=lfs diff=lfs merge=lfs -text 17 | *.tif filter=lfs diff=lfs merge=lfs -text 18 | *.tiff filter=lfs diff=lfs merge=lfs -text 19 | *.model filter=lfs diff=lfs merge=lfs -text 20 | *.parquet filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.safetensors filter=lfs diff=lfs merge=lfs -text 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Note: please mirror changes of this file to .dockerignore. 2 | 3 | .vscode/ 4 | .coverage 5 | *.pyc 6 | __pycache__ 7 | .DS_STORE 8 | .idea 9 | *.swp 10 | .pytest_cache 11 | .mypy_cache 12 | .corenet_data_cache 13 | 14 | build/ 15 | 16 | /results* 17 | vision_datasets/ 18 | exp_results/ 19 | exp_results* 20 | results_* 21 | 22 | *.so 23 | model_zoo 24 | model_zoo/* 25 | pipeline.yaml 26 | 27 | *.egg-info 28 | 29 | venv/ 30 | 31 | trash 32 | 33 | mlx_model/ 34 | .tox/ 35 | 36 | # The playground_* files get generated by tutorials/train_new_model_on_new_dataset_from_scratch.ipynb 37 | playground_* 38 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /assets/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/assets/cat.jpeg -------------------------------------------------------------------------------- /assets/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/assets/dog.jpeg -------------------------------------------------------------------------------- /corenet/__main__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/__main__.py -------------------------------------------------------------------------------- /corenet/__version__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | version = "0.1.1" 6 | -------------------------------------------------------------------------------- /corenet/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/cli/__init__.py -------------------------------------------------------------------------------- /corenet/cli/entrypoints.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Dict, Tuple 7 | 8 | # Entrypoints is a mapping from shell executable name to (module, function) pair. 9 | # Having too many entrypoints in setup.py limits us our ability to add features or 10 | # refactor the code, because users who pull the latest changes will have to re-install 11 | # corenet in order for `setup.py` changes to apply. 12 | # A better practice is to stop introducing new entrypoints, and add subcommands to the 13 | # main `corenet` entrypoint. Currently, `corenet train` is identical to `corenet-train`. 14 | entrypoints: Dict[str, Tuple[str, str]] = { 15 | "corenet-train": ("corenet.cli.main_train", "main_worker"), 16 | "corenet-eval": ("corenet.cli.main_eval", "main_worker"), 17 | "corenet-eval-llmadapters": ( 18 | "corenet.cli.main_eval_llmadapters", 19 | "main_eval_llmadapters", 20 | ), 21 | "corenet-eval-seg": ("corenet.cli.main_eval", "main_worker_segmentation"), 22 | "corenet-eval-det": ("corenet.cli.main_eval", "main_worker_detection"), 23 | "corenet-convert": ("corenet.cli.main_conversion", "main_worker_conversion"), 24 | "corenet": ("corenet.cli.main", "main"), 25 | } 26 | 27 | console_scripts = [ 28 | f"{entrypoint} = {module}:{func}" 29 | for entrypoint, (module, func) in entrypoints.items() 30 | ] 31 | -------------------------------------------------------------------------------- /corenet/cli/main.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import importlib 8 | from itertools import chain 9 | from typing import Dict, List, Optional, Tuple 10 | 11 | from corenet.cli.entrypoints import entrypoints as oss_entrypoints 12 | 13 | try: 14 | from corenet.internal.cli.entrypoints import entrypoints as internal_entrypoints 15 | except ModuleNotFoundError: 16 | internal_entrypoints = {} 17 | 18 | 19 | def main(args: Optional[List[str]] = None) -> None: 20 | """ 21 | We are planning to deprecate `corenet-train`, `corenet-eval`, ... commands for 22 | `corenet train` (the dash is removed), `corenet eval`, ... because adding/renaming 23 | entrypoints will require `pip install -e .`. Most users don't reinstall corenet 24 | after pulling the git repo. Hence, relying on a single entrypoint `corenet` with 25 | subcommands is more future proof. 26 | """ 27 | entrypoints = { 28 | k.replace("corenet-", ""): v 29 | for k, v in chain(oss_entrypoints.items(), internal_entrypoints.items()) 30 | } 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("entrypoint", choices=list(entrypoints.keys())) 33 | entrypoint_opts, args = parser.parse_known_args(args) 34 | module_name, func_name = entrypoints[entrypoint_opts.entrypoint] 35 | getattr(importlib.import_module(module_name), func_name)(args) 36 | -------------------------------------------------------------------------------- /corenet/cli/main_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import List, Optional 6 | 7 | from corenet.options.opts import get_training_arguments 8 | from corenet.train_eval_pipelines import ( 9 | TRAIN_EVAL_PIPELINE_REGISTRY, 10 | BaseTrainEvalPipeline, 11 | ) 12 | 13 | 14 | def main(train_eval_pipeline: BaseTrainEvalPipeline): 15 | """ 16 | This function will be invoked on each gpu worker process. 17 | 18 | Args: 19 | train_eval_pipeline: Provides major pipeline components. The class to be used is 20 | configurable by "--train-eval-pipeline.name" opt. By default, an instance of 21 | ``train_eval_pipelines.TrainEvalPipeline`` will be passed to this function. 22 | """ 23 | evaluation_engine = train_eval_pipeline.evaluation_engine 24 | evaluation_engine.run() 25 | 26 | 27 | def main_worker(args: Optional[List[str]] = None): 28 | opts = get_training_arguments(args=args) 29 | pipeline_name = getattr(opts, "train_eval_pipeline.name") 30 | train_eval_pipeline = TRAIN_EVAL_PIPELINE_REGISTRY[pipeline_name](opts=opts) 31 | launcher = train_eval_pipeline.launcher 32 | launcher(main) 33 | 34 | 35 | # for segmentation and detection, we follow a different evaluation pipeline that allows to save the results too 36 | def main_worker_segmentation(args: Optional[List[str]] = None, **kwargs): 37 | from corenet.engine.eval_segmentation import main_segmentation_evaluation 38 | 39 | main_segmentation_evaluation(args=args, **kwargs) 40 | 41 | 42 | def main_worker_detection(args: Optional[List[str]] = None, **kwargs): 43 | from corenet.engine.eval_detection import main_detection_evaluation 44 | 45 | main_detection_evaluation(args=args, **kwargs) 46 | 47 | 48 | if __name__ == "__main__": 49 | main_worker() 50 | -------------------------------------------------------------------------------- /corenet/cli/main_train.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Optional 7 | 8 | from torch.distributed.elastic.multiprocessing import errors 9 | 10 | from corenet.options.opts import get_training_arguments 11 | from corenet.train_eval_pipelines import ( 12 | TRAIN_EVAL_PIPELINE_REGISTRY, 13 | BaseTrainEvalPipeline, 14 | ) 15 | 16 | 17 | @errors.record 18 | def callback(train_eval_pipeline: BaseTrainEvalPipeline) -> None: 19 | """ 20 | This function will be invoked on each gpu worker process. 21 | 22 | Args: 23 | train_eval_pipeline: Provides major pipeline components. The class to be used is 24 | configurable by "--train-eval-pipeline.name" opt. By default, an instance of 25 | ``train_eval_pipelines.TrainEvalPipeline`` will be passed to this function. 26 | """ 27 | train_sampler = train_eval_pipeline.train_sampler 28 | train_eval_pipeline.training_engine.run(train_sampler=train_sampler) 29 | 30 | 31 | def main_worker(args: Optional[List[str]] = None): 32 | opts = get_training_arguments(args=args) 33 | pipeline_name = getattr(opts, "train_eval_pipeline.name") 34 | train_eval_pipeline = TRAIN_EVAL_PIPELINE_REGISTRY[pipeline_name](opts=opts) 35 | launcher = train_eval_pipeline.launcher 36 | launcher(callback) 37 | 38 | 39 | if __name__ == "__main__": 40 | main_worker() 41 | -------------------------------------------------------------------------------- /corenet/data/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.data.data_loaders import create_test_loader, create_train_val_loader 7 | -------------------------------------------------------------------------------- /corenet/data/datasets/audio_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/data/datasets/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/data/datasets/classification/imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.datasets import DATASET_REGISTRY 9 | from corenet.data.datasets.classification.base_image_classification_dataset import ( 10 | BaseImageClassificationDataset, 11 | ) 12 | 13 | 14 | @DATASET_REGISTRY.register(name="imagenet", type="classification") 15 | class ImageNetDataset(BaseImageClassificationDataset): 16 | """ 17 | ImageNet dataset that follows the structure of ImageClassificationDataset. 18 | 19 | "ImageNet: A large-scale hierarchical image database" 20 | Jia Deng; Wei Dong; Richard Socher; Li-Jia Li; Kai Li; Li Fei-Fei 21 | 2009 IEEE Conference on Computer Vision and Pattern Recognition 22 | """ 23 | 24 | def __init__( 25 | self, 26 | opts: argparse.Namespace, 27 | *args, 28 | **kwargs, 29 | ) -> None: 30 | BaseImageClassificationDataset.__init__( 31 | self, 32 | opts=opts, 33 | *args, 34 | **kwargs, 35 | ) 36 | -------------------------------------------------------------------------------- /corenet/data/datasets/classification/imagenet_a.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """ImageNetA dataset, a distribution shift of ImageNet.""" 6 | import argparse 7 | 8 | from corenet.data.datasets import DATASET_REGISTRY 9 | from corenet.data.datasets.classification.base_imagenet_shift_dataset import ( 10 | BaseImageNetShiftDataset, 11 | ) 12 | from corenet.data.datasets.classification.imagenet_synsets import ( 13 | IMAGENET_A_SYNSETS, 14 | IMAGENET_SYNSETS, 15 | ) 16 | 17 | IMAGENET_A_CLASS_SUBLIST = [ 18 | IMAGENET_SYNSETS.index(IMAGENET_A_SYNSETS[synset]) 19 | for synset in range(len(IMAGENET_A_SYNSETS)) 20 | ] 21 | 22 | 23 | @DATASET_REGISTRY.register(name="imagenet_a", type="classification") 24 | class ImageNetADataset(BaseImageNetShiftDataset): 25 | """ImageNetA dataset, a distribution shift of ImageNet. 26 | 27 | ImageNet-A contains real-world, unmodified natural images that cause model accuracy 28 | to substantially degrade. 29 | 30 | @article{hendrycks2021nae, 31 | title={Natural Adversarial Examples}, 32 | author={Dan Hendrycks and Kevin Zhao and Steven Basart and Jacob Steinhardt and Dawn 33 | Song}, 34 | journal={CVPR}, 35 | year={2021} 36 | } 37 | """ 38 | 39 | def __init__( 40 | self, 41 | opts: argparse.Namespace, 42 | *args, 43 | **kwargs, 44 | ) -> None: 45 | """Initialize ImageNetA.""" 46 | BaseImageNetShiftDataset.__init__(self, opts=opts, *args, **kwargs) 47 | 48 | @staticmethod 49 | def class_id_to_imagenet_class_id(class_id: int) -> int: 50 | """Return the mapped class index using precomputed mapping.""" 51 | return IMAGENET_A_CLASS_SUBLIST[class_id] 52 | -------------------------------------------------------------------------------- /corenet/data/datasets/classification/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """ImageNetSketch dataset, a distribution shift of ImageNet.""" 6 | import argparse 7 | 8 | from corenet.data.datasets import DATASET_REGISTRY 9 | from corenet.data.datasets.classification.base_imagenet_shift_dataset import ( 10 | BaseImageNetShiftDataset, 11 | ) 12 | 13 | 14 | @DATASET_REGISTRY.register(name="imagenet_sketch", type="classification") 15 | class ImageNetSketchDataset(BaseImageNetShiftDataset): 16 | """ImageNetSketch dataset, a distribution shift of ImageNet. 17 | 18 | Data set is created from Google Image queries "sketch of __", where __ is the 19 | standard class name. Search is only within the "black and white" color scheme. 20 | 21 | @inproceedings{wang2019learning, 22 | title={Learning Robust Global Representations by Penalizing Local Predictive 23 | Power}, 24 | author={Wang, Haohan and Ge, Songwei and Lipton, Zachary and Xing, Eric P}, 25 | booktitle={Advances in Neural Information Processing Systems}, 26 | pages={10506--10518}, 27 | year={2019} 28 | } 29 | """ 30 | 31 | def __init__( 32 | self, 33 | opts: argparse.Namespace, 34 | *args, 35 | **kwargs, 36 | ) -> None: 37 | """Initialize ImageNetSketchDataset.""" 38 | BaseImageNetShiftDataset.__init__(self, opts=opts, *args, **kwargs) 39 | 40 | @staticmethod 41 | def class_id_to_imagenet_class_id(class_id: int) -> int: 42 | """Return `class_id` as the ImageNet Sketch classes are the same as ImageNet.""" 43 | return class_id 44 | -------------------------------------------------------------------------------- /corenet/data/datasets/classification/places365.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.datasets import DATASET_REGISTRY 9 | from corenet.data.datasets.classification.base_image_classification_dataset import ( 10 | BaseImageClassificationDataset, 11 | ) 12 | 13 | 14 | @DATASET_REGISTRY.register(name="places365", type="classification") 15 | class Places365Dataset(BaseImageClassificationDataset): 16 | """ 17 | Places365 dataset that follows the structure of ImageClassificationDataset. 18 | 19 | "Places: A 10 million Image Database for Scene Recognition" 20 | B. Zhou, A. Lapedriza, A. Khosla, A. Oliva, and A. Torralba 21 | IEEE Transactions on Pattern Analysis and Machine Intelligence, 2017 22 | """ 23 | 24 | def __init__( 25 | self, 26 | opts: argparse.Namespace, 27 | *args, 28 | **kwargs, 29 | ) -> None: 30 | BaseImageClassificationDataset.__init__( 31 | self, 32 | opts=opts, 33 | *args, 34 | **kwargs, 35 | ) 36 | -------------------------------------------------------------------------------- /corenet/data/datasets/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/data/datasets/language_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/data/datasets/language_modeling/__init__.py -------------------------------------------------------------------------------- /corenet/data/datasets/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.datasets.multi_modal_img_text.base_multi_modal_img_text import ( 9 | BaseMultiModalImgText, 10 | ) 11 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification import ( 12 | arguments_zero_shot_image_classification_dataset, 13 | ) 14 | 15 | 16 | def arguments_multi_modal_img_text( 17 | parser: argparse.ArgumentParser, 18 | ) -> argparse.ArgumentParser: 19 | 20 | parser = arguments_zero_shot_image_classification_dataset(parser) 21 | parser = BaseMultiModalImgText.add_arguments(parser) 22 | return parser 23 | -------------------------------------------------------------------------------- /corenet/data/datasets/multi_modal_img_text/zero_shot_image_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.base_zero_shot_image_classification import ( 9 | BaseZeroShotImageClassificationDataset, 10 | ) 11 | from corenet.utils.registry import Registry 12 | 13 | ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY = Registry( 14 | registry_name="zero_shot_datasets", 15 | base_class=BaseZeroShotImageClassificationDataset, 16 | lazy_load_dirs=[ 17 | "corenet/data/datasets/multi_modal_img_text/zero_shot_image_classification" 18 | ], 19 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 20 | ) 21 | 22 | 23 | def arguments_zero_shot_image_classification_dataset( 24 | parser: argparse.ArgumentParser, 25 | ) -> argparse.ArgumentParser: 26 | """Helper function to get zero-shot dataset arguments""" 27 | parser = BaseZeroShotImageClassificationDataset.add_arguments(parser=parser) 28 | parser = ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY.all_arguments(parser) 29 | return parser 30 | 31 | 32 | def build_zero_shot_image_classification_dataset( 33 | opts: argparse.Namespace, *args, **kwargs 34 | ) -> BaseZeroShotImageClassificationDataset: 35 | """Helper function to build the zero shot image classification dataset.""" 36 | zero_shot_dataset_name = getattr( 37 | opts, "dataset.multi_modal_img_text.zero_shot_img_cls_dataset_name" 38 | ) 39 | return ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY[zero_shot_dataset_name]( 40 | opts, *args, **kwargs 41 | ) 42 | -------------------------------------------------------------------------------- /corenet/data/datasets/multi_modal_img_text/zero_shot_image_classification/imagenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import List 8 | 9 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification import ( 10 | ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY, 11 | BaseZeroShotImageClassificationDataset, 12 | ) 13 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.imagenet_class_names import ( 14 | IMAGENET_CLASS_NAMES, 15 | ) 16 | 17 | 18 | @ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY.register(name="imagenet") 19 | class ImageNetDatasetZeroShot(BaseZeroShotImageClassificationDataset): 20 | """ImageNet dataset for zero-shot evaluation of image-text models.""" 21 | 22 | @property 23 | def class_names(self) -> List[str]: 24 | """Return the name of the classes present in the dataset.""" 25 | return IMAGENET_CLASS_NAMES 26 | -------------------------------------------------------------------------------- /corenet/data/datasets/multi_modal_img_text/zero_shot_image_classification/imagenet_a.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List 7 | 8 | from corenet.data.datasets.classification.imagenet_a import ( 9 | IMAGENET_A_CLASS_SUBLIST, 10 | ImageNetADataset, 11 | ) 12 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification import ( 13 | ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY, 14 | ) 15 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.base_zero_shot_image_classification import ( 16 | BaseZeroShotImageClassificationDataset, 17 | ) 18 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.imagenet_class_names import ( 19 | IMAGENET_CLASS_NAMES, 20 | ) 21 | 22 | 23 | @ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY.register(name="imagenet_a") 24 | class ImageNetADatasetZeroShot(BaseZeroShotImageClassificationDataset): 25 | """ImageNetA Dataset for zero-shot evaluation of Image-text models.""" 26 | 27 | @property 28 | def class_names(self) -> List[str]: 29 | """Return the name of the classes present in the dataset.""" 30 | 31 | return [ 32 | IMAGENET_CLASS_NAMES[ImageNetADataset.class_id_to_imagenet_class_id(i)] 33 | for i in range(len(IMAGENET_A_CLASS_SUBLIST)) 34 | ] 35 | -------------------------------------------------------------------------------- /corenet/data/datasets/multi_modal_img_text/zero_shot_image_classification/imagenet_r.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List 7 | 8 | from corenet.data.datasets.classification.imagenet_r import ( 9 | IMAGENET_R_CLASS_SUBLIST, 10 | ImageNetRDataset, 11 | ) 12 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification import ( 13 | ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY, 14 | ) 15 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.base_zero_shot_image_classification import ( 16 | BaseZeroShotImageClassificationDataset, 17 | ) 18 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.imagenet import ( 19 | IMAGENET_CLASS_NAMES, 20 | ) 21 | 22 | 23 | @ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY.register(name="imagenet_r") 24 | class ImageNetRDatasetZeroShot( 25 | BaseZeroShotImageClassificationDataset, ImageNetRDataset 26 | ): 27 | """ImageNet-R dataset for zero-shot evaluation of Image-text models.""" 28 | 29 | @property 30 | def class_names(self) -> List[str]: 31 | """Return the name of the classes present in the dataset.""" 32 | return [ 33 | IMAGENET_CLASS_NAMES[ImageNetRDataset.class_id_to_imagenet_class_id(i)] 34 | for i in range(len(IMAGENET_R_CLASS_SUBLIST)) 35 | ] 36 | -------------------------------------------------------------------------------- /corenet/data/datasets/multi_modal_img_text/zero_shot_image_classification/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List 7 | 8 | from corenet.data.datasets.classification.imagenet_sketch import ImageNetSketchDataset 9 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification import ( 10 | ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY, 11 | ) 12 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.base_zero_shot_image_classification import ( 13 | BaseZeroShotImageClassificationDataset, 14 | ) 15 | from corenet.data.datasets.multi_modal_img_text.zero_shot_image_classification.imagenet import ( 16 | IMAGENET_CLASS_NAMES, 17 | ) 18 | 19 | 20 | @ZERO_SHOT_IMAGE_CLASSIFICATION_DATASET_REGISTRY.register(name="imagenet_sketch") 21 | class ImageNetSketchDatasetZeroShot( 22 | BaseZeroShotImageClassificationDataset, ImageNetSketchDataset 23 | ): 24 | """ImageNet-Sketch Dataset for zero-shot evaluation of Image-text models.""" 25 | 26 | @property 27 | def class_names(self) -> List[str]: 28 | """Return the name of the classes present in the dataset.""" 29 | return IMAGENET_CLASS_NAMES 30 | -------------------------------------------------------------------------------- /corenet/data/datasets/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/data/datasets/segmentation/__init__.py -------------------------------------------------------------------------------- /corenet/data/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/data/datasets/utils/text.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import re 7 | import urllib 8 | 9 | import ftfy 10 | 11 | 12 | def caption_preprocessing(caption: str) -> str: 13 | """Removes the unwanted tokens (e.g., HTML tokens, next line, unwanted spaces) from 14 | the text.""" 15 | # captions may contain HTML tokens. Remove them 16 | html_re = re.compile("<.*?>") 17 | caption = urllib.parse.unquote(str(caption)) 18 | caption = caption.replace("+", " ") 19 | caption = re.sub(html_re, "", str(caption)) 20 | # remove the next line 21 | caption = caption.strip("\n") 22 | # remove unwanted spaces 23 | caption = re.sub(" +", " ", caption) 24 | 25 | caption = ftfy.fix_text(caption) 26 | return caption.strip().lower() 27 | -------------------------------------------------------------------------------- /corenet/data/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/data/io/__init__.py -------------------------------------------------------------------------------- /corenet/data/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/data/loader/__init__.py -------------------------------------------------------------------------------- /corenet/data/text_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.text_tokenizer.base_tokenizer import BaseTextTokenizer 9 | from corenet.utils import logger 10 | from corenet.utils.registry import Registry 11 | 12 | TOKENIZER_REGISTRY = Registry( 13 | "tokenizer", 14 | base_class=BaseTextTokenizer, 15 | lazy_load_dirs=[ 16 | "corenet/data/text_tokenizer", 17 | "corenet/third_party/data/text_tokenizer", 18 | ], 19 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 20 | ) 21 | 22 | 23 | def arguments_tokenizer(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 24 | # add arguments for text_tokenizer 25 | parser = BaseTextTokenizer.add_arguments(parser) 26 | 27 | # add class specific arguments 28 | parser = TOKENIZER_REGISTRY.all_arguments(parser) 29 | return parser 30 | 31 | 32 | def build_tokenizer(opts, *args, **kwargs) -> BaseTextTokenizer: 33 | """Helper function to build the text tokenizer from command-line arguments. 34 | 35 | Args: 36 | opts: Command-line arguments 37 | 38 | Returns: 39 | Image projection head module. 40 | """ 41 | tokenizer_name = getattr(opts, "text_tokenizer.name", None) 42 | 43 | # We registered the base class using a special `name` (i.e., `__base__`) 44 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 45 | # be used. Therefore, we raise an error for such cases 46 | if tokenizer_name == "__base__": 47 | logger.error("__base__ can't be used as a projection name. Please check.") 48 | 49 | tokenizer = TOKENIZER_REGISTRY[tokenizer_name](opts, *args, **kwargs) 50 | return tokenizer 51 | -------------------------------------------------------------------------------- /corenet/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.transforms.base_transforms import BaseTransformation 9 | from corenet.utils.registry import Registry 10 | 11 | TRANSFORMATIONS_REGISTRY = Registry( 12 | "transformation", 13 | base_class=BaseTransformation, 14 | lazy_load_dirs=["corenet/data/transforms"], 15 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 16 | ) 17 | 18 | 19 | def arguments_augmentation(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 20 | # add arguments for base image projection layer 21 | parser = BaseTransformation.add_arguments(parser) 22 | 23 | # add augmentation specific arguments 24 | parser = TRANSFORMATIONS_REGISTRY.all_arguments(parser) 25 | return parser 26 | -------------------------------------------------------------------------------- /corenet/data/transforms/audio_aux/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/data/transforms/audio_aux/__init__.py -------------------------------------------------------------------------------- /corenet/data/transforms/base_transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | from typing import Dict 8 | 9 | 10 | class BaseTransformation(object): 11 | """ 12 | Base class for augmentation methods 13 | """ 14 | 15 | def __init__(self, opts, *args, **kwargs) -> None: 16 | self.opts = opts 17 | 18 | def __call__(self, data: Dict) -> Dict: 19 | raise NotImplementedError 20 | 21 | def __repr__(self) -> str: 22 | return "{}()".format(self.__class__.__name__) 23 | 24 | @classmethod 25 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 26 | return parser 27 | -------------------------------------------------------------------------------- /corenet/data/transforms/common.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Dict, List 6 | 7 | from corenet.data.transforms import TRANSFORMATIONS_REGISTRY, BaseTransformation 8 | 9 | 10 | @TRANSFORMATIONS_REGISTRY.register(name="compose", type="common") 11 | class Compose(BaseTransformation): 12 | """ 13 | This method applies a list of transforms in a sequential fashion. 14 | """ 15 | 16 | def __init__(self, opts, img_transforms: List, *args, **kwargs) -> None: 17 | super().__init__(opts=opts) 18 | self.img_transforms = img_transforms 19 | 20 | def __call__(self, data: Dict) -> Dict: 21 | for t in self.img_transforms: 22 | data = t(data) 23 | return data 24 | 25 | def __repr__(self) -> str: 26 | transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.img_transforms) 27 | repr_str = "{}({}\n\t\t)".format(self.__class__.__name__, transform_str) 28 | return repr_str 29 | 30 | 31 | @TRANSFORMATIONS_REGISTRY.register(name="identity", type="common") 32 | class Identity(BaseTransformation): 33 | """ 34 | This is a no-op transformation that returns its inputs unchanged. 35 | """ 36 | 37 | def __call__(self, data: Dict) -> Dict: 38 | return data 39 | 40 | def __repr__(self) -> str: 41 | return f"{self.__class__.__name__}()" 42 | -------------------------------------------------------------------------------- /corenet/data/transforms/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any 7 | 8 | import numpy as np 9 | 10 | 11 | def setup_size(size: Any, error_msg="Need a tuple of length 2"): 12 | if size is None: 13 | raise ValueError("Size can't be None") 14 | 15 | if isinstance(size, int): 16 | return size, size 17 | elif isinstance(size, (list, tuple)) and len(size) == 1: 18 | return size[0], size[0] 19 | 20 | if len(size) != 2: 21 | raise ValueError(error_msg) 22 | 23 | return size 24 | 25 | 26 | def intersect(box_a, box_b): 27 | """Computes the intersection between box_a and box_b""" 28 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 29 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 30 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 31 | return inter[:, 0] * inter[:, 1] 32 | 33 | 34 | def jaccard_numpy(box_a: np.ndarray, box_b: np.ndarray): 35 | """ 36 | Computes the intersection of two boxes. 37 | Args: 38 | box_a (np.ndarray): Boxes of shape [Num_boxes_A, 4] 39 | box_b (np.ndarray): Box osf shape [Num_boxes_B, 4] 40 | 41 | Returns: 42 | intersection over union scores. Shape is [box_a.shape[0], box_a.shape[1]] 43 | """ 44 | inter = intersect(box_a, box_b) 45 | area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) # [A,B] 46 | area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) # [A,B] 47 | union = area_a + area_b - inter 48 | return inter / union # [A,B] 49 | -------------------------------------------------------------------------------- /corenet/data/video_reader/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.data.video_reader.base_av_reader import BaseAVReader 9 | from corenet.utils import logger 10 | from corenet.utils.ddp_utils import is_master 11 | from corenet.utils.registry import Registry 12 | 13 | VIDEO_READER_REGISTRY = Registry( 14 | "video_reader", 15 | base_class=BaseAVReader, 16 | lazy_load_dirs=["corenet/data/video_reader"], 17 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_video_reader(parser: argparse.ArgumentParser): 22 | parser = BaseAVReader.add_arguments(parser=parser) 23 | 24 | # add video reader specific arguments 25 | parser = VIDEO_READER_REGISTRY.all_arguments(parser) 26 | return parser 27 | 28 | 29 | def get_video_reader( 30 | opts: argparse.Namespace, log: bool = True, *args, **kwargs 31 | ) -> BaseAVReader: 32 | """Helper function to build the video reader from command-line arguments. 33 | 34 | Args: 35 | opts: Command-line arguments 36 | log: When True, the video reader details will be logged to stdout. 37 | """ 38 | 39 | video_reader_name = getattr(opts, "video_reader.name") 40 | 41 | # We registered the base class using a special `name` (i.e., `__base__`) 42 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 43 | # be used. Therefore, we raise an error for such cases 44 | if video_reader_name == "__base__": 45 | logger.error("__base__ can't be used as a projection name. Please check.") 46 | 47 | video_reader = VIDEO_READER_REGISTRY[video_reader_name](opts, *args, **kwargs) 48 | 49 | is_master_node = is_master(opts) 50 | if log and is_master_node: 51 | logger.log("Video reader details: ") 52 | print("{}".format(video_reader)) 53 | return video_reader 54 | -------------------------------------------------------------------------------- /corenet/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.engine.default_trainer import DefaultTrainer 7 | from corenet.engine.evaluation_engine import Evaluator 8 | -------------------------------------------------------------------------------- /corenet/engine/detection_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/engine/detection_utils/__init__.py -------------------------------------------------------------------------------- /corenet/engine/segmentation_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/engine/segmentation_utils/__init__.py -------------------------------------------------------------------------------- /corenet/engine/segmentation_utils/cityscapes_iou.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import glob 7 | import os 8 | 9 | import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_semseg_eval 10 | 11 | from corenet.utils import logger 12 | 13 | 14 | def eval_cityscapes(pred_dir: str, gt_dir: str) -> None: 15 | """Utility to evaluate on cityscapes dataset""" 16 | cityscapes_semseg_eval.args.predictionPath = pred_dir 17 | cityscapes_semseg_eval.args.predictionWalk = None 18 | cityscapes_semseg_eval.args.JSONOutput = False 19 | cityscapes_semseg_eval.args.colorized = False 20 | 21 | gt_img_list = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png")) 22 | if len(gt_img_list) == 0: 23 | logger.error("Cannot find ground truth images at: {}".format(gt_dir)) 24 | 25 | pred_img_list = [] 26 | for gt in gt_img_list: 27 | pred_img_list.append( 28 | cityscapes_semseg_eval.getPrediction(cityscapes_semseg_eval.args, gt) 29 | ) 30 | 31 | results = cityscapes_semseg_eval.evaluateImgLists( 32 | pred_img_list, gt_img_list, cityscapes_semseg_eval.args 33 | ) 34 | 35 | logger.info("Evaluation results summary") 36 | eval_res_str = "\n\t IoU_cls: {:.2f} \n\t iIOU_cls: {:.2f} \n\t IoU_cat: {:.2f} \n\t iIOU_cat: {:.2f}".format( 37 | 100.0 * results["averageScoreClasses"], 38 | 100.0 * results["averageScoreInstClasses"], 39 | 100.0 * results["averageScoreCategories"], 40 | 100.0 * results["averageScoreInstCategories"], 41 | ) 42 | print(eval_res_str) 43 | -------------------------------------------------------------------------------- /corenet/loss_fn/base_criteria.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import abc 7 | import argparse 8 | from typing import Any 9 | 10 | from torch import nn 11 | 12 | from corenet.utils import logger 13 | 14 | 15 | class BaseCriteria(nn.Module, abc.ABC): 16 | """Base class for defining loss functions. Sub-classes must implement compute_loss function. 17 | 18 | Args: 19 | opts: command line arguments 20 | """ 21 | 22 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 23 | super(BaseCriteria, self).__init__() 24 | self.opts = opts 25 | # small value for numerical stability purposes that sub-classes may want to use. 26 | self.eps = 1e-7 27 | 28 | @classmethod 29 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 30 | """Add criteria-specific arguments to the parser.""" 31 | if cls != BaseCriteria: 32 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 33 | return parser 34 | group = parser.add_argument_group(cls.__name__) 35 | 36 | group.add_argument( 37 | "--loss.category", 38 | type=str, 39 | default=None, 40 | help="Loss function category (e.g., classification). Defaults to None.", 41 | ) 42 | return parser 43 | 44 | @abc.abstractmethod 45 | def forward( 46 | self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs 47 | ) -> Any: 48 | """Compute the loss. 49 | 50 | Args: 51 | input_sample: Input to the model. 52 | prediction: Model's output 53 | target: Ground truth labels 54 | """ 55 | raise NotImplementedError 56 | 57 | def extra_repr(self) -> str: 58 | return "" 59 | 60 | def __repr__(self) -> str: 61 | return "{}({}\n)".format(self.__class__.__name__, self.extra_repr()) 62 | -------------------------------------------------------------------------------- /corenet/loss_fn/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/loss_fn/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/loss_fn/detection/base_detection_criteria.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from corenet.loss_fn import LOSS_REGISTRY, BaseCriteria 7 | 8 | 9 | @LOSS_REGISTRY.register(name="__base__", type="detection") 10 | class BaseDetectionCriteria(BaseCriteria): 11 | """Base class for defining detection loss functions. Sub-classes must implement forward function. 12 | 13 | Args: 14 | opts: command line arguments 15 | """ 16 | 17 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 18 | super().__init__(opts, *args, **kwargs) 19 | 20 | @classmethod 21 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | if cls != BaseDetectionCriteria: 23 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 24 | return parser 25 | 26 | group = parser.add_argument_group(cls.__name__) 27 | group.add_argument( 28 | "--loss.detection.name", 29 | type=str, 30 | default=None, 31 | help=f"Name of the loss function in {cls.__name__}. Defaults to None.", 32 | ) 33 | return parser 34 | -------------------------------------------------------------------------------- /corenet/loss_fn/distillation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/loss_fn/language_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/loss_fn/language_modeling/__init__.py -------------------------------------------------------------------------------- /corenet/loss_fn/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/loss_fn/multi_modal_img_text/base_multi_modal_img_text_criteria.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from corenet.loss_fn import LOSS_REGISTRY, BaseCriteria 7 | 8 | 9 | @LOSS_REGISTRY.register(name="__base__", type="multi_modal_image_text") 10 | class BaseMultiModalImageTextCriteria(BaseCriteria): 11 | """Base class for defining multi-modal image-text loss functions. Sub-classes must implement forward function. 12 | 13 | Args: 14 | opts: command line arguments 15 | """ 16 | 17 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 18 | super().__init__(opts, *args, **kwargs) 19 | 20 | @classmethod 21 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | if cls != BaseMultiModalImageTextCriteria: 23 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 24 | return parser 25 | 26 | group = parser.add_argument_group(cls.__name__) 27 | group.add_argument( 28 | "--loss.multi-modal-image-text.name", 29 | type=str, 30 | default=None, 31 | help="Name of the loss function. Defaults to None.", 32 | ) 33 | return parser 34 | -------------------------------------------------------------------------------- /corenet/loss_fn/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/loss_fn/segmentation/base_segmentation_criteria.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from corenet.loss_fn import LOSS_REGISTRY, BaseCriteria 7 | 8 | 9 | @LOSS_REGISTRY.register(name="__base__", type="segmentation") 10 | class BaseSegmentationCriteria(BaseCriteria): 11 | """Base class for defining segmentation loss functions. Sub-classes must implement forward function. 12 | 13 | Args: 14 | opts: command line arguments 15 | """ 16 | 17 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 18 | super().__init__(opts, *args, **kwargs) 19 | 20 | @classmethod 21 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | if cls != BaseSegmentationCriteria: 23 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 24 | return parser 25 | 26 | group = parser.add_argument_group(cls.__name__) 27 | group.add_argument( 28 | "--loss.segmentation.name", 29 | type=str, 30 | default=None, 31 | help="Name of the loss function. Defaults to None.", 32 | ) 33 | return parser 34 | -------------------------------------------------------------------------------- /corenet/loss_fn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/loss_fn/utils/__init__.py -------------------------------------------------------------------------------- /corenet/loss_fn/utils/build_helper.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | from torch import nn 7 | 8 | from corenet.constants import is_test_env 9 | from corenet.modeling.models import get_model 10 | from corenet.options.utils import extract_opts_with_prefix_replacement 11 | from corenet.utils import logger 12 | 13 | 14 | def build_cls_teacher_from_opts(opts: argparse.Namespace) -> nn.Module: 15 | """Helper function to build a classification teacher model from command-line arguments 16 | 17 | Args: 18 | opts: command-line arguments 19 | 20 | Returns: 21 | A teacher model 22 | """ 23 | pretrained_model = getattr(opts, "teacher.model.classification.pretrained") 24 | 25 | pytest_env = is_test_env() 26 | if not pytest_env and pretrained_model is None: 27 | logger.error( 28 | "For distillation, please specify teacher weights using teacher.model.classification.pretrained" 29 | ) 30 | teacher_opts = extract_opts_with_prefix_replacement( 31 | opts, "teacher.model.", "model." 32 | ) 33 | 34 | # build teacher model 35 | return get_model(teacher_opts, category="classification") 36 | -------------------------------------------------------------------------------- /corenet/loss_fn/utils/class_weighting.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | 10 | def compute_class_weights( 11 | target: Tensor, n_classes: int, norm_val: float = 1.1 12 | ) -> Tensor: 13 | """Implementation of a class-weighting scheme, as defined in Section 5.2 14 | of `ENet `_ paper. 15 | 16 | Args: 17 | target: Tensor of shape [Batch_size, *] containing values in the range `[0, C)`. 18 | n_classes: Integer specifying the number of classes :math:`C` 19 | norm_val: Normalization value. Defaults to 1.1. This value is decided based on the 20 | `ESPNetv2 paper `_. 21 | Link: https://github.com/sacmehta/ESPNetv2/blob/b78e323039908f31347d8ca17f49d5502ef1a594/segmentation/loadData.py#L16 22 | 23 | Returns: 24 | A :math:`C`-dimensional tensor containing class weights 25 | """ 26 | 27 | class_hist = torch.histc(target.float(), bins=n_classes, min=0, max=n_classes - 1) 28 | mask_indices = class_hist == 0 29 | 30 | # normalize between 0 and 1 by dividing by the sum 31 | norm_hist = torch.div(class_hist, class_hist.sum()) 32 | norm_hist = torch.add(norm_hist, norm_val) 33 | 34 | # compute class weights. 35 | # samples with more frequency will have less weight and vice-versa 36 | class_wts = torch.div(torch.ones_like(class_hist), torch.log(norm_hist)) 37 | 38 | # mask the classes which do not have samples in the current batch 39 | class_wts[mask_indices] = 0.0 40 | 41 | return class_wts.to(device=target.device) 42 | -------------------------------------------------------------------------------- /corenet/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.utils.registry import Registry 9 | 10 | METRICS_REGISTRY = Registry( 11 | "metrics", 12 | lazy_load_dirs=["corenet/metrics"], 13 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 14 | ) 15 | 16 | 17 | def arguments_stats(parser: argparse.ArgumentParser): 18 | group = parser.add_argument_group(title="Statistics", description="Statistics") 19 | group.add_argument( 20 | "--stats.val", type=str, default=["loss"], nargs="+", help="Name of statistics" 21 | ) 22 | group.add_argument( 23 | "--stats.train", 24 | type=str, 25 | default=["loss"], 26 | nargs="+", 27 | help="Name of statistics", 28 | ) 29 | group.add_argument( 30 | "--stats.checkpoint-metric", 31 | type=str, 32 | default="loss", 33 | help="Metric to use for saving checkpoints", 34 | ) 35 | group.add_argument( 36 | "--stats.checkpoint-metric-max", 37 | action="store_true", 38 | default=False, 39 | help="Maximize checkpoint metric", 40 | ) 41 | group.add_argument( 42 | "--stats.coco-map.iou-types", 43 | type=str, 44 | default=["bbox"], 45 | nargs="+", 46 | choices=("bbox", "segm"), 47 | help="Types of IOU to compute for MSCoco.", 48 | ) 49 | 50 | return parser 51 | -------------------------------------------------------------------------------- /corenet/metrics/average_precision.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import traceback 7 | from numbers import Number 8 | from typing import Dict, Union 9 | 10 | import numpy as np 11 | from sklearn.metrics import average_precision_score 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | 15 | from corenet.metrics import METRICS_REGISTRY 16 | from corenet.metrics.metric_base import EpochMetric 17 | from corenet.utils import logger 18 | 19 | 20 | @METRICS_REGISTRY.register("average_precision") 21 | class AveragePrecisionMetric(EpochMetric): 22 | def compute_with_aggregates( 23 | self, y_pred: Tensor, y_true: Tensor 24 | ) -> Union[Number, Dict[str, Number]]: 25 | y_pred, y_true = self.get_aggregates() 26 | 27 | y_pred = F.softmax(y_pred, dim=-1).numpy().astype(np.float32) 28 | y_true = y_true.numpy().astype(np.float32) 29 | 30 | # Clip predictions to reduce chance of getting INF 31 | y_pred = y_pred.clip(0, 1) 32 | 33 | if y_pred.ndim == 1 or y_pred.ndim == 2 and y_pred.shape[1] == 1: 34 | pass # TODO? 35 | elif y_pred.ndim == 2 and y_pred.shape[1] == 2: 36 | y_pred = y_pred[:, 1] 37 | else: 38 | logger.warning( 39 | "Expected only two classes, got prediction Tensor of shape {}".format( 40 | y_pred.shape 41 | ) 42 | ) 43 | 44 | try: 45 | ap = 100 * average_precision_score(y_true, y_pred, average=None) 46 | except ValueError as e: 47 | logger.warning("Could not compute Average Precision: {}".format(str(e))) 48 | traceback.print_exc() 49 | ap = 0 # we don't want the job to fail over a metric computation issue 50 | 51 | return ap 52 | -------------------------------------------------------------------------------- /corenet/metrics/metric_base_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any, Dict, Union 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from corenet.metrics.metric_base import AverageMetric 12 | 13 | 14 | class DummyMetric(AverageMetric): 15 | def gather_metrics( 16 | self, 17 | prediction: Union[Tensor, Dict], 18 | target: Union[Tensor, Dict], 19 | extras: Dict[str, Any], 20 | ) -> Union[Tensor, Dict[str, Tensor]]: 21 | return prediction 22 | 23 | 24 | def test_average_metric_distributed_batchsize(mocker): 25 | mocker.patch("torch.distributed.is_initialized", return_value=True) 26 | mocker.patch("torch.distributed.get_world_size", return_value=2) 27 | mocker.patch("torch.distributed.all_reduce", lambda x, *_, **__: x.add_(1)) 28 | 29 | metric = DummyMetric(None, is_distributed=True) 30 | metric.update(torch.tensor([2.0]), None, batch_size=torch.tensor([2])) 31 | 32 | # Value is 2 and batch size is 2, but we're simulating the second device 33 | # having value 1 and batch size 1 by making sure all_reduce adds 1 to both 34 | # the value and the batch size. It's as if we have [2, 2] in GPU1 and [1] 35 | # in GPU 2. Therefore the expected average is 5/3. 36 | 37 | expected_value = (2 * 2 + 1 * 1) / 3 38 | assert metric.compute() == expected_value 39 | -------------------------------------------------------------------------------- /corenet/modeling/anchor_generator/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.modeling.anchor_generator.base_anchor_generator import BaseAnchorGenerator 9 | from corenet.utils import logger 10 | from corenet.utils.registry import Registry 11 | 12 | # register anchor generator 13 | ANCHOR_GEN_REGISTRY = Registry( 14 | "anchor_gen", 15 | base_class=BaseAnchorGenerator, 16 | lazy_load_dirs=["corenet/modeling/anchor_generator"], 17 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_anchor_gen(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | """Arguments related to anchor generator for object detection""" 23 | group = parser.add_argument_group("Anchor generator", "Anchor generator") 24 | group.add_argument( 25 | "--anchor-generator.name", type=str, help="Name of the anchor generator" 26 | ) 27 | 28 | # add class specific arguments 29 | parser = ANCHOR_GEN_REGISTRY.all_arguments(parser) 30 | return parser 31 | 32 | 33 | def build_anchor_generator(opts, *args, **kwargs): 34 | """Build anchor generator for object detection""" 35 | anchor_gen_name = getattr(opts, "anchor_generator.name") 36 | 37 | # We registered the base class using a special `name` (i.e., `__base__`) 38 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 39 | # be used. Therefore, we raise an error for such cases 40 | if anchor_gen_name == "__base__": 41 | logger.error("__base__ can't be used as a projection name. Please check.") 42 | 43 | anchor_gen = ANCHOR_GEN_REGISTRY[anchor_gen_name](opts, *args, **kwargs) 44 | return anchor_gen 45 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/gelu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from corenet.modeling.layers.activation import register_act_fn 9 | 10 | 11 | @register_act_fn(name="gelu") 12 | class GELU(nn.GELU): 13 | """ 14 | Applies the `Gaussian Error Linear Units `_ function 15 | """ 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/hard_sigmoid.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | from torch.nn import functional as F 10 | 11 | from corenet.modeling.layers.activation import register_act_fn 12 | 13 | 14 | @register_act_fn(name="hard_sigmoid") 15 | class Hardsigmoid(nn.Hardsigmoid): 16 | """ 17 | Applies the `Hard Sigmoid `_ function 18 | """ 19 | 20 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 21 | super().__init__(inplace=inplace) 22 | 23 | def forward(self, input: Tensor, *args, **kwargs) -> Tensor: 24 | if hasattr(F, "hardsigmoid"): 25 | return F.hardsigmoid(input, self.inplace) 26 | else: 27 | return F.relu(input + 3) / 6 28 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/hard_swish.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | from torch.nn import functional as F 10 | 11 | from corenet.modeling.layers.activation import register_act_fn 12 | 13 | 14 | @register_act_fn(name="hard_swish") 15 | class Hardswish(nn.Hardswish): 16 | """ 17 | Applies the HardSwish function, as described in the paper 18 | `Searching for MobileNetv3 `_ 19 | """ 20 | 21 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 22 | super().__init__(inplace=inplace) 23 | 24 | def forward(self, input: Tensor, *args, **kwargs) -> Tensor: 25 | if hasattr(F, "hardswish"): 26 | return F.hardswish(input, self.inplace) 27 | else: 28 | x_hard_sig = F.relu(input + 3) / 6 29 | return input * x_hard_sig 30 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/leaky_relu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from corenet.modeling.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="leaky_relu") 14 | class LeakyReLU(nn.LeakyReLU): 15 | """ 16 | Applies a leaky relu function. See `Rectifier Nonlinearities Improve Neural Network Acoustic Models` 17 | for more details. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | negative_slope: Optional[float] = 1e-2, 23 | inplace: Optional[bool] = False, 24 | *args, 25 | **kwargs 26 | ) -> None: 27 | super().__init__(negative_slope=negative_slope, inplace=inplace) 28 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/prelu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from corenet.modeling.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="prelu") 14 | class PReLU(nn.PReLU): 15 | """ 16 | Applies the `Parametric Rectified Linear Unit `_ function 17 | """ 18 | 19 | def __init__( 20 | self, 21 | num_parameters: Optional[int] = 1, 22 | init: Optional[float] = 0.25, 23 | *args, 24 | **kwargs 25 | ) -> None: 26 | super().__init__(num_parameters=num_parameters, init=init) 27 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/relu.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from corenet.modeling.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="relu") 14 | class ReLU(nn.ReLU): 15 | """ 16 | Applies Rectified Linear Unit function 17 | """ 18 | 19 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 20 | super().__init__(inplace=inplace) 21 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/relu6.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from corenet.modeling.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="relu6") 14 | class ReLU6(nn.ReLU6): 15 | """ 16 | Applies the ReLU6 function 17 | """ 18 | 19 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 20 | super().__init__(inplace=inplace) 21 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/sigmoid.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from corenet.modeling.layers.activation import register_act_fn 9 | 10 | 11 | @register_act_fn(name="sigmoid") 12 | class Sigmoid(nn.Sigmoid): 13 | """ 14 | Applies the sigmoid function 15 | """ 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/swish.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from corenet.modeling.layers.activation import register_act_fn 11 | 12 | 13 | @register_act_fn(name="swish") 14 | class Swish(nn.SiLU): 15 | """ 16 | Applies the `Swish (also known as SiLU) `_ function. 17 | """ 18 | 19 | def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: 20 | super().__init__(inplace=inplace) 21 | -------------------------------------------------------------------------------- /corenet/modeling/layers/activation/tanh.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from corenet.modeling.layers.activation import register_act_fn 9 | 10 | 11 | @register_act_fn(name="tanh") 12 | class Tanh(nn.Tanh): 13 | """ 14 | Applies Tanh function 15 | """ 16 | 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__() 19 | -------------------------------------------------------------------------------- /corenet/modeling/layers/adaptive_pool.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Tuple, Union 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): 12 | """ 13 | Applies a 2D adaptive average pooling over an input tensor. 14 | 15 | Args: 16 | output_size (Optional, int or Tuple[int, int]): The target output size. If a single int :math:`h` is passed, 17 | then a square output of size :math:`hxh` is produced. If a tuple of size :math:`hxw` is passed, then an 18 | output of size `hxw` is produced. Default is 1. 19 | Shape: 20 | - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, 21 | :math:`H` is the input height, and :math:`W` is the input width 22 | - Output: :math:`(N, C, h, h)` or :math:`(N, C, h, w)` 23 | """ 24 | 25 | def __init__( 26 | self, output_size: Union[int, Tuple[int, int]] = 1, *args, **kwargs 27 | ) -> None: 28 | super().__init__(output_size=output_size) 29 | -------------------------------------------------------------------------------- /corenet/modeling/layers/dropout.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class Dropout(nn.Dropout): 12 | """ 13 | This layer, during training, randomly zeroes some of the elements of the input tensor with probability `p` 14 | using samples from a Bernoulli distribution. 15 | 16 | Args: 17 | p: probability of an element to be zeroed. Default: 0.5 18 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 19 | 20 | Shape: 21 | - Input: :math:`(N, *)` where :math:`N` is the batch size 22 | - Output: same as the input 23 | 24 | """ 25 | 26 | def __init__( 27 | self, p: Optional[float] = 0.5, inplace: Optional[bool] = False, *args, **kwargs 28 | ) -> None: 29 | super().__init__(p=p, inplace=inplace) 30 | 31 | 32 | class Dropout2d(nn.Dropout2d): 33 | """ 34 | This layer, during training, randomly zeroes some of the elements of the 4D input tensor with probability `p` 35 | using samples from a Bernoulli distribution. 36 | 37 | Args: 38 | p: probability of an element to be zeroed. Default: 0.5 39 | inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 40 | 41 | Shape: 42 | - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the input channels, 43 | :math:`H` is the input tensor height, and :math:`W` is the input tensor width 44 | - Output: same as the input 45 | 46 | """ 47 | 48 | def __init__(self, p: float = 0.5, inplace: bool = False): 49 | super().__init__(p=p, inplace=inplace) 50 | -------------------------------------------------------------------------------- /corenet/modeling/layers/flatten.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class Flatten(nn.Flatten): 12 | r""" 13 | This layer flattens a contiguous range of dimensions into a tensor. 14 | 15 | Args: 16 | start_dim (Optional[int]): first dim to flatten. Default: 1 17 | end_dim (Optional[int]): last dim to flatten. Default: -1 18 | 19 | Shape: 20 | - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' 21 | where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any 22 | number of dimensions including none. 23 | - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. 24 | """ 25 | 26 | def __init__(self, start_dim: Optional[int] = 1, end_dim: Optional[int] = -1): 27 | super(Flatten, self).__init__(start_dim=start_dim, end_dim=end_dim) 28 | -------------------------------------------------------------------------------- /corenet/modeling/layers/identity.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor 7 | 8 | from corenet.modeling.layers.base_layer import BaseLayer 9 | 10 | 11 | class Identity(BaseLayer): 12 | """ 13 | This is a place-holder and returns the same tensor. 14 | """ 15 | 16 | def __init__(self): 17 | super(Identity, self).__init__() 18 | 19 | def forward(self, x: Tensor) -> Tensor: 20 | return x 21 | -------------------------------------------------------------------------------- /corenet/modeling/layers/normalization/group_norm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | from corenet.modeling.layers.normalization import register_norm_fn 11 | 12 | 13 | @register_norm_fn(name="group_norm") 14 | class GroupNorm(nn.GroupNorm): 15 | """ 16 | Applies a `Group Normalization `_ over an input tensor 17 | 18 | Args: 19 | num_groups (int): number of groups to separate the input channels into 20 | num_features (int): :math:`C` from an expected input of size :math:`(N, C, *)` 21 | eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 22 | affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` 23 | 24 | Shape: 25 | - Input: :math:`(N, C, *)` where :math:`N` is the batch size, :math:`C` is the number of input channels, 26 | and :math:`*` is the remaining dimensions of the input tensor 27 | - Output: same shape as the input 28 | 29 | .. note:: 30 | GroupNorm is the same as LayerNorm when `num_groups=1` and it is the same as InstanceNorm when 31 | `num_groups=C`. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | num_groups: int, 37 | num_features: int, 38 | eps: Optional[float] = 1e-5, 39 | affine: Optional[bool] = True, 40 | *args, 41 | **kwargs 42 | ) -> None: 43 | super().__init__( 44 | num_groups=num_groups, num_channels=num_features, eps=eps, affine=affine 45 | ) 46 | -------------------------------------------------------------------------------- /corenet/modeling/layers/normalization_layers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import math 6 | 7 | from torch import nn 8 | 9 | from corenet.modeling.layers.normalization import ( 10 | NORM_LAYER_CLS, 11 | build_normalization_layer, 12 | ) 13 | from corenet.utils import logger 14 | 15 | norm_layers_tuple = tuple(NORM_LAYER_CLS) 16 | 17 | 18 | get_normalization_layer = build_normalization_layer 19 | -------------------------------------------------------------------------------- /corenet/modeling/layers/pixel_shuffle.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | 9 | class PixelShuffle(nn.PixelShuffle): 10 | """ 11 | Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` 12 | to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. 13 | 14 | Args: 15 | upscale_factor (int): factor to increase spatial resolution by 16 | 17 | Shape: 18 | - Input: :math:`(*, C \times r^2, H, W)`, where * is zero or more dimensions 19 | - Output: :math:`(*, C, H \times r, W \times r)` 20 | """ 21 | 22 | def __init__(self, upscale_factor: int, *args, **kwargs) -> None: 23 | super(PixelShuffle, self).__init__(upscale_factor=upscale_factor) 24 | 25 | def __repr__(self): 26 | return "{}(upscale_factor={})".format( 27 | self.__class__.__name__, self.upscale_factor 28 | ) 29 | -------------------------------------------------------------------------------- /corenet/modeling/layers/random_layers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import random 7 | from typing import List, Optional 8 | 9 | from torch import Tensor 10 | 11 | from corenet.modeling.layers.base_layer import BaseLayer 12 | from corenet.utils.math_utils import bound_fn 13 | 14 | 15 | class RandomApply(BaseLayer): 16 | """ 17 | This layer randomly applies a list of modules during training. 18 | 19 | Args: 20 | module_list (List): List of modules 21 | keep_p (Optional[float]): Keep P modules from the list during training. Default: 0.8 (or 80%) 22 | """ 23 | 24 | def __init__( 25 | self, module_list: List, keep_p: Optional[float] = 0.8, *args, **kwargs 26 | ) -> None: 27 | super().__init__() 28 | n_modules = len(module_list) 29 | self.module_list = module_list 30 | 31 | self.module_indexes = [i for i in range(1, n_modules)] 32 | k = int(round(n_modules * keep_p)) 33 | self.keep_k = bound_fn(min_val=1, max_val=n_modules, value=k) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | if self.training: 37 | indexes = [0] + sorted(random.sample(self.module_indexes, k=self.keep_k)) 38 | for idx in indexes: 39 | x = self.module_list[idx](x) 40 | else: 41 | for layer in self.module_list: 42 | x = layer(x) 43 | return x 44 | 45 | def __repr__(self): 46 | format_string = "{}(apply_k (N={})={}, ".format( 47 | self.__class__.__name__, len(self.module_list), self.keep_k 48 | ) 49 | for layer in self.module_list: 50 | format_string += "\n\t {}".format(layer) 51 | format_string += "\n)" 52 | return format_string 53 | -------------------------------------------------------------------------------- /corenet/modeling/layers/softmax.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor, nn 9 | 10 | 11 | class Softmax(nn.Softmax): 12 | """ 13 | Applies the Softmax function to an input tensor along the specified dimension 14 | 15 | Args: 16 | dim (int): Dimension along which softmax to be applied. Default: -1 17 | 18 | Shape: 19 | - Input: :math:`(*)` where :math:`*` is one or more dimensions 20 | - Output: same shape as the input 21 | """ 22 | 23 | def __init__(self, dim: Optional[int] = -1, *args, **kwargs): 24 | super().__init__(dim=dim) 25 | -------------------------------------------------------------------------------- /corenet/modeling/layers/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor 7 | from torchvision.ops import StochasticDepth as StochasticDepthTorch 8 | 9 | 10 | class StochasticDepth(StochasticDepthTorch): 11 | """ 12 | Implements the Stochastic Depth `"Deep Networks with Stochastic Depth" 13 | `_ used for randomly dropping residual 14 | branches of residual architectures. 15 | """ 16 | 17 | def __init__(self, p: float, mode: str) -> None: 18 | super().__init__(p=p, mode=mode) 19 | -------------------------------------------------------------------------------- /corenet/modeling/layers/upsample.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | 12 | class UpSample(nn.Upsample): 13 | """ 14 | This layer upsamples a given input tensor. 15 | 16 | Args: 17 | size (Optional[Union[int, Tuple[int, ...]]): Output spatial size. Default: None 18 | scale_factor (Optional[float]): Scale each spatial dimension of the input by this factor. Default: None 19 | mode (Optional[str]): Upsampling algorithm (``'nearest'``, ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. Default: ``'nearest'`` 20 | align_corners (Optional[bool]): if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at 21 | those pixels. This only has effect when :attr:`mode` is ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. 22 | Default: ``None`` 23 | 24 | Shape: 25 | - Input: :math:`(N, C, W_{in})` or :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` 26 | - Output: :math:`(N, C, W_{out})` or :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})` 27 | """ 28 | 29 | def __init__( 30 | self, 31 | size: Optional[Union[int, Tuple[int, ...]]] = None, 32 | scale_factor: Optional[float] = None, 33 | mode: Optional[str] = "nearest", 34 | align_corners: Optional[bool] = None, 35 | *args, 36 | **kwargs 37 | ) -> None: 38 | super().__init__( 39 | size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners 40 | ) 41 | -------------------------------------------------------------------------------- /corenet/modeling/matcher_det/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.modeling.matcher_det.base_matcher import BaseMatcher 9 | from corenet.utils import logger 10 | from corenet.utils.registry import Registry 11 | 12 | # register BOX Matcher 13 | MATCHER_REGISTRY = Registry( 14 | "matcher", 15 | base_class=BaseMatcher, 16 | lazy_load_dirs=["corenet/modeling/matcher_det"], 17 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 18 | ) 19 | 20 | 21 | def arguments_box_matcher(parser: argparse.ArgumentParser): 22 | group = parser.add_argument_group("Matcher", "Matcher") 23 | group.add_argument( 24 | "--matcher.name", 25 | type=str, 26 | help="Name of the matcher. Matcher matches anchors with GT box coordinates", 27 | ) 28 | 29 | # add segmentation specific arguments 30 | parser = MATCHER_REGISTRY.all_arguments(parser) 31 | return parser 32 | 33 | 34 | def build_matcher(opts, *args, **kwargs): 35 | matcher_name = getattr(opts, "matcher.name", None) 36 | # We registered the base class using a special `name` (i.e., `__base__`) 37 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 38 | # be used. Therefore, we raise an error for such cases 39 | if matcher_name == "__base__": 40 | logger.error("__base__ can't be used as a projection name. Please check.") 41 | 42 | matcher = MATCHER_REGISTRY[matcher_name](opts, *args, **kwargs) 43 | return matcher 44 | -------------------------------------------------------------------------------- /corenet/modeling/matcher_det/base_matcher.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | 9 | class BaseMatcher(object): 10 | """ 11 | Base class for matching anchor boxes and labels for the task of object detection 12 | """ 13 | 14 | def __init__(self, opts, *args, **kwargs) -> None: 15 | super(BaseMatcher, self).__init__() 16 | self.opts = opts 17 | 18 | @classmethod 19 | def add_arguments(cls, parser: argparse.ArgumentParser): 20 | """Add class-specific arguments""" 21 | return parser 22 | 23 | def __call__(self, *args, **kwargs): 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /corenet/modeling/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/modeling/misc/__init__.py -------------------------------------------------------------------------------- /corenet/modeling/models/audio_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/models/audio_classification/base_audio_classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.modeling.models import MODEL_REGISTRY, BaseAnyNNModel 9 | 10 | 11 | @MODEL_REGISTRY.register(name="__base__", type="audio_classification") 12 | class BaseAudioClassification(BaseAnyNNModel): 13 | """Base class for audio classification. 14 | 15 | Args: 16 | opts: Command-line arguments 17 | """ 18 | 19 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 20 | super().__init__(opts, *args, **kwargs) 21 | 22 | @classmethod 23 | def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 24 | """Add model specific arguments""" 25 | if cls != BaseAudioClassification: 26 | # Don't re-register arguments in subclasses that don't override `add_arguments()`. 27 | return parser 28 | group = parser.add_argument_group(title=cls.__name__) 29 | group.add_argument( 30 | "--model.audio-classification.name", 31 | type=str, 32 | default=None, 33 | help="Name of the audio classification model. Defaults to None.", 34 | ) 35 | group.add_argument( 36 | "--model.audio-classification.pretrained", 37 | type=str, 38 | default=None, 39 | help="Path of the pretrained backbone. Defaults to None.", 40 | ) 41 | return parser 42 | -------------------------------------------------------------------------------- /corenet/modeling/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/models/classification/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/modeling/models/classification/config/__init__.py -------------------------------------------------------------------------------- /corenet/modeling/models/classification/config/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import math 7 | from typing import Dict 8 | 9 | from corenet.utils.math_utils import make_divisible 10 | 11 | 12 | def get_configuration(opts) -> Dict: 13 | width_mult = getattr(opts, "model.classification.mobilenetv1.width_multiplier", 1.0) 14 | 15 | def scale_channels(in_channels): 16 | return make_divisible(int(math.ceil(in_channels * width_mult)), 16) 17 | 18 | config = { 19 | "conv1_out": scale_channels(32), 20 | "layer1": {"out_channels": scale_channels(64), "stride": 1, "repeat": 1}, 21 | "layer2": { 22 | "out_channels": scale_channels(128), 23 | "stride": 2, 24 | "repeat": 1, 25 | }, 26 | "layer3": { 27 | "out_channels": scale_channels(256), 28 | "stride": 2, 29 | "repeat": 1, 30 | }, 31 | "layer4": { 32 | "out_channels": scale_channels(512), 33 | "stride": 2, 34 | "repeat": 5, 35 | }, 36 | "layer5": { 37 | "out_channels": scale_channels(1024), 38 | "stride": 2, 39 | "repeat": 1, 40 | }, 41 | } 42 | return config 43 | -------------------------------------------------------------------------------- /corenet/modeling/models/classification/config/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Dict 7 | 8 | 9 | def get_configuration(opts) -> Dict: 10 | 11 | mobilenetv2_config = { 12 | "layer1": { 13 | "expansion_ratio": 1, 14 | "out_channels": 16, 15 | "num_blocks": 1, 16 | "stride": 1, 17 | }, 18 | "layer2": { 19 | "expansion_ratio": 6, 20 | "out_channels": 24, 21 | "num_blocks": 2, 22 | "stride": 2, 23 | }, 24 | "layer3": { 25 | "expansion_ratio": 6, 26 | "out_channels": 32, 27 | "num_blocks": 3, 28 | "stride": 2, 29 | }, 30 | "layer4": { 31 | "expansion_ratio": 6, 32 | "out_channels": 64, 33 | "num_blocks": 4, 34 | "stride": 2, 35 | }, 36 | "layer4_a": { 37 | "expansion_ratio": 6, 38 | "out_channels": 96, 39 | "num_blocks": 3, 40 | "stride": 1, 41 | }, 42 | "layer5": { 43 | "expansion_ratio": 6, 44 | "out_channels": 160, 45 | "num_blocks": 3, 46 | "stride": 2, 47 | }, 48 | "layer5_a": { 49 | "expansion_ratio": 6, 50 | "out_channels": 320, 51 | "num_blocks": 1, 52 | "stride": 1, 53 | }, 54 | } 55 | return mobilenetv2_config 56 | -------------------------------------------------------------------------------- /corenet/modeling/models/classification/config/mobileone.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | from typing import Dict 7 | 8 | from corenet.utils import logger 9 | 10 | 11 | def get_configuration(opts: argparse.Namespace) -> Dict: 12 | """Get configuration of MobileOne models.""" 13 | variant = getattr(opts, "model.classification.mobileone.variant") 14 | config = dict() 15 | 16 | if variant == "s0": 17 | config = { 18 | "num_blocks_per_stage": [2, 8, 10, 1], 19 | "width_multipliers": (0.75, 1.0, 1.0, 2.0), 20 | "num_conv_branches": 4, 21 | "use_se": False, 22 | } 23 | elif variant == "s1": 24 | config = { 25 | "num_blocks_per_stage": [2, 8, 10, 1], 26 | "width_multipliers": (1.5, 1.5, 2.0, 2.5), 27 | "num_conv_branches": 1, 28 | "use_se": False, 29 | } 30 | elif variant == "s2": 31 | config = { 32 | "num_blocks_per_stage": [2, 8, 10, 1], 33 | "width_multipliers": (1.5, 2.0, 2.5, 4.0), 34 | "num_conv_branches": 1, 35 | "use_se": False, 36 | } 37 | elif variant == "s3": 38 | config = { 39 | "num_blocks_per_stage": [2, 8, 10, 1], 40 | "width_multipliers": (2.0, 2.5, 3.0, 4.0), 41 | "num_conv_branches": 1, 42 | "use_se": False, 43 | } 44 | elif variant == "s4": 45 | config = { 46 | "num_blocks_per_stage": [2, 8, 10, 1], 47 | "width_multipliers": (3.0, 3.5, 3.5, 4.0), 48 | "num_conv_branches": 1, 49 | "use_se": True, 50 | } 51 | else: 52 | logger.error( 53 | "MobileOne supported variants: `s0`, `s1`, `s2`, `s3` and `s4`. Please specify variant using " 54 | "--model.classification.mobileone.variant flag. Got: {}".format(variant) 55 | ) 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /corenet/modeling/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from collections import namedtuple 7 | 8 | DetectionPredTuple = namedtuple( 9 | typename="DetectionPredTuple", 10 | field_names=("labels", "scores", "boxes", "masks"), 11 | defaults=(None, None, None, None), 12 | ) 13 | -------------------------------------------------------------------------------- /corenet/modeling/models/detection/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/modeling/models/detection/utils/__init__.py -------------------------------------------------------------------------------- /corenet/modeling/models/language_modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/models/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/models/segmentation/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/models/video_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /corenet/modeling/modules/base_module.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | 12 | class BaseModule(nn.Module): 13 | """Base class for all modules""" 14 | 15 | def __init__(self, *args, **kwargs): 16 | super(BaseModule, self).__init__() 17 | 18 | def forward(self, x: Any, *args, **kwargs) -> Any: 19 | raise NotImplementedError 20 | 21 | def __repr__(self): 22 | return "{}".format(self.__class__.__name__) 23 | -------------------------------------------------------------------------------- /corenet/modeling/modules/efficientnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch import Tensor, nn 7 | 8 | from corenet.modeling.layers import StochasticDepth 9 | from corenet.modeling.modules import InvertedResidualSE 10 | 11 | 12 | class EfficientNetBlock(InvertedResidualSE): 13 | """ 14 | This class implements a variant of the inverted residual block with squeeze-excitation unit, 15 | as described in `MobileNetv3 `_ paper. This variant 16 | includes stochastic depth, as used in `EfficientNet `_ paper. 17 | 18 | Args: 19 | stochastic_depth_prob: float, 20 | For other arguments, refer to the parent class. 21 | 22 | Shape: 23 | - Input: :math:`(N, C_{in}, H_{in}, W_{in})` 24 | - Output: :math:`(N, C_{out}, H_{out}, W_{out})` 25 | """ 26 | 27 | def __init__(self, stochastic_depth_prob: float, *args, **kwargs) -> None: 28 | super().__init__(*args, **kwargs) 29 | self.stochastic_depth = StochasticDepth(p=stochastic_depth_prob, mode="row") 30 | 31 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: 32 | y = self.block(x) 33 | if self.use_res_connect: 34 | # Pass the output through the stochastic layer module, potentially zeroing it. 35 | y = self.stochastic_depth(y) 36 | # residual connection 37 | y = y + x 38 | return y 39 | 40 | def __repr__(self) -> str: 41 | return ( 42 | super().__repr__()[:-1] 43 | + f", stochastic_depth_prob={self.stochastic_depth.p})" 44 | ) 45 | -------------------------------------------------------------------------------- /corenet/modeling/neural_augmentor/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | # 4 | 5 | import argparse 6 | 7 | from corenet.modeling.neural_augmentor.neural_aug import ( 8 | BaseNeuralAugmentor, 9 | build_neural_augmentor, 10 | ) 11 | 12 | 13 | def arguments_neural_augmentor( 14 | parser: argparse.ArgumentParser, 15 | ) -> argparse.ArgumentParser: 16 | return BaseNeuralAugmentor.add_arguments(parser=parser) 17 | -------------------------------------------------------------------------------- /corenet/modeling/neural_augmentor/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/modeling/neural_augmentor/utils/__init__.py -------------------------------------------------------------------------------- /corenet/modeling/text_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.modeling.text_encoders.base_text_encoder import BaseTextEncoder 9 | from corenet.utils import logger 10 | from corenet.utils.registry import Registry 11 | 12 | TEXT_ENCODER_REGISTRY = Registry( 13 | "text_encoder", 14 | base_class=BaseTextEncoder, 15 | lazy_load_dirs=["corenet/modeling/text_encoders"], 16 | internal_dirs=["corenet/internal", "corenet/internal/projects/*"], 17 | ) 18 | 19 | 20 | def arguments_text_encoder(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 21 | """Register arguments of all text encoders.""" 22 | # add arguments for text_encoder 23 | parser = BaseTextEncoder.add_arguments(parser) 24 | 25 | # add class specific arguments 26 | parser = TEXT_ENCODER_REGISTRY.all_arguments(parser) 27 | return parser 28 | 29 | 30 | def build_text_encoder(opts, projection_dim: int, *args, **kwargs) -> BaseTextEncoder: 31 | """Helper function to build the text encoder from command-line arguments. 32 | 33 | Args: 34 | opts: Command-line arguments 35 | projection_dim: The dimensionality of the projection head after text encoder. 36 | 37 | Returns: 38 | Text encoder module. 39 | """ 40 | text_encoder_name = getattr(opts, "model.text.name") 41 | 42 | # We registered the base class using a special `name` (i.e., `__base__`) 43 | # in order to access the arguments defined inside those classes. However, these classes are not supposed to 44 | # be used. Therefore, we raise an error for such cases 45 | if text_encoder_name == "__base__": 46 | logger.error("__base__ can't be used as a projection name. Please check.") 47 | 48 | text_encoder = TEXT_ENCODER_REGISTRY[text_encoder_name]( 49 | opts, projection_dim, *args, **kwargs 50 | ) 51 | return text_encoder 52 | -------------------------------------------------------------------------------- /corenet/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/options/__init__.py -------------------------------------------------------------------------------- /corenet/options/errors.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.constants import is_test_env 7 | 8 | 9 | class UnrecognizedYamlConfigEntry(Warning): 10 | # TODO: consider converting UnrecognizedYamlConfigEntry Warning to an Exception. 11 | def __init__(self, key: str) -> None: 12 | message = ( 13 | f"Yaml config key '{key}' was not recognized by argparser. If you think that you have already added " 14 | f"argument in corenet/options/opts.py file, then check for typos. If not, then please add it to corenet/options/opts.py." 15 | ) 16 | super().__init__(message) 17 | 18 | if is_test_env(): 19 | # Currently, we only raise an exception in test environment. 20 | raise ValueError(message) 21 | -------------------------------------------------------------------------------- /corenet/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/third_party/__init__.py -------------------------------------------------------------------------------- /corenet/third_party/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/third_party/data/__init__.py -------------------------------------------------------------------------------- /corenet/third_party/data/text_tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/third_party/data/text_tokenizer/__init__.py -------------------------------------------------------------------------------- /corenet/third_party/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/third_party/modeling/__init__.py -------------------------------------------------------------------------------- /corenet/train_eval_pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.train_eval_pipelines.base import ( 7 | TRAIN_EVAL_PIPELINE_REGISTRY, 8 | BaseTrainEvalPipeline, 9 | ) 10 | from corenet.train_eval_pipelines.default_train_eval import DefaultTrainEvalPipeline 11 | -------------------------------------------------------------------------------- /corenet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/corenet/utils/__init__.py -------------------------------------------------------------------------------- /corenet/utils/activation_checkpointing_wrapper.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | 7 | from functools import partial 8 | from typing import Callable, List, Union 9 | 10 | import torch 11 | 12 | 13 | def activation_checkpointing( 14 | model: torch.nn.Module, 15 | submodule_class: Union[List[Callable], Callable], 16 | ) -> None: 17 | """ 18 | Applies activation checkpointing to `module_to_checkpoint`, a sub-module(s) inside 'model'. 19 | 20 | Args: 21 | model: The model whose submodules should be wrapped with activation checkpointing. 22 | submodule_class: Submodule class to be wrapped with activation checkpointing. 23 | 24 | Usage:: 25 | model = nn.Sequential( 26 | nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) 27 | ) 28 | module_to_checkpoint = nn.Linear 29 | # checkpoint activations 30 | activation_checkpointing(model, module_to_checkpoint) 31 | """ 32 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 33 | CheckpointImpl, 34 | apply_activation_checkpointing, 35 | checkpoint_wrapper, 36 | ) 37 | 38 | non_reentrant_wrapper = partial( 39 | checkpoint_wrapper, 40 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 41 | ) 42 | 43 | if isinstance(submodule_class, list): 44 | for m in submodule_class: 45 | check_fn = lambda submodule: isinstance(submodule, m) 46 | apply_activation_checkpointing( 47 | model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 48 | ) 49 | else: 50 | check_fn = lambda submodule: isinstance(submodule, submodule_class) 51 | apply_activation_checkpointing( 52 | model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 53 | ) 54 | -------------------------------------------------------------------------------- /corenet/utils/check.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import inspect 7 | import sys 8 | import types 9 | from typing import Any, Union 10 | 11 | 12 | def check( 13 | value: Any, on_failure: Union[str, Exception, types.FunctionType] = "Check failed" 14 | ) -> Any: 15 | """ 16 | Checks if value is truthy and raises an exception if not. 17 | 18 | This is a replacement for assert, with the following advantages: 19 | - Cannot be disabled by the -O flag 20 | - Can raise any exception type 21 | - Returns the checked value for concise code 22 | 23 | on_failure can be: 24 | - A string, in which case a AssertionError is raised with that message. 25 | - A constructed exception to be raised. 26 | - A lambda returning any of the above, so that the message/exception 27 | doesn't need to be constructed if the check succeeds. If the lambda 28 | takes an argument it will be the value. 29 | """ 30 | if value: 31 | return value 32 | 33 | if isinstance(on_failure, types.FunctionType): 34 | nparams = len(inspect.signature(on_failure).parameters) 35 | 36 | if nparams == 0: 37 | on_failure = on_failure() 38 | elif nparams == 1: 39 | on_failure = on_failure(value) 40 | else: 41 | raise ValueError("Expect at most 1 element lambda") 42 | 43 | if not isinstance(on_failure, Exception): 44 | on_failure = AssertionError(str(on_failure)) 45 | 46 | # This used to pop the call stack from the exception traceback, 47 | # so that it would appear to come from the check() call itself, 48 | # but that seems to no longer work in python3.10 49 | check_failed = on_failure 50 | 51 | raise check_failed 52 | -------------------------------------------------------------------------------- /corenet/utils/context_managers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from contextlib import contextmanager 7 | from typing import ContextManager 8 | 9 | 10 | @contextmanager 11 | def context_env_vars(**env: str) -> ContextManager[None]: 12 | """ 13 | Temporarily sets the environment variables within its context. 14 | 15 | Example usage: 16 | ``` 17 | os.environ["X"] = 2 18 | with context_env_vars(X=3): 19 | print(os.environ["X"]) # prints 3 20 | print(os.environ["X"]) # prints 2 21 | ``` 22 | """ 23 | original_values = {} 24 | try: 25 | for key, value in env.items(): 26 | original_values[key] = env.get(key, None) 27 | if value is None: 28 | env.pop(key, None) 29 | else: 30 | env[key] = value 31 | yield 32 | finally: 33 | for key, value in original_values.items(): 34 | if value is None: 35 | env.pop(key, None) 36 | else: 37 | env[key] = value 38 | 39 | 40 | @contextmanager 41 | def context_tensor_threads(num_cpu_threads: int) -> ContextManager[None]: 42 | """ 43 | Temporarily, instructs numpy and torch to use @n cpu threads for processing tensors 44 | and arrays within the context. 45 | """ 46 | num_cpu_threads = str(num_cpu_threads) 47 | with context_env_vars( 48 | MKL_NUM_THREADS=num_cpu_threads, 49 | OMP_NUM_THREADS=num_cpu_threads, 50 | NUMEXPR_NUM_THREADS=num_cpu_threads, 51 | ): 52 | yield 53 | -------------------------------------------------------------------------------- /corenet/utils/dict_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Collection, Dict, Optional 6 | 7 | 8 | def filter_keys( 9 | d: Dict, 10 | whitelist: Optional[Collection[str]] = None, 11 | ) -> Dict: 12 | """Returns a copy of the input dict @d, with a subset of keys that are in 13 | @whitelist. 14 | 15 | Args: 16 | d: Input dictionary that will be copied with a subset of keys. 17 | whitelist: List of keys to keep in the output (if exist in input dict). 18 | """ 19 | 20 | return {key: d[key] for key in whitelist if key in d} 21 | -------------------------------------------------------------------------------- /corenet/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import tempfile 7 | from typing import Optional 8 | 9 | 10 | def make_temp_file( 11 | suffix: str = None, 12 | prefix: Optional[str] = "corenet-tmp-", 13 | dir: Optional[str] = None, 14 | ) -> str: 15 | """Create a temporary file and return its path.""" 16 | tmp_file = tempfile.NamedTemporaryFile( 17 | delete=False, 18 | suffix=suffix, 19 | prefix=prefix, 20 | dir=dir, 21 | ) 22 | tmp_file.close() 23 | return tmp_file.name 24 | -------------------------------------------------------------------------------- /corenet/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional, Union 7 | 8 | 9 | def make_divisible( 10 | v: Union[float, int], 11 | divisor: Optional[int] = 8, 12 | min_value: Optional[Union[float, int]] = None, 13 | ) -> Union[float, int]: 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | def bound_fn( 34 | min_val: Union[float, int], max_val: Union[float, int], value: Union[float, int] 35 | ) -> Union[float, int]: 36 | return max(min_val, min(max_val, value)) 37 | -------------------------------------------------------------------------------- /corenet/utils/object_utils_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.utils.object_utils import apply_recursively, flatten_to_dict 7 | 8 | 9 | def test_apply_on_values(): 10 | d = { 11 | "top1": 1.112311, 12 | "prob_hist": {"max": [0.10003, 0.3, 0.5, 0.09997]}, 13 | "accuracy_per_class": [0.8286, 0.9124], 14 | } 15 | 16 | new_d = apply_recursively(d, lambda x: round(x, 2)) 17 | 18 | assert str(new_d["top1"]) == "1.11" 19 | assert str(new_d["prob_hist"]["max"][0]) == "0.1" 20 | 21 | 22 | def test_flatten_to_dict(): 23 | original = { 24 | "top1": 1.112311, 25 | "prob_hist": {"max": [0.10003, 0.3, 0.5, 0.09997]}, 26 | "accuracy_per_class": [0.8286, 0.9124], 27 | } 28 | flattened = { 29 | "metric/top1": 1.112311, 30 | "metric/prob_hist/max_0": 0.10003, 31 | "metric/prob_hist/max_1": 0.3, 32 | "metric/prob_hist/max_2": 0.5, 33 | "metric/prob_hist/max_3": 0.09997, 34 | "metric/accuracy_per_class_0": 0.8286, 35 | "metric/accuracy_per_class_1": 0.9124, 36 | } 37 | assert flatten_to_dict(original, "metric") == flattened 38 | -------------------------------------------------------------------------------- /corenet/utils/registry_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.utils.registry import Registry 7 | 8 | 9 | def test_functional_registry() -> None: 10 | reg = Registry("registry_name") 11 | reg.register("awesome_dict")(dict) 12 | 13 | assert "awesome_dict" in reg 14 | assert "awesome_dict(name=hello)" in reg 15 | 16 | obj = reg["awesome_dict(name=hello, type=fifo)"]() 17 | 18 | assert obj == {"name": "hello", "type": "fifo"} 19 | 20 | 21 | def test_basic_registration() -> None: 22 | my_registry = Registry("registry_name") 23 | 24 | @my_registry.register("awesome_class_or_func") 25 | def my_awesome_class_or_func(param): 26 | pass 27 | 28 | assert "awesome_class_or_func" in my_registry 29 | assert "awesome_class_or_func(param=value)" in my_registry 30 | -------------------------------------------------------------------------------- /corenet/utils/resources.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | try: 7 | from corenet.internal.utils.resources import cpu_count 8 | except ImportError: 9 | from multiprocessing import cpu_count 10 | 11 | __all__ = ["cpu_count"] 12 | -------------------------------------------------------------------------------- /mlx_examples/clip/README.md: -------------------------------------------------------------------------------- 1 | # MLX port of CLIP 2 | 3 | This is an example to convert CoreNet's CLIP model implementation to 4 | [MLX](https://github.com/ml-explore/mlx)'s CLIP example with some customized modification. MLX is a machine learning framework that provides native Apple Silicon hardware support. 5 | 6 | ## Conversion 7 | 8 | To convert an example CoreNet's CLIP model to the example MLX CLIP using the files in this directory: 9 | 10 | ```bash 11 | cd mlx_examples/clip/ 12 | 13 | # Install required dependencies 14 | # We assume that the main requirements.txt is already installed. 15 | pip install -r requirements.txt 16 | 17 | # Convert the model 18 | python main_clip_to_mlx.py \ 19 | --common.config-file "../../projects/range_augment/clip/clip_vit_base.yaml" \ 20 | --model.multi-modal-image-text.pretrained https://docs-assets.developer.apple.com/ml-research/models/cvnets-v2/examples/range_augment/clip/clip_vit_base_16.pt \ 21 | --common.results-loc results/mlx_model/ 22 | 23 | # Try example inference 24 | python clip.py 25 | ``` 26 | 27 | ## Benchmarking results 28 | 29 | Comparing to PyTorch, given the input as `["a photo of cat", "a photo of dog"]` prompt 30 | and the `assets/{cat,dog}.jpeg` images. The results are the following on a M2 Ultra: 31 | 32 | 33 | | Model | PyTorch time 100iters (s) | MLX time 100iters (s) | Speedup (%) | 34 | | :-----| :----------------------------- | :------------------------- | :---------- | 35 | | FP16 Base variant | 2.7322 | 1.0743 | 60.68% | 36 | | FP16 Huge variant | 4.9098 | 4.3189 | 12.04% | 37 | -------------------------------------------------------------------------------- /mlx_examples/clip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/mlx_examples/clip/__init__.py -------------------------------------------------------------------------------- /mlx_examples/clip/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.40.0 2 | huggingface_hub==0.21.4 3 | -r ../requirements.txt 4 | -------------------------------------------------------------------------------- /mlx_examples/clip/results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/mlx_examples/clip/results/.gitkeep -------------------------------------------------------------------------------- /mlx_examples/open_elm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/mlx_examples/open_elm/__init__.py -------------------------------------------------------------------------------- /mlx_examples/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx==0.10.0 ; sys_platform == 'darwin' 2 | sentencepiece==0.2.0 3 | safetensors==0.4.2 4 | -------------------------------------------------------------------------------- /projects/byteformer/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/projects/byteformer/model_arch.png -------------------------------------------------------------------------------- /projects/catlip/README-multi-label-object-classification.md: -------------------------------------------------------------------------------- 1 | # Multi-label Object Classification using CatLIP 2 | 3 | Below are instructions for [training](#training-on-coco) a pre-trained CatLIP model on the COCO dataset and [evaluating](#evaluation) its accuracy. 4 | 5 | We also provide [pre-trained model weights](#pretrained-model-weights-on-coco) for different multi-label classification models. 6 | 7 | ## Training on COCO 8 | 9 | To finetune ViT-B, pretrained using CatLIP, on COCO using four A100 GPU, run the following command: 10 | 11 | ```bash 12 | export CFG_FILE=projects/catlip/multi_label_image_classification/vit_base.yaml 13 | corenet-train --common.config-file $CFG_FILE --common.results-loc classification_results 14 | ``` 15 | 16 | We assume that the training and validation data is located at `/mnt/vision_datasets/coco`. 17 | 18 | ## Evaluation 19 | 20 | To evaluate the finetuned `ViT-B` model on the validation set of the COCO, run the following command: 21 | 22 | ```bash 23 | export CFG_FILE=projects/catlip/multi_label_image_classification/vit_base.yaml 24 | export DATASET_PATH="/mnt/vision_datasets/coco" # change to the COCO validation path 25 | export MODEL_WEIGHTS=https://docs-assets.developer.apple.com/ml-research/models/corenet/v0.1.0/catlip/multi-label-classification/coco/vit_base.pt 26 | CUDA_VISIBLE_DEVICES=0 corenet-eval --common.config-file $CFG_FILE --common.override-kwargs dataset.root_val=$DATASET_PATH model.classification.pretrained=$MODEL_WEIGHTS model.resume_exclude_scopes='' 27 | ``` 28 | 29 | This should give 30 | ``` 31 | 'micro': 0.9118, 'macro': 0.8806, 'weighted': 0.8907 32 | ``` 33 | 34 | ## Pretrained Model Weights on COCO 35 | 36 | | Model | Macro mAP | Pretrained weights | 37 | | ---- | ---- | ---- | 38 | | ViT-B/16 | 88.06 | [Link](https://docs-assets.developer.apple.com/ml-research/models/corenet/v0.1.0/catlip/multi-label-classification/coco/vit_base.pt) | 39 | | ViT-L/16 | 90.75 | [Link](https://docs-assets.developer.apple.com/ml-research/models/corenet/v0.1.0/catlip/multi-label-classification/coco/vit_large.pt) | 40 | -------------------------------------------------------------------------------- /projects/kv-prediction/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/projects/kv-prediction/model_arch.png -------------------------------------------------------------------------------- /projects/kv-prediction/triviaqa-template.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | task: triviaqa-fixed 3 | dataset_path: trivia_qa 4 | dataset_name: rc.nocontext 5 | output_type: generate_until 6 | training_split: train 7 | validation_split: validation 8 | doc_to_text: "Question: {{question}}\nAnswer:" 9 | doc_to_target: "{{answer.aliases}}" 10 | should_decontaminate: true 11 | doc_to_decontamination_query: question 12 | fewshot_delimiter: "\n" 13 | generation_kwargs: 14 | until: 15 | - "\n" 16 | - "." 17 | - "," 18 | do_sample: false 19 | temperature: 0.0 20 | filter_list: 21 | - name: remove_whitespace 22 | filter: 23 | - function: remove_whitespace 24 | - function: take_first 25 | target_delimiter: " " 26 | metric_list: 27 | - metric: exact_match 28 | aggregation: mean 29 | higher_is_better: true 30 | ignore_case: true 31 | ignore_punctuation: true 32 | metadata: 33 | version: 3.0 34 | -------------------------------------------------------------------------------- /projects/openelm/instruction_tuning/openelm-instruct.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | # Model arguments 3 | model_name_or_path: OpenELM-500M 4 | torch_dtype: null 5 | use_flash_attention_2: false 6 | 7 | chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 8 | # Data training arguments 9 | # For definitions, see: src/h4/training/config.py 10 | dataset_mixer: 11 | csarron/argilla-ultrafeedback-binarized-preferences-cleaned: 1.0 12 | dataset_splits: 13 | - train 14 | - test 15 | preprocessing_num_workers: 16 16 | 17 | # DPOTrainer arguments 18 | bf16: true 19 | beta: 0.01 20 | do_eval: true 21 | evaluation_strategy: steps 22 | eval_steps: 100 23 | gradient_accumulation_steps: 2 24 | gradient_checkpointing: true 25 | gradient_checkpointing_kwargs: 26 | use_reentrant: False 27 | hub_model_id: OpenELM-500M-dpo 28 | learning_rate: 5.0e-5 29 | log_level: info 30 | logging_steps: 10 31 | lr_scheduler_type: cosine 32 | max_length: 1024 33 | max_prompt_length: 512 34 | num_train_epochs: 3 35 | optim: adamw_torch 36 | output_dir: data/OpenELM-500M-dpo 37 | per_device_train_batch_size: 8 38 | per_device_eval_batch_size: 8 39 | push_to_hub: false 40 | save_strategy: "steps" 41 | save_steps: 100 42 | save_total_limit: 1 43 | seed: 42 44 | warmup_ratio: 0.1 45 | -------------------------------------------------------------------------------- /projects/range_augment/README.md: -------------------------------------------------------------------------------- 1 | # RangeAugment: Efficient Online Augmentation with Range Learning 2 | 3 | [RangeAugment](https://arxiv.org/abs/2212.10553) is an automatic augmentation method that allows us to learn `model- and task-specific` magnitude range of each augmentation operation. 4 | 5 | We provide training and evaluation code along with pretrained models and configuration files for the following tasks: 6 | 7 | 1. [Image Classification on the ImageNet dataset](./README-classification.md) 8 | 2. [Semantic segmentation on the ADE20k and the PASCAL VOC datasets](./README-segmentation.md) 9 | 3. [Object detection on the MS-COCO dataset](./README-object-detection.md) 10 | 4. [Contrastive Learning using Image-Text pairs](./README-clip.md) 11 | 5. [Distillation on the ImageNet dataset](./README-distillation.md) 12 | 13 | ***Note***: In the [codebase](../../corenet/modeling/neural_augmentor), we refer RangeAugment as Neural Augmentor (or NA). 14 | 15 | 16 | ## Citation 17 | 18 | If you find our work useful, please cite: 19 | 20 | ```BibTex 21 | @article{mehta2022rangeaugment, 22 | title={RangeAugment: Efficient Online Augmentation with Range Learning}, 23 | author = {Mehta, Sachin and Naderiparizi, Saeid and Faghri, Fartash and Horton, Maxwell and Chen, Lailin and Farhadi, Ali and Tuzel, Oncel and Rastegari, Mohammad}, 24 | journal={arXiv preprint arXiv:2212.10553}, 25 | year={2022}, 26 | url={https://arxiv.org/abs/2212.10553}, 27 | } 28 | 29 | @inproceedings{mehta2022cvnets, 30 | author = {Mehta, Sachin and Abdolhosseini, Farzad and Rastegari, Mohammad}, 31 | title = {CVNets: High Performance Library for Computer Vision}, 32 | year = {2022}, 33 | booktitle = {Proceedings of the 30th ACM International Conference on Multimedia}, 34 | series = {MM '22} 35 | } 36 | ``` 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | skip_gitignore = true 4 | 5 | [tool.black] 6 | extend-exclude = '.history' 7 | 8 | [tool.pytest.ini_options] 9 | junit_family = 'xunit2' 10 | 11 | # Add-opts documentation: 12 | # "-p no:warnings" instructs pytest to avoid modifying warnings.filters, as we have a 13 | # custom implementation for filtering warnings in corenet/__init__.py. In the CI, the 14 | # unexpected warnings are automatically converted to errors by corenet/__init__.py. 15 | # "--junit-xml" generates execution metadata that is used for visualizing test results. 16 | addopts = '-p no:warnings --junit-xml=./build/test-results/junit_reports/junit.xml' 17 | 18 | markers = 'skip_ci: Mark a test to be skipped in CI to avoid known issues like download failure.' 19 | -------------------------------------------------------------------------------- /requirements-optional.txt: -------------------------------------------------------------------------------- 1 | ffmpeg-python==0.2.0 2 | # Installing decord on Mac is tricky. Syntax: https://pip.pypa.io/en/stable/reference/requirement-specifiers/ 3 | decord==0.6.0 ; sys_platform == 'linux' 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psutil==5.9.8 2 | ujson==5.9.0 3 | scikit-learn==1.4.1.post1 4 | scikit-image==0.22.0 5 | pyyaml==6.0.1 6 | 7 | # requirement for Pytorch, Torchvision, TorchText. 8 | # This section must be synchronized with the Dockerfile for 9 | # the image used in CI 10 | torch==2.3.0 11 | torchvision==0.18.0 12 | torchtext==0.18.0 13 | torchaudio==2.3.0 14 | torchdata==0.7.1 15 | 16 | # dependency for coremltools 17 | coremltools==7.1 18 | 19 | # dependency for MSCOCO dataset 20 | pycocotools==2.0.7 21 | 22 | # dependency for cityscape evaluation 23 | cityscapesscripts==2.2.2 24 | 25 | # added as a dependency to reproduce 3rd party models 26 | pytorchvideo==0.1.5 27 | 28 | # PyAV for video decoding 29 | av==12.0.0 30 | 31 | # FVCore for FLOP calculation 32 | fvcore==0.1.5.post20221221 33 | 34 | # black for reformatting 35 | black==24.4.0 36 | isort==5.13.2 37 | 38 | # testing 39 | pytest==8.1.1 40 | pytest-mock==3.14.0 41 | pytest-xdist==3.5.0 42 | pytest-timeout==2.3.1 43 | 44 | ftfy==6.2.0 45 | 46 | # for hdf5 reading 47 | h5py==3.10.0 48 | 49 | # for reading byte data 50 | pybase64==1.3.2 51 | 52 | # For OpenAI's clip tokenizer 53 | regex==2023.12.25 54 | pyarrow==15.0.2 55 | 56 | numpy==1.26.4 57 | scipy==1.13.0 58 | pandas==2.2.1 59 | tqdm==4.66.2 60 | setuptools==69.2.0 61 | 62 | boto3==1.28.30 63 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/collate_fns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/collate_fns/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/audio_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/audio_classification/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/coco.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | taskname: '+ ResNet-50 SSD' 4 | 5 | common: 6 | run_label: "train" 7 | accum_freq: 1 8 | accum_after_epoch: -1 9 | log_freq: 100 10 | auto_resume: true 11 | mixed_precision: true 12 | 13 | dataset: 14 | root_train: "tests/data/coco" 15 | root_val: "tests/data/coco" 16 | category: "classification" 17 | train_batch_size0: 2 18 | val_batch_size0: 2 19 | eval_batch_size0: 1 20 | workers: 0 21 | persistent_workers: false 22 | pin_memory: true 23 | name: "mock_coco" 24 | 25 | image_augmentation: 26 | # training related parameters 27 | random_resized_crop: 28 | enable: true 29 | interpolation: "bilinear" 30 | random_horizontal_flip: 31 | enable: true 32 | # validation related parameters 33 | resize: 34 | enable: true 35 | size: 64 36 | interpolation: "bilinear" 37 | center_crop: 38 | enable: true 39 | size: 64 40 | 41 | sampler: 42 | name: "batch_sampler" 43 | bs: 44 | crop_size_width: 64 45 | crop_size_height: 64 46 | 47 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/image_classification_dataset.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_train: "tests/data/datasets/classification/dummy_images/training" 4 | root_val: "tests/data/datasets/classification/dummy_images/validation" 5 | collate_fn_name_train: "image_classification_data_collate_fn" 6 | collate_fn_name_val: "image_classification_data_collate_fn" 7 | collate_fn_name_test: "image_classification_data_collate_fn" 8 | name: "dummy" 9 | category: "classification" 10 | train_batch_size0: 2 11 | val_batch_size0: 4 12 | eval_batch_size0: 4 13 | workers: 8 14 | persistent_workers: true 15 | pin_memory: true 16 | image_augmentation: 17 | # training related parameters 18 | random_resized_crop: 19 | enable: true 20 | interpolation: "bilinear" 21 | random_horizontal_flip: 22 | enable: true 23 | auto_augment: 24 | enable: true 25 | cutmix: 26 | alpha: 1.0 27 | enable: true 28 | p: 1.0 29 | mixup: 30 | alpha: 0.2 31 | enable: true 32 | p: 1.0 33 | # validation related parameters 34 | resize: 35 | enable: true 36 | size: 232 37 | interpolation: "bilinear" 38 | center_crop: 39 | enable: true 40 | size: 224 41 | 42 | sampler: 43 | name: "batch_sampler" 44 | bs: 45 | crop_size_width: 256 46 | crop_size_height: 256 47 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_train: "/mnt/imagenet/training" 4 | root_val: "/mnt/imagenet/validation" 5 | name: "imagenet" 6 | category: "classification" 7 | train_batch_size0: 2 8 | val_batch_size0: 2 9 | eval_batch_size0: 2 10 | workers: 8 11 | persistent_workers: true 12 | pin_memory: true 13 | 14 | image_augmentation: 15 | # training related parameters 16 | random_resized_crop: 17 | enable: true 18 | interpolation: "bilinear" 19 | random_horizontal_flip: 20 | enable: true 21 | auto_augment: 22 | enable: true 23 | cutmix: 24 | alpha: 1.0 25 | enable: true 26 | p: 1.0 27 | mixup: 28 | alpha: 0.2 29 | enable: true 30 | p: 1.0 31 | # validation related parameters 32 | resize: 33 | enable: true 34 | size: 232 35 | interpolation: "bilinear" 36 | center_crop: 37 | enable: true 38 | size: 224 39 | 40 | sampler: 41 | name: "variable_batch_sampler" 42 | vbs: 43 | crop_size_width: 224 44 | crop_size_height: 224 45 | max_n_scales: 5 46 | min_crop_size_width: 128 47 | max_crop_size_width: 320 48 | min_crop_size_height: 128 49 | max_crop_size_height: 320 50 | check_scale: 32 51 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_val: "/mnt/vision_datasets/imagenet-a-1.0.0/data/raw/" 4 | name: "imagenet_a" 5 | category: "classification" 6 | train_batch_size0: 2 7 | val_batch_size0: 2 8 | eval_batch_size0: 2 9 | workers: 8 10 | persistent_workers: true 11 | pin_memory: true 12 | 13 | model: 14 | classification: 15 | n_classes: 1000 16 | 17 | image_augmentation: 18 | # training related parameters 19 | random_resized_crop: 20 | enable: true 21 | interpolation: "bilinear" 22 | random_horizontal_flip: 23 | enable: true 24 | auto_augment: 25 | enable: true 26 | cutmix: 27 | alpha: 1.0 28 | enable: true 29 | p: 1.0 30 | mixup: 31 | alpha: 0.2 32 | enable: true 33 | p: 1.0 34 | # validation related parameters 35 | resize: 36 | enable: true 37 | size: 232 38 | interpolation: "bilinear" 39 | center_crop: 40 | enable: true 41 | size: 224 42 | 43 | sampler: 44 | name: "variable_batch_sampler" 45 | vbs: 46 | crop_size_width: 224 47 | crop_size_height: 224 48 | max_n_scales: 5 49 | min_crop_size_width: 128 50 | max_crop_size_width: 320 51 | min_crop_size_height: 128 52 | max_crop_size_height: 320 53 | check_scale: 32 54 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_val: "/mnt/vision_datasets/imagenet-r-1.0.0/data/raw/" 4 | name: "imagenet_r" 5 | category: "classification" 6 | train_batch_size0: 2 7 | val_batch_size0: 2 8 | eval_batch_size0: 2 9 | workers: 8 10 | persistent_workers: true 11 | pin_memory: true 12 | 13 | model: 14 | classification: 15 | n_classes: 1000 16 | 17 | image_augmentation: 18 | # training related parameters 19 | random_resized_crop: 20 | enable: true 21 | interpolation: "bilinear" 22 | random_horizontal_flip: 23 | enable: true 24 | auto_augment: 25 | enable: true 26 | cutmix: 27 | alpha: 1.0 28 | enable: true 29 | p: 1.0 30 | mixup: 31 | alpha: 0.2 32 | enable: true 33 | p: 1.0 34 | # validation related parameters 35 | resize: 36 | enable: true 37 | size: 232 38 | interpolation: "bilinear" 39 | center_crop: 40 | enable: true 41 | size: 224 42 | 43 | sampler: 44 | name: "variable_batch_sampler" 45 | vbs: 46 | crop_size_width: 224 47 | crop_size_height: 224 48 | max_n_scales: 5 49 | min_crop_size_width: 128 50 | max_crop_size_width: 320 51 | min_crop_size_height: 128 52 | max_crop_size_height: 320 53 | check_scale: 32 54 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_val: "/mnt/vision_datasets/imagenet-sketch-1.0.0/data/raw/" 4 | name: "imagenet_sketch" 5 | category: "classification" 6 | train_batch_size0: 2 7 | val_batch_size0: 2 8 | eval_batch_size0: 2 9 | workers: 8 10 | persistent_workers: true 11 | pin_memory: true 12 | 13 | model: 14 | classification: 15 | n_classes: 1000 16 | 17 | image_augmentation: 18 | # training related parameters 19 | random_resized_crop: 20 | enable: true 21 | interpolation: "bilinear" 22 | random_horizontal_flip: 23 | enable: true 24 | auto_augment: 25 | enable: true 26 | cutmix: 27 | alpha: 1.0 28 | enable: true 29 | p: 1.0 30 | mixup: 31 | alpha: 0.2 32 | enable: true 33 | p: 1.0 34 | # validation related parameters 35 | resize: 36 | enable: true 37 | size: 232 38 | interpolation: "bilinear" 39 | center_crop: 40 | enable: true 41 | size: 224 42 | 43 | sampler: 44 | name: "variable_batch_sampler" 45 | vbs: 46 | crop_size_width: 224 47 | crop_size_height: 224 48 | max_n_scales: 5 49 | min_crop_size_width: 128 50 | max_crop_size_width: 320 51 | min_crop_size_height: 128 52 | max_crop_size_height: 320 53 | check_scale: 32 54 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_configs/wordnet_tagged_classification.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | taskname: '+ CatLIP ViT-B/16 [DataComp]' 4 | 5 | _anchor_vocab_size: &_anchor_vocab_size 10 6 | 7 | common: 8 | run_label: "train" 9 | log_freq: 500 10 | auto_resume: true 11 | mixed_precision: true 12 | mixed_precision_dtype: "bfloat16" 13 | grad_clip: 1.0 14 | save_all_checkpoints: true 15 | save_interval_freq: 5000 16 | 17 | dataset: 18 | # root_train does not matter for img_text_tar dataset because dataset is information is expected 19 | # to be contained in metadata file. 20 | root_train: "" 21 | disable_val: true 22 | train_batch_size0: 2 23 | workers: 0 24 | persistent_workers: true 25 | pin_memory: true 26 | name: "wordnet_tagged_classification" 27 | category: "classification" 28 | wordnet_tagged_classification: 29 | metadata_file: ".corenet_data_cache/metadata.pkl" 30 | vocab_file: ".corenet_data_cache/vocab.pkl" 31 | vocab_size: *_anchor_vocab_size 32 | 33 | image_augmentation: 34 | # training related augmentations 35 | random_resized_crop: 36 | enable: true 37 | interpolation: "bilinear" 38 | random_horizontal_flip: 39 | enable: true 40 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/training/class1/dummy_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/training/class1/dummy_image1.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/training/class1/dummy_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/training/class1/dummy_image2.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/training/class2/dummy_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/training/class2/dummy_image1.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/training/class2/dummy_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/training/class2/dummy_image2.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/validation/class1/dummy_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/validation/class1/dummy_image1.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/validation/class1/dummy_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/validation/class1/dummy_image2.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/validation/class2/dummy_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/validation/class2/dummy_image1.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/dummy_images/validation/class2/dummy_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/classification/dummy_images/validation/class2/dummy_image2.jpg -------------------------------------------------------------------------------- /tests/data/datasets/classification/mock_coco.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from corenet.data.datasets import DATASET_REGISTRY 12 | from corenet.data.datasets.classification.coco import COCOClassification 13 | 14 | 15 | @DATASET_REGISTRY.register(name="mock_coco", type="classification") 16 | class MockCOCOClassification(COCOClassification): 17 | @staticmethod 18 | def read_image_pil(path: str) -> Optional[Image.Image]: 19 | """Mock the init logic for read_image_pil function. 20 | 21 | Instead of reading a PIL image at location specified by `path`, a random PIL 22 | image is returned. 23 | """ 24 | im_arr = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) 25 | return Image.fromarray(im_arr).convert("RGB") 26 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/test_mock_coco.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from corenet.data.loader.dataloader import CoreNetDataLoader 7 | from corenet.data.sampler import build_sampler 8 | from tests.configs import get_config 9 | from tests.data.datasets.classification.mock_coco import MockCOCOClassification 10 | 11 | 12 | def test_coco_dataset() -> None: 13 | """Test for COCO classification dataset.""" 14 | config_file_path = "tests/data/datasets/classification/dummy_configs/coco.yaml" 15 | opts = get_config(config_file=config_file_path) 16 | 17 | dataset = MockCOCOClassification(opts) 18 | 19 | train_sampler = build_sampler(opts, n_data_samples=len(dataset), is_training=True) 20 | 21 | train_loader = CoreNetDataLoader( 22 | dataset=dataset, 23 | batch_sampler=train_sampler, 24 | batch_size=1, 25 | num_workers=0, 26 | ) 27 | 28 | for batch in train_loader: 29 | assert batch.keys() == {"samples", "targets", "sample_id"} 30 | # bounds from the config file 31 | assert list(batch["samples"].shape) == [ 32 | 2, 33 | 3, 34 | 64, 35 | 64, 36 | ], "The output shape should be [2, 3, 64, 64]." 37 | assert list(batch["targets"].shape) == [ 38 | 2, 39 | 80, 40 | ], "Batch size should be 2 and number of classes should be 80." 41 | assert ( 42 | batch["sample_id"].dim() == 1 43 | ), "Expecting sample_id's in [batch, ] format" 44 | -------------------------------------------------------------------------------- /tests/data/datasets/classification/test_wordnet_tagged_classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | 8 | import pytest 9 | import torch 10 | 11 | from tests.configs import get_config 12 | from tests.data.datasets.classification.mock_wordnet_tagged_classification import ( 13 | MockWordnetTaggedClassificationDataset, 14 | ) 15 | 16 | 17 | @pytest.mark.parametrize("image_size", [16, 32]) 18 | def test_wordnet_tagged_classification_dataset(image_size: int) -> None: 19 | """Test for WordnetTaggedClassificationDataset dataset.""" 20 | 21 | if "nltk" in sys.modules: 22 | config_file = "tests/data/datasets/classification/dummy_configs/wordnet_tagged_classification.yaml" 23 | opts = get_config(config_file=config_file) 24 | 25 | dataset = MockWordnetTaggedClassificationDataset( 26 | opts, is_training=True, is_evaluation=False 27 | ) 28 | 29 | sample_index = 0 30 | data_item = dataset.__getitem__((image_size, image_size, sample_index)) 31 | assert "samples" in data_item 32 | assert "targets" in data_item 33 | assert list(data_item["samples"].shape) == [3, image_size, image_size] 34 | assert list(data_item["targets"].shape) == [10] 35 | 36 | exptected_target_label = torch.tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0]) 37 | assert torch.all( 38 | data_item["targets"] 39 | == exptected_target_label.to( 40 | dtype=data_item["targets"].dtype, device=data_item["targets"].device 41 | ) 42 | ) 43 | -------------------------------------------------------------------------------- /tests/data/datasets/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/detection/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/detection/mock_coco_mask_rcnn.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from corenet.data.datasets import DATASET_REGISTRY 12 | from corenet.data.datasets.detection.coco_mask_rcnn import COCODetectionMaskRCNN 13 | 14 | 15 | @DATASET_REGISTRY.register(name="mock_coco_mask_rcnn", type="detection") 16 | class MockCOCODetectionMaskRCNN(COCODetectionMaskRCNN): 17 | @staticmethod 18 | def read_image_pil(path: str) -> Optional[Image.Image]: 19 | """Mock the init logic for read_image_pil function. 20 | 21 | Instead of reading a PIL image at location specified by `path`, a random PIL 22 | image is returned. 23 | """ 24 | im_arr = np.random.randint(low=0, high=255, size=(64, 64, 3), dtype=np.uint8) 25 | return Image.fromarray(im_arr).convert("RGB") 26 | -------------------------------------------------------------------------------- /tests/data/datasets/detection/mock_coco_ssd.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Optional 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from corenet.data.datasets import DATASET_REGISTRY 12 | from corenet.data.datasets.detection.coco_ssd import COCODetectionSSD 13 | 14 | 15 | @DATASET_REGISTRY.register(name="mock_coco_ssd", type="detection") 16 | class MockCOCODetectionSSD(COCODetectionSSD): 17 | @staticmethod 18 | def read_image_pil(path: str) -> Optional[Image.Image]: 19 | """Mock the init logic for read_image_pil function. 20 | 21 | Instead of reading a PIL image at location specified by `path`, a random PIL 22 | image is returned. 23 | """ 24 | im_arr = np.random.randint(low=0, high=255, size=(64, 64, 3), dtype=np.uint8) 25 | return Image.fromarray(im_arr).convert("RGB") 26 | -------------------------------------------------------------------------------- /tests/data/datasets/language_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/language_modeling/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/language_modeling/dummy_commonsense_170k.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | root_train: "" 4 | category: language_modeling 5 | name: commonsense_170k 6 | language_modeling: 7 | shuffle_data: false 8 | sequence_length: 5 9 | min_tokens_per_text: 0 10 | min_characters_per_text: 0 11 | commonsense_170k: 12 | path: "" 13 | 14 | text_tokenizer: 15 | name: "openai_clip" 16 | pad_token: "pad" 17 | -------------------------------------------------------------------------------- /tests/data/datasets/language_modeling/dummy_lm_dataset.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | root_train: "" 5 | disable_val: true 6 | workers: 4 7 | # dataset details 8 | category: "language_modeling" 9 | name: "mock_general_lm" 10 | language_modeling: 11 | sequence_length: 10 12 | min_tokens_per_text: 0 13 | min_characters_per_text: 0 14 | shuffle_data: true 15 | general_lm: 16 | data_state_save_interval: 0 17 | reader_chunk_size: 1 18 | train_data_info: [ 19 | { 20 | "file_name": ".corenet_data_cache/sample.jsonl", 21 | "text_key": "text", 22 | "file_id_range": [0, 1], 23 | }, 24 | { 25 | "file_name": ".corenet_data_cache/sample.json.gz", 26 | "text_key": "text", 27 | "file_id_range": [0, 1], 28 | }, 29 | ] 30 | 31 | text_tokenizer: 32 | name: "openai_clip" 33 | pad_token: "pad" 34 | -------------------------------------------------------------------------------- /tests/data/datasets/language_modeling/mock_general_lm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import gzip 8 | import json 9 | 10 | from corenet.constants import DATA_CACHE_DIR 11 | from corenet.data.datasets import DATASET_REGISTRY 12 | from corenet.data.datasets.language_modeling.general_lm import GeneralLMDataset 13 | 14 | 15 | def _generate_dummy_json_data() -> None: 16 | data = [ 17 | { 18 | "text": "Hello world, CoreNet serves as a versatile research library catering to a wide array of purposes. It has been used for small- and large-scale training, with numerous research papers leveraging its functionalities and contributing to various domains of study.", 19 | } 20 | ] * 12 21 | 22 | with open(f"{DATA_CACHE_DIR}/sample.jsonl", "w") as outfile: 23 | for entry in data: 24 | print(json.dumps(entry), file=outfile) 25 | 26 | 27 | def _generate_dummy_json_gz_data() -> None: 28 | data = [{"text": " !"}] * 2 29 | 30 | with gzip.open(f"{DATA_CACHE_DIR}/sample.json.gz", "w") as outfile: 31 | for text in data: 32 | json_str = json.dumps(text) + "\n" 33 | json_bytes = json_str.encode("utf-8") 34 | outfile.write(json_bytes) 35 | 36 | 37 | @DATASET_REGISTRY.register(name="mock_general_lm", type="language_modeling") 38 | class MockImgGeneralLMDataset(GeneralLMDataset): 39 | """A wrapper around GeneralLMDataset that generates dummy data for CI/CD. 40 | 41 | Args: 42 | opts: Command-line arguments. 43 | """ 44 | 45 | def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: 46 | _generate_dummy_json_data() 47 | _generate_dummy_json_gz_data() 48 | super().__init__(opts, *args, **kwargs) 49 | -------------------------------------------------------------------------------- /tests/data/datasets/language_modeling/test_commonsense_170k.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import json 7 | import tempfile 8 | 9 | import yaml 10 | 11 | from corenet.data.datasets.language_modeling import commonsense_170k 12 | from corenet.options.utils import flatten_yaml_as_dict 13 | from tests.configs import get_config 14 | from tests.data.datasets.language_modeling import test_general_lm 15 | 16 | 17 | def write_data(filename: str) -> None: 18 | data = [ 19 | { 20 | "instruction": "Please answer the following question with true or false. Question: is the sky blue?", 21 | "input": "", 22 | "output": "the correct answer is true", 23 | "answer": "true", 24 | } 25 | ] * 5 26 | # Make input non-empty for one data point. 27 | data[0]["input"] = "This is an example input." 28 | with open(filename, "w+") as f: 29 | json.dump(data, f) 30 | 31 | 32 | def test_general_lm_dataset() -> None: 33 | """Test for GeneralLMDataset dataset.""" 34 | sequence_length = 5 35 | with tempfile.NamedTemporaryFile() as tmp: 36 | write_data(tmp.name) 37 | config_file = ( 38 | "tests/data/datasets/language_modeling/dummy_commonsense_170k.yaml" 39 | ) 40 | opts = get_config(config_file=config_file) 41 | setattr(opts, "dataset.language_modeling.sequence_length", 5) 42 | setattr(opts, "dataset.language_modeling.commonsense_170k.path", tmp.name) 43 | 44 | dataset = commonsense_170k.CommonSense170k(opts) 45 | max_iterations = 12 46 | 47 | test_general_lm._iterate_and_test_dataset( 48 | dataset, 49 | max_iterations=max_iterations, 50 | expected_sequence_length=sequence_length, 51 | ) 52 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/multi_modal_img_text/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/dummy_img_text_tar_dataset.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | # The training path in 'img_text_tar' dataset is infered from metadata file path. 5 | root_train: "" 6 | disable_val: true 7 | 8 | name: "img_text_tar" 9 | category: "multi_modal_image_text" 10 | multi_modal_img_text: 11 | img_text_tar: 12 | metadata_file: ".corenet_data_cache/metadata.pkl" 13 | 14 | text_tokenizer: 15 | name: "openai_clip" 16 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/test_img_text_tar_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | 7 | import pytest 8 | 9 | from tests.configs import get_config 10 | from tests.data.datasets.multi_modal_img_text.mock_img_text_tar_dataset import ( 11 | MockImgTextTarDataset, 12 | ) 13 | 14 | 15 | @pytest.mark.parametrize("image_size", [16, 32]) 16 | @pytest.mark.parametrize("context_length", [12, 77]) 17 | def test_img_text_dataset(image_size: int, context_length: int) -> None: 18 | """Test for ImgTextTarDataset dataset.""" 19 | 20 | config_file = ( 21 | "tests/data/datasets/multi_modal_img_text/dummy_img_text_tar_dataset.yaml" 22 | ) 23 | opts = get_config(config_file=config_file) 24 | setattr(opts, "dataset.multi_modal_img_text.context_length", context_length) 25 | 26 | dataset = MockImgTextTarDataset(opts, is_training=True, is_evaluation=False) 27 | 28 | sample_index = 0 29 | data_item = dataset.__getitem__((image_size, image_size, sample_index)) 30 | assert "samples" in data_item 31 | assert "targets" in data_item 32 | assert data_item["targets"] == -1 33 | assert "image" in data_item["samples"] 34 | assert list(data_item["samples"]["image"].shape) == [3, image_size, image_size] 35 | assert list(data_item["samples"]["text"].shape) == [context_length] 36 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot_image_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/datasets/multi_modal_img_text/zero_shot_image_classification/__init__.py -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot_image_classification/dummy_configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | category: "multi_modal_image_text" 4 | multi_modal_img_text: 5 | zero_shot_img_cls_dataset_name: "imagenet" 6 | context_length: 77 7 | img_text_tar: 8 | metadata_file: "PATH_OF_METADATA_FILE" 9 | 10 | 11 | image_augmentation: 12 | # training related parameters 13 | random_resized_crop: 14 | enable: true 15 | interpolation: "bilinear" 16 | random_horizontal_flip: 17 | enable: true 18 | auto_augment: 19 | enable: true 20 | cutmix: 21 | alpha: 1.0 22 | enable: true 23 | p: 1.0 24 | mixup: 25 | alpha: 0.2 26 | enable: true 27 | p: 1.0 28 | # validation related parameters 29 | resize: 30 | enable: true 31 | size: 232 32 | interpolation: "bilinear" 33 | center_crop: 34 | enable: true 35 | size: 224 36 | 37 | sampler: 38 | name: "variable_batch_sampler" 39 | vbs: 40 | crop_size_width: 224 41 | crop_size_height: 224 42 | max_n_scales: 5 43 | min_crop_size_width: 128 44 | max_crop_size_width: 320 45 | min_crop_size_height: 128 46 | max_crop_size_height: 320 47 | check_scale: 32 48 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot_image_classification/dummy_configs/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | category: "multi_modal_image_text" 4 | multi_modal_img_text: 5 | zero_shot_img_cls_dataset_name: "imagenet_a" 6 | context_length: 77 7 | img_text_tar: 8 | metadata_file: "PATH_OF_METADATA_FILE" 9 | 10 | 11 | image_augmentation: 12 | # training related parameters 13 | random_resized_crop: 14 | enable: true 15 | interpolation: "bilinear" 16 | random_horizontal_flip: 17 | enable: true 18 | auto_augment: 19 | enable: true 20 | cutmix: 21 | alpha: 1.0 22 | enable: true 23 | p: 1.0 24 | mixup: 25 | alpha: 0.2 26 | enable: true 27 | p: 1.0 28 | # validation related parameters 29 | resize: 30 | enable: true 31 | size: 232 32 | interpolation: "bilinear" 33 | center_crop: 34 | enable: true 35 | size: 224 36 | 37 | sampler: 38 | name: "variable_batch_sampler" 39 | vbs: 40 | crop_size_width: 224 41 | crop_size_height: 224 42 | max_n_scales: 5 43 | min_crop_size_width: 128 44 | max_crop_size_width: 320 45 | min_crop_size_height: 128 46 | max_crop_size_height: 320 47 | check_scale: 32 48 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot_image_classification/dummy_configs/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | category: "multi_modal_image_text" 4 | multi_modal_img_text: 5 | zero_shot_img_cls_dataset_name: "imagenet_r" 6 | context_length: 77 7 | img_text_tar: 8 | metadata_file: "PATH_OF_METADATA_FILE" 9 | 10 | 11 | image_augmentation: 12 | # training related parameters 13 | random_resized_crop: 14 | enable: true 15 | interpolation: "bilinear" 16 | random_horizontal_flip: 17 | enable: true 18 | auto_augment: 19 | enable: true 20 | cutmix: 21 | alpha: 1.0 22 | enable: true 23 | p: 1.0 24 | mixup: 25 | alpha: 0.2 26 | enable: true 27 | p: 1.0 28 | # validation related parameters 29 | resize: 30 | enable: true 31 | size: 232 32 | interpolation: "bilinear" 33 | center_crop: 34 | enable: true 35 | size: 224 36 | 37 | sampler: 38 | name: "variable_batch_sampler" 39 | vbs: 40 | crop_size_width: 224 41 | crop_size_height: 224 42 | max_n_scales: 5 43 | min_crop_size_width: 128 44 | max_crop_size_width: 320 45 | min_crop_size_height: 128 46 | max_crop_size_height: 320 47 | check_scale: 32 48 | -------------------------------------------------------------------------------- /tests/data/datasets/multi_modal_img_text/zero_shot_image_classification/dummy_configs/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | dataset: 3 | category: "multi_modal_image_text" 4 | multi_modal_img_text: 5 | zero_shot_img_cls_dataset_name: "imagenet_sketch" 6 | context_length: 77 7 | img_text_tar: 8 | metadata_file: "PATH_OF_METADATA_FILE" 9 | 10 | 11 | image_augmentation: 12 | # training related parameters 13 | random_resized_crop: 14 | enable: true 15 | interpolation: "bilinear" 16 | random_horizontal_flip: 17 | enable: true 18 | auto_augment: 19 | enable: true 20 | cutmix: 21 | alpha: 1.0 22 | enable: true 23 | p: 1.0 24 | mixup: 25 | alpha: 0.2 26 | enable: true 27 | p: 1.0 28 | # validation related parameters 29 | resize: 30 | enable: true 31 | size: 232 32 | interpolation: "bilinear" 33 | center_crop: 34 | enable: true 35 | size: 224 36 | 37 | sampler: 38 | name: "variable_batch_sampler" 39 | vbs: 40 | crop_size_width: 224 41 | crop_size_height: 224 42 | max_n_scales: 5 43 | min_crop_size_width: 128 44 | max_crop_size_width: 320 45 | min_crop_size_height: 128 46 | max_crop_size_height: 320 47 | check_scale: 32 48 | -------------------------------------------------------------------------------- /tests/data/datasets/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/datasets/segmentation/dummy_ade20k_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | root_train: "/mnt/vision_datasets/ADEChallengeData2016/" 5 | root_val: "/mnt/vision_datasets/ADEChallengeData2016/" 6 | name: "ade20k" 7 | category: "segmentation" 8 | train_batch_size0: 4 9 | val_batch_size0: 4 10 | eval_batch_size0: 2 11 | workers: 4 12 | persistent_workers: false 13 | pin_memory: false 14 | image_augmentation: 15 | random_crop: 16 | enable: true 17 | seg_class_max_ratio: 0.75 18 | pad_if_needed: true 19 | mask_fill: 0 # background idx is 0 20 | random_horizontal_flip: 21 | enable: true 22 | resize: 23 | enable: true 24 | size: [64, 64] 25 | interpolation: "bilinear" 26 | random_short_size_resize: 27 | enable: true 28 | interpolation: "bilinear" 29 | short_side_min: 32 30 | short_side_max: 64 31 | max_img_dim: 64 32 | photo_metric_distort: 33 | enable: true 34 | random_rotate: 35 | enable: true 36 | angle: 10 37 | mask_fill: 0 # background idx is 0 38 | random_gaussian_noise: 39 | enable: true 40 | sampler: 41 | name: "batch_sampler" 42 | bs: 43 | crop_size_width: 64 44 | crop_size_height: 64 45 | evaluation: 46 | segmentation: 47 | resize_input_images: false 48 | -------------------------------------------------------------------------------- /tests/data/datasets/segmentation/dummy_cocostuff_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | root_train: "/mnt/vision_datasets/cocostuff/" 5 | root_val: "/mnt/vision_datasets/cocostuff/" 6 | name: "coco_stuff" 7 | category: "segmentation" 8 | train_batch_size0: 4 9 | val_batch_size0: 4 10 | eval_batch_size0: 2 11 | workers: 4 12 | persistent_workers: false 13 | pin_memory: false 14 | image_augmentation: 15 | random_crop: 16 | enable: true 17 | seg_class_max_ratio: 0.75 18 | pad_if_needed: true 19 | mask_fill: 255 # Same as the ignore index value in the loss function 20 | random_horizontal_flip: 21 | enable: true 22 | resize: 23 | enable: true 24 | size: [64, 64] 25 | interpolation: "bilinear" 26 | random_short_size_resize: 27 | enable: true 28 | interpolation: "bilinear" 29 | short_side_min: 32 30 | short_side_max: 64 31 | max_img_dim: 64 32 | photo_metric_distort: 33 | enable: true 34 | random_rotate: 35 | enable: true 36 | angle: 10 37 | mask_fill: 255 # Same as the ignore index value in the loss function 38 | random_gaussian_noise: 39 | enable: true 40 | sampler: 41 | name: "batch_sampler" 42 | bs: 43 | crop_size_width: 64 44 | crop_size_height: 64 45 | evaluation: 46 | segmentation: 47 | resize_input_images: false 48 | -------------------------------------------------------------------------------- /tests/data/datasets/test_image_pil.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import pytest 9 | import torch 10 | 11 | from corenet.data.transforms import image_pil 12 | 13 | 14 | def test_to_tensor() -> None: 15 | parser = argparse.ArgumentParser() 16 | parser = image_pil.ToTensor.add_arguments(parser) 17 | opts = parser.parse_args([]) 18 | 19 | to_tensor = image_pil.ToTensor(opts=opts) 20 | 21 | H, W, C = 2, 2, 3 22 | num_masks = 2 23 | data = { 24 | "image": torch.rand([H, W, C]), 25 | "mask": torch.randint(0, 1, [num_masks, H, W]), 26 | } 27 | 28 | output = to_tensor(data) 29 | 30 | assert output["image"].shape == (H, W, C) 31 | assert output["mask"].shape == (num_masks, H, W) 32 | 33 | 34 | def test_to_tensor_bad_mask() -> None: 35 | parser = argparse.ArgumentParser() 36 | parser = image_pil.ToTensor.add_arguments(parser) 37 | opts = parser.parse_args([]) 38 | 39 | to_tensor = image_pil.ToTensor(opts=opts) 40 | 41 | H, W, C = 2, 2, 3 42 | num_categories = 2 43 | data = { 44 | "image": torch.rand([H, W, C]), 45 | "mask": torch.randint(0, 1, [num_categories, 1, H, W]), 46 | } 47 | 48 | with pytest.raises(SystemExit): 49 | to_tensor(data) 50 | -------------------------------------------------------------------------------- /tests/data/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/dummy_silent_video.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/dummy_silent_video.mov -------------------------------------------------------------------------------- /tests/data/dummy_video.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/dummy_video.mov -------------------------------------------------------------------------------- /tests/data/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/io/__init__.py -------------------------------------------------------------------------------- /tests/data/io/test_transfer_clients.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | import os 8 | 9 | import pytest 10 | 11 | from corenet.options.opts import get_training_arguments 12 | from corenet.utils.download_utils import get_local_path 13 | 14 | 15 | @pytest.mark.skip_ci 16 | @pytest.mark.parametrize( 17 | "file_path", 18 | [ 19 | # Downloading of below file has been tested on Oct 30, 2023. 20 | # To avoid CI/CD breaking, we skip these tests during CI/CD. 21 | "https://github.com/apple/ml-cvnets/blob/main/examples/range_augment/classification/mobilenet_v1.yaml", 22 | "http://farm4.staticflickr.com/3217/2975157083_4567dde5d5_z.jpg", 23 | ], 24 | ) 25 | def test_client(file_path: str): 26 | opts = get_training_arguments(args=[]) 27 | 28 | local_path = get_local_path( 29 | opts=opts, 30 | path=file_path, 31 | max_retries=1, 32 | ) 33 | assert os.path.isfile(local_path) 34 | -------------------------------------------------------------------------------- /tests/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/data/samplers/test_batch_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "batch_sampler" 4 | num_repeats: 1 5 | truncated_repeat_aug_sampler: false 6 | bs: 7 | crop_size_width: 224 8 | crop_size_height: 224 9 | -------------------------------------------------------------------------------- /tests/data/samplers/test_chain_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | # dummy configuration for testing chain sampler 3 | # We over-ride different options inside the test to study different cases of chain sampler 4 | sampler: 5 | name: "chain_sampler" 6 | chain_sampler_mode: "sequential" 7 | chain_sampler: 8 | - task_name: "task_1" 9 | train_batch_size0: 128 10 | val_batch_size0: 100 11 | sampler_config: 12 | name: "variable_batch_sampler" 13 | vbs: 14 | crop_size_width: 224 15 | crop_size_height: 224 16 | max_n_scales: 25 17 | min_crop_size_width: 128 18 | max_crop_size_width: 320 19 | min_crop_size_height: 128 20 | max_crop_size_height: 320 21 | check_scale: 16 22 | -------------------------------------------------------------------------------- /tests/data/samplers/test_multi_scale_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "multi_scale_sampler" 4 | num_repeats: 1 5 | truncated_repeat_aug_sampler: false 6 | msc: 7 | crop_size_width: 224 8 | crop_size_height: 224 9 | max_n_scales: 5 10 | min_crop_size_width: 128 11 | max_crop_size_width: 320 12 | min_crop_size_height: 128 13 | max_crop_size_height: 320 14 | check_scale: 32 15 | -------------------------------------------------------------------------------- /tests/data/samplers/test_variable_batch_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "variable_batch_sampler" 4 | vbs: 5 | crop_size_width: 224 6 | crop_size_height: 224 7 | max_n_scales: 5 8 | min_crop_size_width: 128 9 | max_crop_size_width: 320 10 | min_crop_size_height: 128 11 | max_crop_size_height: 320 12 | check_scale: 32 13 | -------------------------------------------------------------------------------- /tests/data/samplers/test_video_clip_batch_sampler_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | sampler: 3 | name: "video_clip_batch_sampler" 4 | vbs: 5 | crop_size_width: 224 6 | crop_size_height: 224 7 | check_scale: 32 8 | vcbs: 9 | num_frames_per_clip: 8 10 | video_fps: 8 11 | audio_fps: 16000 12 | max_num_clips_per_batch: 2 13 | num_clips_per_second_train: 2 14 | num_clips_per_second_val: 2 15 | num_samples_per_clip: 2 16 | video_fps_num_scales: 2 17 | -------------------------------------------------------------------------------- /tests/data/text_tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/text_tokenizer/__init__.py -------------------------------------------------------------------------------- /tests/data/text_tokenizer/test_clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from argparse import Namespace 7 | 8 | import torch 9 | 10 | from corenet.data.text_tokenizer.clip_tokenizer import ClipTokenizer 11 | 12 | 13 | def test_clip_tokenizer(): 14 | """Test for 'ClipTokenizer'.""" 15 | opts = Namespace() 16 | 17 | setattr( 18 | opts, 19 | "text_tokenizer.clip.merges_path", 20 | "http://download.pytorch.org/models/text/clip_merges.bpe", 21 | ) 22 | setattr( 23 | opts, 24 | "text_tokenizer.clip.encoder_json_path", 25 | "http://download.pytorch.org/models/text/clip_encoder.json", 26 | ) 27 | 28 | tokenizer = ClipTokenizer(opts=opts) 29 | out = tokenizer("the quick brown fox jumped over the lazy dog") 30 | 31 | expected_data = [ 32 | 49406, # Start token id 33 | 518, 34 | 3712, 35 | 2866, 36 | 3240, 37 | 16901, 38 | 962, 39 | 518, 40 | 10753, 41 | 1929, 42 | 49407, # end token id 43 | ] 44 | expected_out = torch.tensor(expected_data, dtype=out.dtype) 45 | torch.testing.assert_close(actual=out, expected=expected_out) 46 | assert tokenizer.sot_token == "<|startoftext|>" 47 | assert tokenizer.eot_token == "<|endoftext|>" 48 | assert tokenizer.sot_token_id == 49406 49 | assert tokenizer.eot_token_id == 49407 50 | -------------------------------------------------------------------------------- /tests/data/text_tokenizer/test_openai_clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from argparse import Namespace 7 | 8 | import torch 9 | 10 | from corenet.third_party.data.text_tokenizer.openai_clip_tokenizer import ( 11 | OpenAIClipTokenizer, 12 | ) 13 | 14 | 15 | def test_openai_clip_tokenizer(): 16 | """Test for OpenAIClipTokenizer.""" 17 | opts = Namespace() 18 | 19 | setattr( 20 | opts, 21 | "text_tokenizer.openai_clip.bpe_path", 22 | "https://github.com/openai/CLIP/raw/a1d071733d7111c9c014f024669f959182114e33/clip/bpe_simple_vocab_16e6.txt.gz", 23 | ) 24 | tokenizer = OpenAIClipTokenizer(opts) 25 | out = tokenizer("the quick brown fox jumped over the lazy dog") 26 | 27 | expected_data = [ 28 | 49406, # Start token id 29 | 518, 30 | 3712, 31 | 2866, 32 | 3240, 33 | 16901, 34 | 962, 35 | 518, 36 | 10753, 37 | 1929, 38 | 49407, # end token id 39 | ] 40 | expected_out = torch.tensor(expected_data, dtype=out.dtype) 41 | torch.testing.assert_close(actual=out, expected=expected_out) 42 | assert tokenizer.sot_token == "<|startoftext|>" 43 | assert tokenizer.eot_token == "<|endoftext|>" 44 | assert tokenizer.sot_token_id == 49406 45 | assert tokenizer.eot_token_id == 49407 46 | -------------------------------------------------------------------------------- /tests/data/video_reader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/data/video_reader/__init__.py -------------------------------------------------------------------------------- /tests/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/engine/__init__.py -------------------------------------------------------------------------------- /tests/loss_fns/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /tests/loss_fns/language_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/loss_fns/language_modeling/__init__.py -------------------------------------------------------------------------------- /tests/loss_fns/test_class_weighting.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | 6 | from corenet.loss_fn.utils.class_weighting import compute_class_weights 7 | 8 | 9 | def test_class_weighting(): 10 | # test for checking the class weighting method 11 | targets = torch.tensor([1, 1, 1, 2, 2, 3], dtype=torch.long) 12 | n_classes = 4 13 | norm_val = 1.0 14 | 15 | weights = compute_class_weights( 16 | target=targets, n_classes=n_classes, norm_val=norm_val 17 | ) 18 | weights = torch.round(weights, decimals=2) 19 | 20 | expected_weights = torch.tensor([0.0, 2.47, 3.48, 6.49]) 21 | 22 | torch.testing.assert_allclose(actual=weights, expected=expected_weights) 23 | -------------------------------------------------------------------------------- /tests/loss_fns/test_contrastive_loss.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | import pytest 7 | import torch 8 | 9 | from corenet.loss_fn.multi_modal_img_text.contrastive_loss_clip import ( 10 | ContrastiveLossClip, 11 | ) 12 | 13 | 14 | @pytest.mark.parametrize("batch_size", [1, 2]) 15 | @pytest.mark.parametrize("projection_dim", [256, 512]) 16 | def test_contrastive_loss_in_out(batch_size: int, projection_dim: int) -> None: 17 | # These tests check the input and output formats are correct or not. 18 | parser = argparse.ArgumentParser() 19 | parser = ContrastiveLossClip.add_arguments(parser) 20 | 21 | opts = parser.parse_args([]) 22 | criteria = ContrastiveLossClip(opts) 23 | 24 | image_features = torch.randn(size=(batch_size, projection_dim)) 25 | text_features = torch.randn(size=(batch_size, projection_dim)) 26 | 27 | input_sample = None 28 | targets = None 29 | 30 | prediction = {"image": image_features, "text": text_features} 31 | 32 | loss_output = criteria(input_sample, prediction, targets) 33 | expected_output_keys = {"total_loss", "image_loss", "text_loss", "logit_scale"} 34 | assert expected_output_keys.issubset(loss_output.keys()) 35 | 36 | for loss_name, loss_val in loss_output.items(): 37 | if loss_name == "logit_scale" and isinstance(loss_val, (float, int)): 38 | loss_val = torch.tensor(loss_val) 39 | assert isinstance( 40 | loss_val, torch.Tensor 41 | ), "Loss should be an instance of torch.Tensor" 42 | assert loss_val.dim() == 0, "Loss value should be a scalar" 43 | -------------------------------------------------------------------------------- /tests/loss_fns/test_neural_aug.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | 6 | import pytest 7 | import torch 8 | 9 | from corenet.loss_fn.neural_augmentation import NeuralAugmentation 10 | 11 | 12 | @pytest.mark.parametrize("batch_size", [1, 2]) 13 | def test_neural_aug_loss_in_out(batch_size: int) -> None: 14 | # These tests check the input and output formats are correct or not. 15 | # build configuration 16 | parser = argparse.ArgumentParser() 17 | parser = NeuralAugmentation.add_arguments(parser) 18 | 19 | opts = parser.parse_args([]) 20 | setattr(opts, "scheduler.max_epochs", 20) 21 | 22 | # build loss function 23 | neural_aug_loss_fn = NeuralAugmentation(opts) 24 | pred_tensor = { 25 | "augmented_tensor": torch.zeros( 26 | size=(batch_size, 3, 224, 224), dtype=torch.float 27 | ) 28 | } 29 | 30 | # Three input cases: 31 | # Case 1: Input image is a tensor 32 | # Case 2: Input is a dictionary, with image as a mandatory key and value as a batch of input image tensor 33 | # Case 3: Input is a dictionary, with image as a mandatory key and value as a list of input image tensor 34 | input_case_1 = torch.randint(low=0, high=1, size=(batch_size, 3, 224, 224)) 35 | input_case_2 = { 36 | "image": torch.randint(low=0, high=1, size=(batch_size, 3, 224, 224)) 37 | } 38 | input_case_3 = { 39 | "image": [torch.randint(low=0, high=1, size=(1, 3, 224, 224))] * batch_size 40 | } 41 | 42 | for inp in [input_case_1, input_case_2, input_case_3]: 43 | loss = neural_aug_loss_fn(inp, pred_tensor) 44 | assert isinstance( 45 | loss, torch.Tensor 46 | ), "Loss should be an instance of torch.Tensor" 47 | assert loss.dim() == 0, "Loss value should be a scalar" 48 | -------------------------------------------------------------------------------- /tests/loss_fns/test_neural_aug_compatibility.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import re 5 | import sys 6 | from pathlib import Path 7 | 8 | sys.path.append("..") 9 | 10 | from tests.configs import get_config 11 | from tests.modeling.test_model import exclude_yaml_from_test 12 | 13 | 14 | def test_neural_aug_backward_compatibility(config_file: str): 15 | opts = get_config(config_file=config_file) 16 | 17 | opts_dict = vars(opts) 18 | for k, v in opts_dict.items(): 19 | if isinstance(v, str) and re.search(".*_with_na$", v): 20 | raise DeprecationWarning( 21 | "We deprecated the usage of _with_na loss functions. " 22 | "Please see projects/range_augment for examples." 23 | ) 24 | 25 | 26 | def pytest_generate_tests(metafunc): 27 | configs = [ 28 | str(x) 29 | for x in Path("config").rglob("**/*.yaml") 30 | if not exclude_yaml_from_test(x) 31 | ] 32 | configs += [ 33 | str(x) 34 | for x in Path("projects").rglob("**/*.yaml") 35 | if not exclude_yaml_from_test(x) 36 | ] 37 | metafunc.parametrize("config_file", configs) 38 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/metrics/test_image_text_retrieval_metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import pytest 7 | import torch 8 | 9 | from corenet.metrics.stats import Statistics 10 | from tests.configs import default_training_opts 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "batch_size, num_captions, hidden_dim, text_dim", 15 | [ 16 | (1, 1, 8, 2), 17 | (2, 5, 4, 3), 18 | ], 19 | ) 20 | def test_image_text_retrieval( 21 | batch_size: int, num_captions: int, hidden_dim: int, text_dim: int 22 | ) -> None: 23 | stats = Statistics( 24 | opts=default_training_opts(), metric_names=["image_text_retrieval"] 25 | ) 26 | for _ in range(3): 27 | image_emb = torch.randn(batch_size, hidden_dim) 28 | text_emb = torch.randn(batch_size, num_captions, hidden_dim) 29 | if text_dim == 2: 30 | text_emb = text_emb.reshape(-1, hidden_dim) 31 | stats.update({"image": image_emb, "text": text_emb}, {}, {}) 32 | 33 | metrics = stats._compute_avg_statistics_all() 34 | img_text_metrics = metrics["image_text_retrieval"] 35 | 36 | parent_keys = ["text2image", "image2text"] 37 | child_keys = ["recall@1", "recall@5", "recall@10", "mean_rank", "median_rank"] 38 | for parent_key in parent_keys: 39 | assert parent_key in img_text_metrics 40 | for child_key in child_keys: 41 | assert child_key in img_text_metrics[parent_key] 42 | -------------------------------------------------------------------------------- /tests/metrics/test_iou.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Callable 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from corenet.metrics.stats import Statistics 12 | from tests.metrics.base import transform_args 13 | 14 | 15 | def test_gather_iou_metrics(transform_args: Callable): 16 | # [Batch, num_classes, height, width] 17 | # in this example, [1, 2, 2, 3] 18 | prediction = torch.tensor( 19 | [ 20 | [ 21 | [[0.2, 0.8, 0.2], [0.9, 0.2, 0.1]], 22 | [[0.8, 0.2, 0.8], [0.1, 0.8, 0.9]], # spatial dms 23 | ] # classes 24 | ] # batch 25 | ) 26 | 27 | target = torch.tensor([[[0, 0, 0], [0, 1, 1]]]) 28 | 29 | metric_names, stats_args = transform_args(["iou"], prediction, target) 30 | 31 | expected_inter = np.array([2.0, 2.0]) 32 | expected_union = np.array([4.0, 4.0]) 33 | 34 | expected_iou = np.mean(expected_inter / expected_union) * 100 35 | 36 | stats = Statistics(opts=None, metric_names=metric_names) 37 | stats.update(*stats_args) 38 | 39 | np.testing.assert_equal( 40 | actual=stats.avg_statistics(metric_names[0]), desired=expected_iou 41 | ) 42 | -------------------------------------------------------------------------------- /tests/metrics/test_probability_histogram.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Callable 7 | 8 | import numpy as np 9 | 10 | from corenet.metrics.stats import Statistics 11 | from tests.configs import default_training_opts 12 | from tests.metrics.base import sample_classification_outputs, transform_args 13 | 14 | 15 | def test_probability_histogram(transform_args: Callable): 16 | metric_names, stats_args = transform_args( 17 | ["prob_hist"], *sample_classification_outputs() 18 | ) 19 | 20 | stats = Statistics(opts=default_training_opts(), metric_names=metric_names) 21 | stats.update(*stats_args) 22 | 23 | # max values -> 0.91, 0.81, 0.51 24 | max_conf_hist = stats.avg_statistics(metric_names[0], "max") 25 | np.testing.assert_almost_equal( 26 | max_conf_hist, 27 | [0, 0, 0, 0, 0, 0.33, 0, 0, 0.33, 0.33], 28 | decimal=2, 29 | ) 30 | 31 | # target values -> 0.05, 0.16, 0.51 32 | target_conf_hist = stats.avg_statistics(metric_names[0], "target") 33 | np.testing.assert_almost_equal( 34 | target_conf_hist, 35 | [0.33, 0.33, 0, 0, 0, 0.33, 0, 0, 0, 0], 36 | decimal=2, 37 | ) 38 | -------------------------------------------------------------------------------- /tests/metrics/test_psnr.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import math 7 | from typing import Callable 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from corenet.metrics.stats import Statistics 13 | from tests.metrics.base import transform_args 14 | 15 | 16 | def test_gather_psnr_metrics(transform_args: Callable): 17 | # Test for case 1 18 | inp_tensor = torch.randn((3, 2), dtype=torch.float) 19 | target_tensor = inp_tensor 20 | 21 | # Ideally, the PSNR should be infinite when input and target are the same, because error between 22 | # signal and noise is 0. However, we add a small eps value (error of 1e-10) in the computation 23 | # for numerical stability. Therefore, PSNR will not be infinite. 24 | expected_psnr = 10.0 * math.log10(255.0**2 / 1e-10) 25 | 26 | metric_names, stats_args = transform_args(["psnr"], inp_tensor, target_tensor) 27 | 28 | stats = Statistics(opts=None, metric_names=metric_names) 29 | stats.update(*stats_args) 30 | 31 | np.testing.assert_almost_equal( 32 | stats.avg_statistics(metric_names[0]), expected_psnr, decimal=2 33 | ) 34 | -------------------------------------------------------------------------------- /tests/metrics/test_topk_accuracy.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Callable 7 | 8 | import numpy as np 9 | 10 | from corenet.metrics.stats import Statistics 11 | from tests.metrics.base import sample_classification_outputs, transform_args 12 | 13 | 14 | def test_gather_top_k_metrics(transform_args: Callable): 15 | metric_names, stats_args = transform_args( 16 | ["top1", "top5"], *sample_classification_outputs() 17 | ) 18 | 19 | stats = Statistics(opts=None, metric_names=metric_names) 20 | stats.update(*stats_args) 21 | top1_acc = round(stats.avg_statistics(metric_names[0]), 2) 22 | top5_acc = round(stats.avg_statistics(metric_names[1]), 2) 23 | 24 | np.testing.assert_almost_equal(top1_acc, 33.33, decimal=2) 25 | np.testing.assert_almost_equal(top5_acc, 100.00, decimal=2) 26 | -------------------------------------------------------------------------------- /tests/metrics/test_vqa_preset_score_metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """Test for metrics/vqa_score.py.""" 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from corenet.metrics.stats import Statistics 11 | 12 | 13 | def test_vqa_preset_score() -> None: 14 | predictions = { 15 | "logits": torch.tensor( 16 | [ 17 | [0, 0, 1], 18 | [0, 0, 1], 19 | [0, 0, 1], 20 | ], 21 | dtype=torch.float, 22 | ) 23 | } 24 | targets = torch.tensor( 25 | [ 26 | [0, 0, 1], 27 | [0, 1, 0], 28 | [0, 0.5, 0.5], 29 | ], 30 | dtype=torch.float, 31 | ) 32 | 33 | stats = Statistics(opts=None, metric_names=["vqa_preset_score"]) 34 | stats.update(predictions, targets) 35 | score = round(stats.avg_statistics("vqa_preset_score", "bbox"), 2) 36 | 37 | np.testing.assert_almost_equal(score, 50.0) 38 | -------------------------------------------------------------------------------- /tests/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/misc/dummy_clip_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | common: 4 | debug_mode: true 5 | 6 | dataset: 7 | name: "img_text_tar" 8 | category: "multi_modal_image_text" 9 | 10 | model: 11 | activation_checkpointing: true 12 | freeze_modules: [ 13 | "image_encoder.transformer", 14 | "image_encoder.cls_token", 15 | "image_encoder.patch_emb", 16 | "image_encoder.post_transformer_norm", 17 | "image_encoder.pos_embed", 18 | ] 19 | multi_modal_image_text: 20 | name: "clip" 21 | clip: 22 | projection_dim: 128 23 | classification: 24 | name: "vit" 25 | vit: 26 | mode: "tiny" 27 | norm_layer: "layer_norm_fp32" 28 | image_projection_head: 29 | name: "simple_projection_nc2nc" 30 | text: 31 | name: "transformer" 32 | vocab_size: 200 33 | context_length: 77 34 | transformer: 35 | causal_masking: true 36 | model_dim: 128 37 | n_transformer_layers: 1 38 | ffn_multiplier_per_layer: 4.0 39 | n_heads_per_layer: 8 40 | norm_layer: "layer_norm_fp32" 41 | activation: 42 | name: "gelu" 43 | -------------------------------------------------------------------------------- /tests/misc/dummy_linear_probe_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | common: 4 | debug_mode: true 5 | 6 | dataset: 7 | name: "imagenet" 8 | category: "classification" 9 | 10 | model: 11 | resume_exclude_scopes: 12 | - "image_encoder.neural_augmentor." 13 | - "neural_augmentor." 14 | - "text_encoder." 15 | - "logit_scale" 16 | - "image_encoder.classifier.proj" 17 | ignore_missing_scopes: ["classifier."] 18 | learn_augmentation: 19 | mode: None 20 | rename_scopes_map: [["image_encoder.", ""]] 21 | freeze_modules: "^((?!classifier).)*$" 22 | classification: 23 | name: "vit" 24 | vit: 25 | mode: "base" 26 | dropout: 0.2 27 | activation: 28 | name: "gelu" 29 | activation: 30 | name: "gelu" 31 | normalization: 32 | name: "batch_norm" 33 | momentum: 0.1 34 | layer: 35 | conv_init: "kaiming_normal" 36 | linear_init: "trunc_normal" 37 | linear_init_std_dev: 0.02 38 | -------------------------------------------------------------------------------- /tests/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/__init__.py -------------------------------------------------------------------------------- /tests/modeling/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/layers/__init__.py -------------------------------------------------------------------------------- /tests/modeling/layers/normalization_layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/layers/normalization_layers/__init__.py -------------------------------------------------------------------------------- /tests/modeling/layers/normalization_layers/test_rms_norm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | 8 | from corenet.modeling.layers.normalization.rms_norm import RMSNorm 9 | 10 | 11 | def test_rms_norm() -> None: 12 | in_features = 16 13 | norm_layer = RMSNorm(num_features=in_features) 14 | 15 | inputs = [ 16 | # 3D inputs (e.g., Transformers) 17 | torch.randn(size=(2, 4, in_features)), 18 | # 4D inputs (e.g., CNNs) 19 | torch.randn(size=(2, 4, 5, in_features)), 20 | # 2D inputs (e.g., Linear) 21 | torch.randn(size=(2, in_features)), 22 | ] 23 | for inp in inputs: 24 | out = norm_layer(inp) 25 | assert out.shape == inp.shape 26 | # check if there are any NaNs in the output. 27 | assert not torch.any(torch.isnan(out)) 28 | -------------------------------------------------------------------------------- /tests/modeling/layers/test_pos_embeddings.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | 8 | import numpy as np 9 | import pytest 10 | 11 | sys.path.append("..") 12 | 13 | from corenet.modeling.layers.positional_embedding import PositionalEmbedding 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "is_learnable, input_seq_len, sequence_first, padding_idx", 18 | [ 19 | (True, 34, True, None), 20 | (False, 128, False, 0), 21 | (False, 192, True, 0), 22 | ], 23 | ) 24 | def test_pos_embedding( 25 | is_learnable: bool, input_seq_len: int, sequence_first: bool, padding_idx: int 26 | ): 27 | num_embeddings = 128 28 | pos_embedding = PositionalEmbedding( 29 | opts=None, 30 | num_embeddings=num_embeddings, 31 | embedding_dim=512, 32 | padding_idx=padding_idx, 33 | is_learnable=is_learnable, 34 | sequence_first=sequence_first, 35 | ) 36 | seq_dim = 0 if sequence_first else 1 37 | 38 | out = pos_embedding(input_seq_len) 39 | np.testing.assert_equal(out.shape[seq_dim], input_seq_len) 40 | -------------------------------------------------------------------------------- /tests/modeling/layers/test_rotary_embeddings.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import pytest 7 | import torch 8 | 9 | from corenet.modeling.layers import RotaryEmbedding 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "model_dim,n_queries,n_keys,n_groups", 14 | [ 15 | (18, 5, 5, 1), 16 | (18, 5, 6, 4), 17 | ], 18 | ) 19 | def test_rotary_embedding( 20 | model_dim: int, n_queries: int, n_keys: int, n_groups: int 21 | ) -> None: 22 | """Test for RoPE embeddings.""" 23 | rope_embedding = RotaryEmbedding( 24 | model_dim=model_dim, 25 | # setting max_seq_length to the same as number of queries. 26 | # When n_keys > n_queries, then cos and sine embeddings are re-computed. 27 | max_seq_length=n_queries, 28 | ) 29 | 30 | batch_size = 2 31 | n_query_heads = 16 32 | # When n_groups != 1, RoPE with GQA is tested 33 | n_key_heads = n_query_heads // n_groups 34 | 35 | query_tensor = torch.randn( 36 | size=(batch_size, n_query_heads, n_queries, model_dim), 37 | dtype=torch.bfloat16, 38 | device=torch.device("cpu"), 39 | ) 40 | key_tensor = torch.randn( 41 | size=(batch_size, n_key_heads, n_keys, model_dim), 42 | dtype=torch.bfloat16, 43 | device=torch.device("cpu"), 44 | ) 45 | 46 | query_tensor_with_rope, key_tensor_with_rope = rope_embedding( 47 | query_tensor, key_tensor 48 | ) 49 | assert rope_embedding._cached_seq_length == n_keys 50 | assert query_tensor.shape == query_tensor_with_rope.shape 51 | assert key_tensor.shape == key_tensor_with_rope.shape 52 | 53 | assert torch.isnan(query_tensor_with_rope).to(torch.bool).sum() == 0 54 | assert torch.isnan(key_tensor_with_rope).to(torch.bool).sum() == 0 55 | -------------------------------------------------------------------------------- /tests/modeling/models/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /tests/modeling/models/audio_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/models/audio_classification/__init__.py -------------------------------------------------------------------------------- /tests/modeling/models/audio_classification/test_base_audio_classification.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | from corenet.modeling.models.audio_classification import base_audio_classification 9 | 10 | 11 | def test_base_audio_classification_adds_arguments() -> None: 12 | opts = argparse.Namespace() 13 | model = base_audio_classification.BaseAudioClassification(opts) 14 | 15 | parser = argparse.ArgumentParser() 16 | model.add_arguments(parser) 17 | assert hasattr(parser.parse_args([]), "model.audio_classification.name") 18 | -------------------------------------------------------------------------------- /tests/modeling/models/audio_classification/test_byteformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | 8 | from corenet.modeling.models.audio_classification import audio_byteformer 9 | from corenet.modeling.models.classification import byteformer as image_byteformer 10 | from tests.modeling.models.classification import test_byteformer 11 | 12 | 13 | def test_audio_byteformer() -> None: 14 | # Make sure it matches the image classification network. 15 | opts = test_byteformer.get_opts() 16 | 17 | byteformer1 = image_byteformer.ByteFormer(opts) 18 | byteformer2 = audio_byteformer.AudioByteFormer(opts) 19 | 20 | # Make their state_dicts match. 21 | byteformer2.load_state_dict(byteformer1.state_dict()) 22 | 23 | batch_size, sequence_length = 2, 32 24 | 25 | x = torch.randint(0, 128, [batch_size, sequence_length]) 26 | 27 | assert torch.all(byteformer1(x) == byteformer2({"audio": x})) 28 | -------------------------------------------------------------------------------- /tests/modeling/models/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/models/classification/__init__.py -------------------------------------------------------------------------------- /tests/modeling/models/classification/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/models/classification/config/__init__.py -------------------------------------------------------------------------------- /tests/modeling/models/classification/config/test_byteformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import pytest 9 | 10 | from corenet.modeling.models.classification.config import byteformer 11 | 12 | 13 | @pytest.mark.parametrize("mode", ["tiny", "small", "base", "huge"]) 14 | def test_get_configuration(mode) -> None: 15 | opts = argparse.Namespace() 16 | setattr(opts, "model.classification.byteformer.mode", mode) 17 | setattr(opts, "model.classification.byteformer.dropout", 0.0) 18 | setattr(opts, "model.classification.byteformer.norm_layer", "layer_norm") 19 | byteformer.get_configuration(opts) 20 | -------------------------------------------------------------------------------- /tests/modeling/models/classification/config/vit_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | name: "imagenet" 5 | category: "classification" 6 | 7 | model: 8 | classification: 9 | name: "vit" 10 | vit: 11 | mode: "test" 12 | norm_layer: "layer_norm_fp32" 13 | activation: 14 | name: "gelu" 15 | -------------------------------------------------------------------------------- /tests/modeling/models/language_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/models/language_modeling/__init__.py -------------------------------------------------------------------------------- /tests/modeling/models/language_modeling/config/gpt_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | name: "general_lm" 5 | category: "language_modeling" 6 | 7 | model: 8 | language_modeling: 9 | name: "general_gpt" 10 | general_gpt: 11 | model_name: "gpt-test" 12 | vocab_size: 100 13 | max_context_length: 100 14 | -------------------------------------------------------------------------------- /tests/modeling/models/language_modeling/config/kv_prediction_config.yaml: -------------------------------------------------------------------------------- 1 | # pytest: disable 2 | 3 | dataset: 4 | name: "general_lm" 5 | category: "language_modeling" 6 | 7 | model: 8 | language_modeling: 9 | name: "kv_prediction" 10 | kv_prediction: 11 | auxkv_num_layers_to_basekv_num_layers: [0, 0, 1] 12 | base_model: 13 | - model: 14 | language_modeling: 15 | name: "layer_pruned_general_gpt" 16 | general_gpt: 17 | model_name: "gpt-test-base" 18 | auxiliary_model: 19 | - model: 20 | language_modeling: 21 | name: "layer_pruned_general_gpt" 22 | general_gpt: 23 | model_name: "gpt-test-aux" 24 | general_gpt: 25 | vocab_size: 100 26 | max_context_length: 100 27 | -------------------------------------------------------------------------------- /tests/modeling/models/test_activation_checkpointing_wrapper.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | 7 | from typing import List, Tuple, Union 8 | 9 | import pytest 10 | import torch 11 | from torch import nn 12 | 13 | from corenet.utils.activation_checkpointing_wrapper import activation_checkpointing 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "activation_checkpointing_module_and_count", 18 | [ 19 | # _checkpoint_wrapped_module is added for each trainable parameter (e.g., weight and bias) in a layer. 20 | (nn.Linear, 3), 21 | (nn.Conv1d, 2), 22 | ([nn.Linear, nn.Conv1d], 5), 23 | ], 24 | ) 25 | def test_activation_checkpointing( 26 | activation_checkpointing_module_and_count: Tuple[ 27 | Union[torch.nn.Module, List[torch.nn.Module]], int 28 | ] 29 | ): 30 | 31 | ( 32 | activation_checkpoint_module, 33 | expected_activation_checkpoinitng_layers, 34 | ) = activation_checkpointing_module_and_count 35 | # dummy model 36 | model = torch.nn.Sequential( 37 | nn.Linear(10, 10, bias=False), 38 | nn.Conv1d(10, 10, kernel_size=1), 39 | nn.Linear(10, 10), 40 | nn.AvgPool1d(kernel_size=1), 41 | ) 42 | 43 | activation_checkpointing(model, submodule_class=activation_checkpoint_module) 44 | 45 | num_ckpt_modules = 0 46 | for p_name, _ in model.named_parameters(): 47 | if p_name.find("_checkpoint_wrapped_module") > -1: 48 | num_ckpt_modules += 1 49 | 50 | assert num_ckpt_modules == expected_activation_checkpoinitng_layers 51 | -------------------------------------------------------------------------------- /tests/modeling/models/test_neural_aug_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | 8 | import pytest 9 | 10 | from corenet.modeling.neural_augmentor.utils.neural_aug_utils import * 11 | 12 | 13 | @pytest.mark.parametrize("noise_var", [0.0001, 0.01, 0.1]) 14 | def test_random_noise(noise_var): 15 | in_channels = 3 16 | in_height = 224 17 | in_width = 224 18 | x = torch.ones(size=(1, in_channels, in_width, in_height), dtype=torch.float) 19 | 20 | aug_out = random_noise(x, variance=torch.tensor(noise_var, dtype=torch.float)) 21 | 22 | torch.testing.assert_allclose(actual=x.shape, expected=aug_out.shape) 23 | 24 | 25 | @pytest.mark.parametrize("magnitude", [0.1, 1.0, 2.0]) 26 | def test_random_brightness(magnitude): 27 | in_channels = 3 28 | in_height = 224 29 | in_width = 224 30 | x = torch.ones(size=(1, in_channels, in_width, in_height), dtype=torch.float) 31 | 32 | aug_out = random_brightness(x, magnitude=torch.tensor(magnitude, dtype=torch.float)) 33 | 34 | torch.testing.assert_allclose(actual=x.shape, expected=aug_out.shape) 35 | 36 | 37 | @pytest.mark.parametrize("magnitude", [0.1, 1.0, 2.0]) 38 | def test_random_contrast(magnitude): 39 | in_channels = 3 40 | in_height = 224 41 | in_width = 224 42 | x = torch.ones(size=(1, in_channels, in_width, in_height), dtype=torch.float) 43 | 44 | aug_out = random_contrast(x, magnitude=torch.tensor(magnitude, dtype=torch.float)) 45 | 46 | torch.testing.assert_allclose(actual=x.shape, expected=aug_out.shape) 47 | -------------------------------------------------------------------------------- /tests/modeling/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/modeling/modules/__init__.py -------------------------------------------------------------------------------- /tests/modeling/modules/test_transformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import torch 9 | 10 | from corenet.modeling.modules import TransformerEncoder 11 | 12 | 13 | def get_opts() -> argparse.Namespace: 14 | opts = argparse.Namespace() 15 | setattr(opts, "model.normalization.groups", None) 16 | setattr(opts, "model.normalization.momentum", 0.9) 17 | setattr(opts, "model.activation.name", "relu") 18 | setattr(opts, "model.activation.inplace", False) 19 | setattr(opts, "model.activation.neg_slope", False) 20 | return opts 21 | 22 | 23 | def ensure_equal_in_range(t: torch.Tensor, start: int, end: int) -> None: 24 | """ 25 | Ensure values of @t are equal from @start to @end, but not after @end. 26 | 27 | The tensor can have any number of dimensions greater than 0. The first 28 | dimension is the dimension indexed by @start and @end. 29 | 30 | Args: 31 | t: The tensor to check. 32 | start: The start index. 33 | end: The end index. 34 | """ 35 | prototype = t[start] 36 | assert torch.all((prototype - t[start:end]).abs() < 1e-3) 37 | assert torch.all(prototype != t[end:]) 38 | 39 | 40 | def test_masked_attention() -> None: 41 | opts = get_opts() 42 | 43 | B, N, C = 2, 64 + 2, 8 44 | t = TransformerEncoder(opts, embed_dim=C, ffn_latent_dim=4 * C) 45 | prototype = torch.randn([C]) 46 | x = torch.ones([B, N, C]) 47 | x[:, :] = prototype 48 | 49 | key_padding_mask = torch.zeros([B, N]) 50 | key_padding_mask[0, 63:] = float("-inf") 51 | # Mask the @x values at the masked positions. 52 | x[0, 63:] = 0 53 | 54 | y = t(x, key_padding_mask=key_padding_mask) 55 | 56 | prototype = y[0, 0] 57 | assert torch.all(prototype == y[0, :63]) 58 | assert torch.all(prototype != y[0, 63:]) 59 | 60 | prototype = y[1, 0] 61 | assert torch.all(prototype == y[1, :]) 62 | -------------------------------------------------------------------------------- /tests/optims/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/optims/__init__.py -------------------------------------------------------------------------------- /tests/optims/scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/optims/scheduler/__init__.py -------------------------------------------------------------------------------- /tests/options/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/options/test_parse_args.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | from typing import Any, Dict, List, Tuple, Union 7 | 8 | import pytest 9 | 10 | from corenet.options.parse_args import JsonValidator 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "expected_type,valid,valid_parsed,invalid", 15 | [ 16 | (None, "null", None, "1"), 17 | (int, "1", 1, "1.0"), 18 | (float, "1.0", 1.0, '"1"'), 19 | (float, "1", 1.0, "s"), 20 | (bool, "true", True, "null"), 21 | (List[int], "[1, 2,3]", [1, 2, 3], "{1: 2}"), 22 | (List[int], "[]", [], '["s"]'), 23 | (Tuple[int, int], "[1, 2]", (1, 2), "[1, 2, 3]"), 24 | (Dict[str, Tuple[int, float]], '{"x": [1, 2]}', {"x": (1, 2.0)}, '{"x": "y"}'), 25 | (Union[Tuple[int, Any], int], "[1,null]", (1, None), "[null,1]"), 26 | (Union[Tuple[int, int], int], "1", 1, '"1"'), 27 | ], 28 | ) 29 | def test_json_validator( 30 | expected_type: type, valid: str, valid_parsed: Any, invalid: str 31 | ): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--x", type=JsonValidator(expected_type)) 34 | 35 | class ArgparseFailure(Exception): 36 | pass 37 | 38 | def _exit(status, message): 39 | raise ArgparseFailure(f"Unexpected argparse failure: {message}") 40 | 41 | parser.exit = ( 42 | _exit # override exit to raise exception, rather than invoking sys.exit() 43 | ) 44 | 45 | opts = parser.parse_args([f"--x={valid}"]) 46 | assert opts.x == valid_parsed 47 | assert repr(opts.x) == repr(valid_parsed) # check types 48 | 49 | with pytest.raises(ArgparseFailure): 50 | parser.parse_args([f"--x={invalid}"]) 51 | -------------------------------------------------------------------------------- /tests/options/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from tests.configs import get_config 10 | 11 | 12 | def test_load_config_file_produces_no_false_warnings() -> None: 13 | get_config() 14 | 15 | 16 | def test_load_config_file_produces_true_warning( 17 | tmp_path: Path, 18 | ) -> None: 19 | config_path = tmp_path.joinpath("config.yaml") 20 | config_path.write_text("an_invalid_key: 2") 21 | with pytest.raises(ValueError, match="an_invalid_key"): 22 | get_config(str(config_path)) 23 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import re 7 | from typing import Tuple 8 | 9 | import pytest 10 | 11 | 12 | def unset_pretrained_models_from_opts(opts: argparse.Namespace) -> None: 13 | """Unset the argument corresponding to pretrained model path in opts during tests""" 14 | opts_as_dict = vars(opts) 15 | for k, v in opts_as_dict.items(): 16 | if is_pretrained_model_key(k): 17 | setattr(opts, k, None) 18 | 19 | 20 | def is_pretrained_model_key(key_name: str) -> bool: 21 | """Check if arguments corresponding to model have a pretrained key or not.""" 22 | return True if re.search(r".*model\..*\.pretrained$", key_name) else False 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "key_name_expected_output", 27 | [ 28 | ("model.classification.pretrained", True), 29 | ("model.segmentation.pretrained", True), 30 | ("model.video_classification.pretrained", True), 31 | ("teacher.model.classification.pretrained", True), 32 | ("loss.classification.pretrained", False), 33 | ("model.classification.pretrained_dummy", False), 34 | ("model.classification.mypretrained", False), 35 | ("model.classification.my.pretrained", True), 36 | ], 37 | ) 38 | def test_is_pretrained_model_key(key_name_expected_output: Tuple[str, bool]): 39 | key_name = key_name_expected_output[0] 40 | expected_output = key_name_expected_output[1] 41 | assert is_pretrained_model_key(key_name) == expected_output 42 | -------------------------------------------------------------------------------- /tests/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tests/transforms/__init__.py -------------------------------------------------------------------------------- /tests/transforms/test_audio_bytes.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import argparse 7 | 8 | import pytest 9 | import torch 10 | 11 | from corenet.data.transforms import audio_bytes 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "format,encoding_dtype,num_samples,expected_length", 16 | [ 17 | ("wav", "float32", 4, 74), 18 | ("wav", "float32", 8, 90), 19 | ("wav", "int32", 8, 112), 20 | ("wav", "int16", 8, 60), 21 | ("wav", "uint8", 8, 52), 22 | ("mp3", None, 8, 216), 23 | ], 24 | ) 25 | def test_audio_save(format, encoding_dtype, num_samples, expected_length) -> None: 26 | opts = argparse.Namespace() 27 | setattr(opts, "audio_augmentation.torchaudio_save.encoding_dtype", encoding_dtype) 28 | setattr(opts, "audio_augmentation.torchaudio_save.format", format) 29 | setattr(opts, "audio_augmentation.torchaudio_save.backend", "sox") 30 | t = audio_bytes.TorchaudioSave(opts) 31 | 32 | x = { 33 | "samples": {"audio": torch.randn([2, num_samples])}, 34 | "metadata": {"audio_fps": 16}, 35 | } 36 | 37 | outputs = t(x)["samples"]["audio"] 38 | assert torch.all(0 <= outputs) 39 | assert torch.all(outputs <= 255) 40 | assert outputs.shape == (expected_length,) 41 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /tests/utils/test_check.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import unittest 6 | 7 | from corenet.utils.check import check 8 | 9 | 10 | class TestCheck(unittest.TestCase): 11 | 12 | def test_ok(self): 13 | check(True) 14 | check(1) 15 | check([1]) 16 | 17 | def test_fail(self): 18 | with self.assertRaises(AssertionError): 19 | check(False) 20 | with self.assertRaises(AssertionError): 21 | check(0) 22 | with self.assertRaises(AssertionError): 23 | check([]) 24 | 25 | def test_custom_raise(self): 26 | with self.assertRaisesRegex(AssertionError, "phooey"): 27 | check(False, "phooey") 28 | with self.assertRaisesRegex(ValueError, "phooey"): 29 | check(False, ValueError("phooey")) 30 | with self.assertRaisesRegex(AssertionError, "phooey"): 31 | check(False, lambda: "phooey") 32 | with self.assertRaisesRegex(ValueError, "phooey"): 33 | check(False, lambda: ValueError("phooey")) 34 | with self.assertRaisesRegex(AssertionError, "phooey: 0"): 35 | check(0, lambda x: f"phooey: {x}") 36 | with self.assertRaisesRegex(ValueError, "phooey: 0"): 37 | check(0, lambda x: ValueError(f"phooey: {x}")) 38 | -------------------------------------------------------------------------------- /tests/utils/test_common_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from corenet.utils.common_utils import unwrap_model_fn 11 | 12 | 13 | def check_models( 14 | original_unwrapped_model: torch.nn.Module, model_after_unwrapping: torch.nn.Module 15 | ) -> None: 16 | """Helper function to test original and unwrapped models are the same.""" 17 | for layer_id in range(len(original_unwrapped_model)): 18 | # for unwrapped models, we should be able to index them 19 | assert repr(model_after_unwrapping[layer_id]) == repr( 20 | original_unwrapped_model[layer_id] 21 | ) 22 | 23 | 24 | def test_unwrap_model_fn(): 25 | """Test for unwrap_model_fn""" 26 | 27 | dummy_model = torch.nn.Sequential( 28 | torch.nn.Linear(10, 20), 29 | torch.nn.Linear(20, 40), 30 | ) 31 | 32 | # test DataParallel wrapping 33 | wrapped_model_dp = torch.nn.DataParallel(dummy_model) 34 | unwrapped_model_dp = unwrap_model_fn(wrapped_model_dp) 35 | check_models(dummy_model, unwrapped_model_dp) 36 | 37 | # Initialize the distributed environment 38 | os.environ["MASTER_ADDR"] = "localhost" 39 | os.environ["MASTER_PORT"] = "1234" 40 | dist.init_process_group(backend="gloo", rank=0, world_size=1) 41 | 42 | # test DDP wrapping 43 | wrapped_model_ddp = torch.nn.parallel.DistributedDataParallel(dummy_model) 44 | unwrapped_model_ddp = unwrap_model_fn(wrapped_model_ddp) 45 | 46 | check_models(dummy_model, unwrapped_model_ddp) 47 | # clean up DDP environment 48 | dist.destroy_process_group() 49 | -------------------------------------------------------------------------------- /tests/utils/test_dict_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from corenet.utils.dict_utils import filter_keys 6 | 7 | 8 | def test_extract_keys(): 9 | d = {"x": 2, "y": 3, "z": 4} 10 | 11 | assert filter_keys(d, ["x", "y"]) == {"x": 2, "y": 3} 12 | assert filter_keys(d, ["w"]) == {} 13 | -------------------------------------------------------------------------------- /tests/utils/test_download_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import Any, List 7 | 8 | import pytest 9 | 10 | from corenet.options.opts import get_training_arguments 11 | from corenet.utils.download_utils import download_assets_in_parallel 12 | 13 | 14 | def dummy_download_fn(index: int, local_dst_dir: str, args, kwargs) -> None: 15 | """Dummy download function. 16 | 17 | Tests if kwargs passed from 'download_assets_in_parallel' can be accessed inside 'dummy_download_fn'. 18 | """ 19 | dummy_kwarg_data = kwargs.get("dummy_kwarg") 20 | # Indexing should not raise an error. 21 | dummy_kwarg_data[index] 22 | 23 | 24 | @pytest.mark.parametrize("asset_names", [["a", "b", "c", "d", "e"], [1, 2, 3], [1]]) 25 | def test_download_assets_in_parallel(asset_names: List[Any]) -> None: 26 | """Test for download_assets_in_parallel function. 27 | 28 | Args: 29 | asset_names: A list of assets that are handled by 'download_func' in 'download_assets_in_parallel'. 30 | """ 31 | function_kwargs = {"dummy_kwarg": asset_names} 32 | opts = get_training_arguments(parse_args=True, args=[]) 33 | 34 | record_indices = download_assets_in_parallel( 35 | opts=opts, 36 | local_dst_dir="trash/dummy_test", 37 | num_assets=len(asset_names), 38 | download_func=dummy_download_fn, 39 | **function_kwargs, 40 | ) 41 | assert len(record_indices) == len(asset_names) 42 | -------------------------------------------------------------------------------- /tests/utils/test_file_logger.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import tempfile 8 | 9 | import pytest 10 | import torch 11 | 12 | from corenet.utils.file_logger import FileLogger 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "metric_name, epoch1, value1, epoch2, value2", 17 | [("metric", 0, 1.0, 1, 2.0), ("metric2", 5, 1.0, 6, 2.0)], 18 | ) 19 | def test_file_logger( 20 | metric_name: str, epoch1: int, value1: float, epoch2: int, value2: float 21 | ) -> None: 22 | with tempfile.TemporaryDirectory() as tempdir: 23 | # Case 1: The file doesn't exist. 24 | filename = os.path.join(tempdir, "stats.pt") 25 | logger = FileLogger(filename) 26 | 27 | logger.add_scalar(metric_name, value1, epoch1) 28 | logger.close() 29 | assert os.path.exists(filename) 30 | 31 | a = torch.load(filename) 32 | assert a == {"epochs": {epoch1: {"metrics": {metric_name: value1}}}} 33 | 34 | # Case 2: The file does exist. 35 | logger = FileLogger(filename) 36 | logger.add_scalar(metric_name, value2, epoch2) 37 | logger.close() 38 | 39 | a = torch.load(filename) 40 | assert a == { 41 | "epochs": { 42 | epoch1: {"metrics": {metric_name: value1}}, 43 | epoch2: {"metrics": {metric_name: value2}}, 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/corenet/1aa3acd17b7b02ba51889b648b1182845c4323eb/tools/__init__.py -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = darwin, linux 3 | 4 | [testenv] 5 | deps = 6 | -r requirements.txt 7 | -r requirements-optional.txt 8 | -r mlx_examples/requirements.txt 9 | commands = 10 | make test-all extra_args="--ignore=corenet/internal --ignore=tests/internal --ignore=experimental" 11 | allowlist_externals = make 12 | 13 | 14 | [testenv:darwin] 15 | platform = darwin 16 | # Use Python 3.9 on macOS (Mac OS 14.4 system Python version) 17 | basepython = python3.9 18 | setenv = 19 | DYLD_LIBRARY_PATH=/opt/homebrew/lib 20 | deps = 21 | {[testenv]deps} 22 | commands = 23 | {[testenv]commands} 24 | 25 | [testenv:linux] 26 | platform = linux 27 | # Use Python 3.10 on Linux (Ubuntu 22.04 system Python version) 28 | basepython = python3.10 29 | deps = 30 | {[testenv]deps} 31 | commands = 32 | {[testenv]commands} 33 | 34 | --------------------------------------------------------------------------------