├── src ├── tests │ ├── __init__.py │ ├── data_loader_test.py │ └── check_tars.py └── laion_clap │ ├── evaluate │ ├── __init__.py │ ├── eval_dcase.py │ ├── eval_retrieval.py │ ├── eval_retrieval_main.py │ └── eval_zeroshot_classification.py │ ├── training │ ├── __init__.py │ ├── audioset_textmap.npy │ ├── scheduler.py │ ├── logger.py │ ├── infer_demo.py │ ├── zero_shot.py │ ├── distributed.py │ └── lp_train.py │ ├── clap_module │ ├── version.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── model_configs │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32.json │ │ ├── ViT-L-14.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50.json │ │ ├── RN50x4.json │ │ ├── RN50x16.json │ │ ├── RN101-quickgelu.json │ │ ├── RN50-quickgelu.json │ │ ├── PANN-6.json │ │ ├── HTSAT-tiny.json │ │ ├── PANN-10.json │ │ ├── PANN-14.json │ │ ├── HTSAT-base.json │ │ ├── HTSAT-large.json │ │ ├── HTSAT-tiny-win-1536.json │ │ ├── PANN-14-fmax-18k.json │ │ ├── PANN-14-fmax-8k-20s.json │ │ ├── PANN-14-win-1536.json │ │ └── PANN-14-tiny-transformer.json │ ├── __init__.py │ ├── transform.py │ ├── bert.py │ ├── linear_probe.py │ ├── timm_model.py │ ├── openai.py │ ├── pretrained.py │ ├── tokenizer.py │ ├── feature_fusion.py │ ├── factory.py │ └── utils.py │ ├── __init__.py │ ├── unit_test.py │ └── hook.py ├── assets ├── logo.PNG ├── audioclip-arch.png └── clap-zeroshot.PNG ├── MANIFEST.in ├── class_labels ├── UrbanSound8K_class_labels_indices.json ├── GTZAN_class_labels.json ├── ESC50_class_labels_indices.json ├── ESC50_class_labels_indices_space.json ├── FSD50k_class_labels_indices.json ├── VGGSound_class_labels_indices.json ├── audioset_class_labels_indices.json └── audioset_fsd50k_class_labels_indices.json ├── requirements.txt ├── experiment_scripts ├── zeroshot_esc50.sh ├── train-only-clotho.sh ├── esc50_api.py ├── eval_retrieval_freesound.sh ├── train-pann-roberta.sh ├── train-htsat-roberta.sh ├── finetune-esc50.sh ├── finetune-fsd50k.sh └── htsat-roberta-large-dataset-fusion.sh ├── pyproject.toml ├── .gitignore └── LICENSE /src/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/laion_clap/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/laion_clap/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.1' 2 | -------------------------------------------------------------------------------- /assets/logo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/CLAP/main/assets/logo.PNG -------------------------------------------------------------------------------- /assets/audioclip-arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/CLAP/main/assets/audioclip-arch.png -------------------------------------------------------------------------------- /assets/clap-zeroshot.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/CLAP/main/assets/clap-zeroshot.PNG -------------------------------------------------------------------------------- /src/laion_clap/training/audioset_textmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/CLAP/main/src/laion_clap/training/audioset_textmap.npy -------------------------------------------------------------------------------- /src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/CLAP/main/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/laion_clap/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | dir_path = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(dir_path) 5 | from .hook import CLAP_Module -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include src/laion_clap/clap_module/model_configs *.json 2 | recursive-include src/laion_clap/clap_module bpe_simple_vocab_16e6.txt.gz 3 | recursive-include src/laion_clap/training audioset_textmap.npy 4 | -------------------------------------------------------------------------------- /class_labels/UrbanSound8K_class_labels_indices.json: -------------------------------------------------------------------------------- 1 | {"air conditioner": 0, "car horn": 1, "children playing": 2, "dog bark": 3, "drilling": 4, "engine idling": 5, "gun shot": 6, "jackhammer": 7, "siren": 8, "street music": 9} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | soundfile 2 | librosa 3 | torchlibrosa 4 | ftfy 5 | braceexpand 6 | webdataset 7 | wget 8 | wandb 9 | llvmlite 10 | scipy 11 | scikit-learn 12 | pandas 13 | h5py 14 | tqdm 15 | regex 16 | transformers<=4.30.2 17 | -------------------------------------------------------------------------------- /class_labels/GTZAN_class_labels.json: -------------------------------------------------------------------------------- 1 | { 2 | "blues": 0, 3 | "classical": 1, 4 | "country": 2, 5 | "disco": 3, 6 | "hiphop": 4, 7 | "jazz": 5, 8 | "metal": 6, 9 | "pop": 7, 10 | "reggae": 8, 11 | "rock": 9 12 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 3 | from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /experiment_scripts/zeroshot_esc50.sh: -------------------------------------------------------------------------------- 1 | # run from CLAP directory 2 | python -m evaluate.eval_zeroshot_classification \ 3 | --dataset-type="webdataset" \ 4 | --precision="fp32" \ 5 | --batch-size=512 \ 6 | --workers=6 \ 7 | --amodel HTSAT-tiny \ 8 | --tmodel roberta \ 9 | --datasetnames "esc50_no_overlap" \ 10 | --remotedata \ 11 | --datasetinfos "train" \ 12 | --seed 3407 \ 13 | --logs ./logs \ 14 | --data-filling "repeatpad" \ 15 | --data-truncating "rand_trunc" \ 16 | --freeze-text \ 17 | --class-label-path="../class_labels/ESC50_class_labels_indices_space.json" \ 18 | --pretrained="/fsx/clap_logs/2023_02_18-00_03_45-model_HTSAT-tiny-lr_0.0001-b_96-j_6-p_fp32/checkpoints" 19 | 20 | -------------------------------------------------------------------------------- /src/laion_clap/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | return _lr_adjuster -------------------------------------------------------------------------------- /class_labels/ESC50_class_labels_indices.json: -------------------------------------------------------------------------------- 1 | {"dog": 0, "rooster": 1, "pig": 2, "cow": 3, "frog": 4, "cat": 5, "hen": 6, "insects": 7, "sheep": 8, "crow": 9, "rain": 10, "sea_waves": 11, "crackling_fire": 12, "crickets": 13, "chirping_birds": 14, "water_drops": 15, "wind": 16, "pouring_water": 17, "toilet_flush": 18, "thunderstorm": 19, "crying_baby": 20, "sneezing": 21, "clapping": 22, "breathing": 23, "coughing": 24, "footsteps": 25, "laughing": 26, "brushing_teeth": 27, "snoring": 28, "drinking_sipping": 29, "door_wood_knock": 30, "mouse_click": 31, "keyboard_typing": 32, "door_wood_creaks": 33, "can_opening": 34, "washing_machine": 35, "vacuum_cleaner": 36, "clock_alarm": 37, "clock_tick": 38, "glass_breaking": 39, "helicopter": 40, "chainsaw": 41, "siren": 42, "car_horn": 43, "engine": 44, "train": 45, "church_bells": 46, "airplane": 47, "fireworks": 48, "hand_saw": 49} -------------------------------------------------------------------------------- /class_labels/ESC50_class_labels_indices_space.json: -------------------------------------------------------------------------------- 1 | {"dog": 0, "rooster": 1, "pig": 2, "cow": 3, "frog": 4, "cat": 5, "hen": 6, "insects": 7, "sheep": 8, "crow": 9, "rain": 10, "sea waves": 11, "crackling fire": 12, "crickets": 13, "chirping birds": 14, "water drops": 15, "wind": 16, "pouring water": 17, "toilet flush": 18, "thunderstorm": 19, "crying baby": 20, "sneezing": 21, "clapping": 22, "breathing": 23, "coughing": 24, "footsteps": 25, "laughing": 26, "brushing teeth": 27, "snoring": 28, "drinking sipping": 29, "door wood knock": 30, "mouse click": 31, "keyboard typing": 32, "door wood creaks": 33, "can opening": 34, "washing machine": 35, "vacuum cleaner": 36, "clock alarm": 37, "clock tick": 38, "glass breaking": 39, "helicopter": 40, "chainsaw": 41, "siren": 42, "car horn": 43, "engine": 44, "train": 45, "church bells": 46, "airplane": 47, "fireworks": 48, "hand saw": 49} -------------------------------------------------------------------------------- /experiment_scripts/train-only-clotho.sh: -------------------------------------------------------------------------------- 1 | python -m laion_clap.training.main \ 2 | --save-frequency 5 \ 3 | --save-top-performance 3 \ 4 | --save-most-recent \ 5 | --dataset-type="webdataset" \ 6 | --datasetpath="" \ 7 | --precision="fp32" \ 8 | --batch-size=96 \ 9 | --lr=1e-4 \ 10 | --wd=0.0 \ 11 | --epochs=45 \ 12 | --workers=6 \ 13 | --use-bn-sync \ 14 | --amodel HTSAT-tiny \ 15 | --tmodel roberta \ 16 | --warmup 3200 \ 17 | --datasetnames "Clotho" \ 18 | --datasetinfos "train" \ 19 | --top-k-checkpoint-select-dataset="Clotho-test" \ 20 | --top-k-checkpoint-select-metric="mAP@10" \ 21 | --logs 'logs' \ 22 | --seed 3407 \ 23 | --gather-with-grad \ 24 | --optimizer "adam" \ 25 | --data-filling "repeatpad" \ 26 | --data-truncating "rand_trunc" \ 27 | --pretrained-audio '/HTSAT-fullset-imagenet-map=0.467.ckpt' \ 28 | --prefetch-factor 2 -------------------------------------------------------------------------------- /src/laion_clap/clap_module/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 2 | CenterCrop 3 | 4 | 5 | def _convert_to_rgb(image): 6 | return image.convert('RGB') 7 | 8 | 9 | def image_transform( 10 | image_size: int, 11 | is_train: bool, 12 | mean=(0.48145466, 0.4578275, 0.40821073), 13 | std=(0.26862954, 0.26130258, 0.27577711) 14 | ): 15 | normalize = Normalize(mean=mean, std=std) 16 | if is_train: 17 | return Compose([ 18 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 19 | _convert_to_rgb, 20 | ToTensor(), 21 | normalize, 22 | ]) 23 | else: 24 | return Compose([ 25 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 26 | CenterCrop(image_size), 27 | _convert_to_rgb, 28 | ToTensor(), 29 | normalize, 30 | ]) 31 | -------------------------------------------------------------------------------- /src/laion_clap/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 3 | model = BertModel.from_pretrained("bert-base-uncased") 4 | text = "Replace me by any text you'd like." 5 | 6 | def bert_embeddings(text): 7 | # text = "Replace me by any text you'd like." 8 | encoded_input = tokenizer(text, return_tensors='pt') 9 | output = model(**encoded_input) 10 | return output 11 | 12 | from transformers import RobertaTokenizer, RobertaModel 13 | 14 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 15 | model = RobertaModel.from_pretrained('roberta-base') 16 | text = "Replace me by any text you'd like." 17 | def Roberta_embeddings(text): 18 | # text = "Replace me by any text you'd like." 19 | encoded_input = tokenizer(text, return_tensors='pt') 20 | output = model(**encoded_input) 21 | return output 22 | 23 | from transformers import BartTokenizer, BartModel 24 | 25 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 26 | model = BartModel.from_pretrained('facebook/bart-base') 27 | text = "Replace me by any text you'd like." 28 | def bart_embeddings(text): 29 | # text = "Replace me by any text you'd like." 30 | encoded_input = tokenizer(text, return_tensors='pt') 31 | output = model(**encoded_input) 32 | return output -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | [project] 5 | name = "laion_clap" 6 | version = "1.1.4" 7 | authors = [ 8 | { name="Ke Chen", email="knutchen@ucsd.edu" }, 9 | { name="Yusong Wu" }, 10 | { name="Tianyu Zhang" }, 11 | { name="Yuchen Hui" } 12 | ] 13 | maintainers = [ 14 | { name="Ke Chen", email="knutchen@ucsd.edu" }, 15 | { name="Yusong Wu" }, 16 | { name="Tianyu Zhang" }, 17 | { name="Yuchen Hui" } 18 | ] 19 | description = "Contrastive Language-Audio Pretraining Model from LAION" 20 | license = {file = "LICENSE"} 21 | readme = "README.md" 22 | requires-python = ">=3.7" 23 | dependencies = [ 24 | "numpy==1.23.5", 25 | "soundfile", 26 | "librosa", 27 | "torchlibrosa", 28 | "ftfy", 29 | "braceexpand", 30 | "webdataset", 31 | "wget", 32 | "wandb", 33 | "llvmlite", 34 | "scipy", 35 | "scikit-learn", 36 | "pandas", 37 | "h5py", 38 | "tqdm", 39 | "regex", 40 | "transformers", 41 | "progressbar" 42 | ] 43 | classifiers = [ 44 | 'Development Status :: 3 - Alpha', 45 | 'Intended Audience :: Developers', 46 | 'Intended Audience :: Science/Research', 47 | 'License :: OSI Approved :: Apache Software License', 48 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 49 | ] 50 | 51 | 52 | [project.urls] 53 | "Homepage" = "https://github.com/LAION-AI/CLAP" 54 | "Bug Tracker" = "https://github.com/LAION-AI/CLAP/issues" -------------------------------------------------------------------------------- /experiment_scripts/esc50_api.py: -------------------------------------------------------------------------------- 1 | import laion_clap 2 | import glob 3 | import json 4 | import torch 5 | import numpy as np 6 | 7 | device = torch.device('cuda:0') 8 | 9 | # download https://drive.google.com/drive/folders/1scyH43eQAcrBz-5fAw44C6RNBhC3ejvX?usp=sharing and extract ./ESC50_1/test/0.tar to ./ESC50_1/test/ 10 | esc50_test_dir = './ESC50_1/test/*/' 11 | class_index_dict_path = '/fsx/yusong/CLAP/class_labels/ESC50_class_labels_indices_space.json' 12 | 13 | # Load the model 14 | model = laion_clap.CLAP_Module(enable_fusion=False, device=device) 15 | model.load_ckpt() 16 | 17 | # Get the class index dict 18 | class_index_dict = {v: k for v, k in json.load(open(class_index_dict_path)).items()} 19 | 20 | # Get all the data 21 | audio_files = sorted(glob.glob(esc50_test_dir + '**/*.flac', recursive=True)) 22 | json_files = sorted(glob.glob(esc50_test_dir + '**/*.json', recursive=True)) 23 | ground_truth_idx = [class_index_dict[json.load(open(jf))['tag'][0]] for jf in json_files] 24 | 25 | with torch.no_grad(): 26 | ground_truth = torch.tensor(ground_truth_idx).view(-1, 1) 27 | 28 | # Get text features 29 | all_texts = ["This is a sound of " + t for t in class_index_dict.keys()] 30 | text_embed = model.get_text_embedding(all_texts) 31 | audio_embed = model.get_audio_embedding_from_filelist(x=audio_files) 32 | 33 | ranking = torch.argsort(torch.tensor(audio_embed) @ torch.tensor(text_embed).t(), descending=True) 34 | preds = torch.where(ranking == ground_truth)[1] 35 | preds = preds.cpu().numpy() 36 | 37 | metrics = {} 38 | metrics[f"mean_rank"] = preds.mean() + 1 39 | metrics[f"median_rank"] = np.floor(np.median(preds)) + 1 40 | for k in [1, 5, 10]: 41 | metrics[f"R@{k}"] = np.mean(preds < k) 42 | # map@10 43 | metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) 44 | 45 | print( 46 | f"Zeroshot Classification Results: " 47 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 48 | ) 49 | -------------------------------------------------------------------------------- /experiment_scripts/eval_retrieval_freesound.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --comment clap 3 | #SBATCH --partition=g40423 4 | #SBATCH --job-name=mclap 5 | #SBATCH --nodes 3 6 | #SBATCH --ntasks-per-node 8 7 | #SBATCH --cpus-per-gpu=6 8 | #SBATCH --exclusive 9 | #SBATCH --output=%x_%j.out 10 | 11 | module load openmpi 12 | module load cuda/11.7 13 | export NCCL_PROTO=simple 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | export NCCL_DEBUG=info 18 | export OMPI_MCA_mtl_base_verbose=1 19 | export FI_EFA_ENABLE_SHM_TRANSFER=0 20 | export FI_PROVIDER=efa 21 | export FI_EFA_TX_MIN_CREDITS=64 22 | export NCCL_TREE_THRESHOLD=0 23 | 24 | # sent to sub script 25 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 26 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 27 | export MASTER_PORT=12802 28 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 29 | 30 | echo go $COUNT_NODE 31 | echo $HOSTNAMES 32 | 33 | source /fsx/yusong/clap/bin/activate 34 | cd /fsx/yusong/CLAP/src 35 | export TRANSFORMERS_CACHE=/fsx/yusong/transformers_cache 36 | 37 | srun --comment clap --cpu_bind=v --accel-bind=gn python -m evaluate.eval_retrieval_main \ 38 | --save-frequency 5 \ 39 | --save-top-performance 3 \ 40 | --save-most-recent \ 41 | --dataset-type="webdataset" \ 42 | --precision="fp32" \ 43 | --warmup 0 \ 44 | --batch-size=512 \ 45 | --wd=0.0 \ 46 | --epochs=50 \ 47 | --workers=6 \ 48 | --use-bn-sync \ 49 | --freeze-text \ 50 | --amodel HTSAT-tiny \ 51 | --tmodel roberta \ 52 | --report-to "wandb" \ 53 | --wandb-notes "10.17-freesound-dataset-4#" \ 54 | --datasetnames "freesound_no_overlap_noesc50" \ 55 | --datasetinfos "train" \ 56 | --seed 3407 \ 57 | --remotedata \ 58 | --logs /fsx/clap_logs \ 59 | --gather-with-grad \ 60 | --openai-model-cache-dir /fsx/yusong/transformers_cache \ 61 | --data-filling "repeatpad" \ 62 | --data-truncating "rand_trunc" \ 63 | --pretrained="/fsx/clap_logs/2022_10_17-02_08_21-model_HTSAT-tiny-lr_0.0001-b_96-j_6-p_fp32/checkpoints" -------------------------------------------------------------------------------- /experiment_scripts/train-pann-roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --comment clap 3 | #SBATCH --partition=g40423 4 | #SBATCH --job-name=mclap 5 | #SBATCH --nodes 3 6 | #SBATCH --ntasks-per-node 8 7 | #SBATCH --cpus-per-gpu=6 8 | #SBATCH --exclusive 9 | #SBATCH --output=%x_%j.out 10 | 11 | module load openmpi 12 | module load cuda/11.7 13 | export NCCL_PROTO=simple 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | export NCCL_DEBUG=info 18 | export OMPI_MCA_mtl_base_verbose=1 19 | export FI_EFA_ENABLE_SHM_TRANSFER=0 20 | export FI_PROVIDER=efa 21 | export FI_EFA_TX_MIN_CREDITS=64 22 | export NCCL_TREE_THRESHOLD=0 23 | 24 | # sent to sub script 25 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 26 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 27 | export MASTER_PORT=12802 28 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 29 | 30 | echo go $COUNT_NODE 31 | echo $HOSTNAMES 32 | 33 | source /fsx/yusong/clap/bin/activate 34 | cd /fsx/yusong/CLAP/src 35 | export TRANSFORMERS_CACHE=/fsx/yusong/transformers_cache 36 | 37 | srun --comment clap --cpu_bind=v --accel-bind=gn python -m training.main \ 38 | --save-frequency 5 \ 39 | --save-top-performance 3 \ 40 | --save-most-recent \ 41 | --dataset-type="webdataset" \ 42 | --precision="fp32" \ 43 | --batch-size=96 \ 44 | --lr=1e-4 \ 45 | --wd=0.0 \ 46 | --epochs=45 \ 47 | --workers=6 \ 48 | --use-bn-sync \ 49 | --amodel PANN-14 \ 50 | --tmodel roberta \ 51 | --warmup 500 \ 52 | --report-to "wandb" \ 53 | --wandb-notes "10.16-clap-dataset-1#-pann-roberta" \ 54 | --datasetnames "Clotho" "audiocaps" \ 55 | --datasetinfos "train" "unbalanced_train" \ 56 | --top-k-checkpoint-select-dataset="Clotho-test" \ 57 | --top-k-checkpoint-select-metric="mAP@10" \ 58 | --openai-model-cache-dir /fsx/yusong/transformers_cache \ 59 | --logs /fsx/clap_logs \ 60 | --seed 3407 \ 61 | --remotedata \ 62 | --gather-with-grad \ 63 | --optimizer "adam" \ 64 | --data-filling "repeatpad" \ 65 | --data-truncating "rand_trunc" \ 66 | --pretrained-audio /fsx/yusong/audio_pretrained_model/PANN-fullset-map=0.439.ckpt 67 | -------------------------------------------------------------------------------- /experiment_scripts/train-htsat-roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --comment clap 3 | #SBATCH --partition=g40423 4 | #SBATCH --job-name=mclap 5 | #SBATCH --nodes 3 6 | #SBATCH --ntasks-per-node 8 7 | #SBATCH --cpus-per-gpu=6 8 | #SBATCH --exclusive 9 | #SBATCH --output=%x_%j.out 10 | 11 | module load openmpi 12 | module load cuda/11.7 13 | export NCCL_PROTO=simple 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | export NCCL_DEBUG=info 18 | export OMPI_MCA_mtl_base_verbose=1 19 | export FI_EFA_ENABLE_SHM_TRANSFER=0 20 | export FI_PROVIDER=efa 21 | export FI_EFA_TX_MIN_CREDITS=64 22 | export NCCL_TREE_THRESHOLD=0 23 | 24 | # sent to sub script 25 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 26 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 27 | export MASTER_PORT=12802 28 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 29 | 30 | echo go $COUNT_NODE 31 | echo $HOSTNAMES 32 | 33 | source /fsx/yusong/clap/bin/activate 34 | cd /fsx/yusong/CLAP/src 35 | export TRANSFORMERS_CACHE=/fsx/yusong/transformers_cache 36 | 37 | srun --comment clap --cpu_bind=v --accel-bind=gn python -m training.main \ 38 | --save-frequency 5 \ 39 | --save-top-performance 3 \ 40 | --save-most-recent \ 41 | --dataset-type="webdataset" \ 42 | --precision="fp32" \ 43 | --batch-size=96 \ 44 | --lr=1e-4 \ 45 | --wd=0.0 \ 46 | --epochs=45 \ 47 | --workers=6 \ 48 | --use-bn-sync \ 49 | --amodel HTSAT-tiny \ 50 | --tmodel roberta \ 51 | --warmup 3200 \ 52 | --report-to "wandb" \ 53 | --wandb-notes "10.16-clap-dataset-1#-htsat-roberta" \ 54 | --datasetnames "Clotho" "audiocaps" \ 55 | --datasetinfos "train" "unbalanced_train" \ 56 | --top-k-checkpoint-select-dataset="Clotho-test" \ 57 | --top-k-checkpoint-select-metric="mAP@10" \ 58 | --openai-model-cache-dir /fsx/yusong/transformers_cache \ 59 | --logs /fsx/clap_logs \ 60 | --seed 3407 \ 61 | --remotedata \ 62 | --gather-with-grad \ 63 | --optimizer "adam" \ 64 | --data-filling "repeatpad" \ 65 | --data-truncating "rand_trunc" \ 66 | --pretrained-audio /fsx/yusong/audio_pretrained_model/HTSAT-fullset-imagenet-map=0.467.ckpt 67 | -------------------------------------------------------------------------------- /experiment_scripts/finetune-esc50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --comment clap 3 | #SBATCH --partition=g40423 4 | #SBATCH --job-name=mclap 5 | #SBATCH --nodes 3 6 | #SBATCH --ntasks-per-node 8 7 | #SBATCH --cpus-per-gpu=6 8 | #SBATCH --exclusive 9 | #SBATCH --output=%x_%j.out 10 | 11 | module load openmpi 12 | module load cuda/11.7 13 | export NCCL_PROTO=simple 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | export NCCL_DEBUG=info 18 | export OMPI_MCA_mtl_base_verbose=1 19 | export FI_EFA_ENABLE_SHM_TRANSFER=0 20 | export FI_PROVIDER=efa 21 | export FI_EFA_TX_MIN_CREDITS=64 22 | export NCCL_TREE_THRESHOLD=0 23 | 24 | # sent to sub script 25 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 26 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 27 | export MASTER_PORT=12802 28 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 29 | 30 | echo go $COUNT_NODE 31 | echo $HOSTNAMES 32 | 33 | source /fsx/yusong/clap/bin/activate 34 | cd /fsx/yusong/CLAP/src 35 | export TRANSFORMERS_CACHE=/fsx/yusong/transformers_cache 36 | 37 | srun --comment clap --cpu_bind=v --accel-bind=gn python -m evaluate.eval_linear_probe \ 38 | --save-frequency 50 \ 39 | --save-top-performance 3 \ 40 | --save-most-recent \ 41 | --dataset-type="webdataset" \ 42 | --precision="fp32" \ 43 | --warmup 0 \ 44 | --batch-size=160 \ 45 | --lr=1e-4 \ 46 | --wd=0.1 \ 47 | --epochs=100 \ 48 | --workers=4 \ 49 | --use-bn-sync \ 50 | --freeze-text \ 51 | --amodel PANN-14 \ 52 | --tmodel roberta \ 53 | --report-to "wandb" \ 54 | --wandb-notes "10.14-finetune-esc50" \ 55 | --datasetnames "esc50" \ 56 | --datasetinfos "train" \ 57 | --seed 3407 \ 58 | --remotedata \ 59 | --logs /fsx/clap_logs \ 60 | --gather-with-grad \ 61 | --lp-loss="ce" \ 62 | --lp-metrics="acc" \ 63 | --lp-lr=1e-4 \ 64 | --lp-mlp \ 65 | --class-label-path="../class_labels/ESC50_class_labels_indices_space.json" \ 66 | --openai-model-cache-dir /fsx/yusong/transformers_cache \ 67 | --pretrained="/fsx/clap_logs/2022_10_14-04_05_14-model_PANN-14-lr_0.0001-b_160-j_6-p_fp32/checkpoints" \ 68 | --data-filling "repeatpad" \ 69 | --data-truncating "rand_trunc" \ 70 | --optimizer "adam" -------------------------------------------------------------------------------- /experiment_scripts/finetune-fsd50k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --comment clap 3 | #SBATCH --partition=g40423 4 | #SBATCH --job-name=mclap 5 | #SBATCH --nodes 3 6 | #SBATCH --ntasks-per-node 8 7 | #SBATCH --cpus-per-gpu=6 8 | #SBATCH --exclusive 9 | #SBATCH --output=%x_%j.out 10 | 11 | module load openmpi 12 | module load cuda/11.7 13 | export NCCL_PROTO=simple 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | export NCCL_DEBUG=info 18 | export OMPI_MCA_mtl_base_verbose=1 19 | export FI_EFA_ENABLE_SHM_TRANSFER=0 20 | export FI_PROVIDER=efa 21 | export FI_EFA_TX_MIN_CREDITS=64 22 | export NCCL_TREE_THRESHOLD=0 23 | 24 | # sent to sub script 25 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 26 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 27 | export MASTER_PORT=12802 28 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 29 | 30 | echo go $COUNT_NODE 31 | echo $HOSTNAMES 32 | 33 | source /fsx/yusong/clap/bin/activate 34 | cd /fsx/yusong/CLAP/src 35 | export TRANSFORMERS_CACHE=/fsx/yusong/transformers_cache 36 | 37 | srun --comment clap --cpu_bind=v --accel-bind=gn python -m evaluate.eval_linear_probe \ 38 | --save-frequency 50 \ 39 | --save-top-performance 3 \ 40 | --save-most-recent \ 41 | --dataset-type="webdataset" \ 42 | --precision="fp32" \ 43 | --warmup 0 \ 44 | --batch-size=160 \ 45 | --lr=1e-4 \ 46 | --wd=0.1 \ 47 | --epochs=100 \ 48 | --workers=4 \ 49 | --use-bn-sync \ 50 | --freeze-text \ 51 | --amodel PANN-14 \ 52 | --tmodel roberta \ 53 | --report-to "wandb" \ 54 | --wandb-notes "10.14-finetune-fsd50k" \ 55 | --datasetnames "fsd50k_class_label" \ 56 | --datasetinfos "train" \ 57 | --seed 3407 \ 58 | --remotedata \ 59 | --logs /fsx/clap_logs \ 60 | --gather-with-grad \ 61 | --lp-loss="bce" \ 62 | --lp-metrics="map" \ 63 | --lp-lr=1e-4 \ 64 | --lp-mlp \ 65 | --class-label-path="../class_labels/FSD50k_class_labels_indices.json" \ 66 | --openai-model-cache-dir /fsx/yusong/transformers_cache \ 67 | --pretrained="/fsx/clap_logs/2022_10_14-04_05_14-model_PANN-14-lr_0.0001-b_160-j_6-p_fp32/checkpoints" \ 68 | --data-filling "repeatpad" \ 69 | --data-truncating "rand_trunc" \ 70 | --optimizer "adam" -------------------------------------------------------------------------------- /src/tests/data_loader_test.py: -------------------------------------------------------------------------------- 1 | from laion_clap import create_model 2 | from laion_clap.training.data import get_data 3 | from laion_clap.training import parse_args 4 | import torch 5 | import os 6 | from tqdm import tqdm 7 | from laion_clap.training.distributed import is_master, world_info_from_env 8 | from laion_clap.utils import dataset_split 9 | 10 | 11 | def run_dataloader(): 12 | for i, batch in enumerate(tqdm(dataloader, total=data["train"].dataloader.num_samples // args.batch_size)): 13 | pass 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | args = parse_args() 19 | # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? 20 | args.amodel = args.amodel.replace("/", "-") 21 | device = torch.device('cpu') 22 | 23 | # discover initial world args early so we can log properly 24 | args.distributed = False 25 | args.local_rank, args.rank, args.world_size = world_info_from_env() 26 | 27 | if args.remotedata and is_master(args): 28 | for dataset_name in args.datasetnames: 29 | for split in dataset_split[dataset_name]: 30 | if not os.path.exists(f"./json_files/{dataset_name}/{split}"): 31 | os.makedirs(f"./json_files/{dataset_name}/{split}") 32 | os.system( 33 | f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" 34 | ) 35 | 36 | model, model_cfg = create_model( 37 | args.amodel, 38 | args.tmodel, 39 | args.pretrained, 40 | precision=args.precision, 41 | device=device, 42 | jit=args.torchscript, 43 | force_quick_gelu=args.force_quick_gelu, 44 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 45 | skip_params=True, 46 | pretrained_audio=args.pretrained_audio, 47 | pretrained_text=args.pretrained_text, 48 | enable_fusion=args.enable_fusion, 49 | fusion_type=args.fusion_type 50 | ) 51 | 52 | data = get_data(args, model_cfg) 53 | 54 | dataloader, sampler = data["train"].dataloader, data["train"].sampler 55 | 56 | print('dataset size:', data["train"].dataloader.num_samples) 57 | print('batch size:', args.batch_size) 58 | print('num batches:', data["train"].dataloader.num_samples // args.batch_size) 59 | 60 | run_dataloader() 61 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .model import MLPLayers 5 | 6 | 7 | class LinearProbe(nn.Module): 8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): 9 | """ 10 | Args: 11 | model: nn.Module 12 | mlp: bool, if True, then use the MLP layer as the linear probe module 13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe 14 | in_ch: int, the output channel from CLAP model 15 | out_ch: int, the output channel from linear probe (class_num) 16 | act: torch.nn.functional, the activation function before the loss function 17 | """ 18 | super().__init__() 19 | in_ch = 512 20 | self.clap_model = model 21 | self.clap_model.text_branch = None # to save memory 22 | self.freeze = freeze 23 | if mlp: 24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) 25 | else: 26 | self.lp_layer = nn.Linear(in_ch, out_ch) 27 | 28 | if self.freeze: 29 | for param in self.clap_model.parameters(): 30 | param.requires_grad = False 31 | 32 | if act == 'None': 33 | self.act = None 34 | elif act == 'relu': 35 | self.act = nn.ReLU() 36 | elif act == 'elu': 37 | self.act = nn.ELU() 38 | elif act == 'prelu': 39 | self.act = nn.PReLU(num_parameters=in_ch) 40 | elif act == 'softmax': 41 | self.act = nn.Softmax(dim=-1) 42 | elif act == 'sigmoid': 43 | self.act = nn.Sigmoid() 44 | 45 | def forward(self, x, mix_lambda=None, device=None): 46 | """ 47 | Args: 48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list 49 | mix_lambda: torch.tensor [batch], the mixup lambda 50 | Returns: 51 | class_prob: torch.tensor [batch, class_num] 52 | 53 | """ 54 | # batchnorm cancel grandient 55 | if self.freeze: 56 | self.clap_model.eval() 57 | 58 | x = self.clap_model.audio_projection( 59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"]) 60 | out = self.lp_layer(x) 61 | if self.act is not None: 62 | out = self.act(out) 63 | return out 64 | -------------------------------------------------------------------------------- /experiment_scripts/htsat-roberta-large-dataset-fusion.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --comment clap 3 | #SBATCH --partition=g40423 4 | #SBATCH --job-name=mclap 5 | #SBATCH --nodes 3 6 | #SBATCH --ntasks-per-node 8 7 | #SBATCH --cpus-per-gpu=6 8 | #SBATCH --exclusive 9 | #SBATCH --output=%x_%j.out 10 | 11 | module load openmpi 12 | module load cuda/11.7 13 | export NCCL_PROTO=simple 14 | export FI_EFA_FORK_SAFE=1 15 | export FI_LOG_LEVEL=1 16 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 17 | export NCCL_DEBUG=info 18 | export OMPI_MCA_mtl_base_verbose=1 19 | export FI_EFA_ENABLE_SHM_TRANSFER=0 20 | export FI_PROVIDER=efa 21 | export FI_EFA_TX_MIN_CREDITS=64 22 | export NCCL_TREE_THRESHOLD=0 23 | 24 | # sent to sub script 25 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 26 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 27 | export MASTER_PORT=12802 28 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 29 | 30 | echo go $COUNT_NODE 31 | echo $HOSTNAMES 32 | 33 | source /fsx/yusong/clap/bin/activate 34 | cd /fsx/yusong/CLAP/src 35 | export TRANSFORMERS_CACHE=/fsx/yusong/transformers_cache 36 | 37 | srun --comment clap --cpu_bind=v --accel-bind=gn python -m training.main \ 38 | --save-frequency 5 \ 39 | --save-top-performance 3 \ 40 | --save-most-recent \ 41 | --dataset-type="webdataset" \ 42 | --precision="fp32" \ 43 | --batch-size=96 \ 44 | --lr=1e-4 \ 45 | --wd=0.0 \ 46 | --epochs=45 \ 47 | --workers=6 \ 48 | --use-bn-sync \ 49 | --amodel HTSAT-tiny \ 50 | --tmodel roberta \ 51 | --warmup 3200 \ 52 | --report-to "wandb" \ 53 | --wandb-notes "10.16-clap-dataset-2#-htsat-roberta-fusion" \ 54 | --datasetnames "Clotho" "audiocaps" "BBCSoundEffects" "free_to_use_sounds" "paramount_motion" "sonniss_game_effects" "wesoundeffects" "freesound_no_overlap_noesc50" "audiostock" "epidemic_sound_effects" "fsd50k_class_label" "MACS" "WavText5K" \ 55 | --full-train-dataset "BBCSoundEffects" "free_to_use_sounds" "paramount_motion" "sonniss_game_effects" "wesoundeffects" "audiostock" "epidemic_sound_effects" "fsd50k_class_label" \ 56 | --exclude-eval-dataset "freesound_no_overlap_noesc50" "MACS" "WavText5K" "fsd50k_class_label" \ 57 | --datasetinfos "train" "unbalanced_train" \ 58 | --top-k-checkpoint-select-dataset="Clotho-test" \ 59 | --top-k-checkpoint-select-metric="mAP@10" \ 60 | --openai-model-cache-dir /fsx/yusong/transformers_cache \ 61 | --logs /fsx/clap_logs \ 62 | --seed 3407 \ 63 | --remotedata \ 64 | --gather-with-grad \ 65 | --optimizer "adam" \ 66 | --data-filling "repeatpad" \ 67 | --data-truncating "fusion" \ 68 | --enable-fusion \ 69 | --fusion-type "aff_2d" \ 70 | --pretrained-audio /fsx/yusong/audio_pretrained_model/HTSAT-fullset-imagenet-map=0.467.ckpt 71 | -------------------------------------------------------------------------------- /src/laion_clap/unit_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contrastive Language-Audio Pretraining Model from LAION 3 | -------------------------------------------------------- 4 | Paper: https://arxiv.org/abs/2211.06687 5 | Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui 6 | Support: LAION 7 | """ 8 | 9 | import numpy as np 10 | import librosa 11 | import torch 12 | import laion_clap 13 | 14 | # quantization 15 | def int16_to_float32(x): 16 | return (x / 32767.0).astype(np.float32) 17 | 18 | 19 | def float32_to_int16(x): 20 | x = np.clip(x, a_min=-1., a_max=1.) 21 | return (x * 32767.).astype(np.int16) 22 | 23 | model = laion_clap.CLAP_Module(enable_fusion=False) 24 | model.load_ckpt() 25 | 26 | # Directly get audio embeddings from audio files 27 | audio_file = [ 28 | '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', 29 | '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav' 30 | ] 31 | audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=False) 32 | print(audio_embed[:,-20:]) 33 | print(audio_embed.shape) 34 | 35 | # Get audio embeddings from audio data 36 | audio_data, _ = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) # sample rate should be 48000 37 | audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T) 38 | audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False) 39 | print(audio_embed[:,-20:]) 40 | print(audio_embed.shape) 41 | 42 | # Directly get audio embeddings from audio files, but return torch tensor 43 | audio_file = [ 44 | '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', 45 | '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav' 46 | ] 47 | audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=True) 48 | print(audio_embed[:,-20:]) 49 | print(audio_embed.shape) 50 | 51 | # Get audio embeddings from audio data 52 | audio_data, _ = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) # sample rate should be 48000 53 | audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T) 54 | audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model 55 | audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=True) 56 | print(audio_embed[:,-20:]) 57 | print(audio_embed.shape) 58 | 59 | # Get text embedings from texts: 60 | text_data = ["I love the contrastive learning", "I love the pretrain model"] 61 | text_embed = model.get_text_embedding(text_data) 62 | print(text_embed) 63 | print(text_embed.shape) 64 | 65 | # Get text embedings from texts, but return torch tensor: 66 | text_data = ["I love the contrastive learning", "I love the pretrain model"] 67 | text_embed = model.get_text_embedding(text_data, use_tensor=True) 68 | print(text_embed) 69 | print(text_embed.shape) 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | src/logs/ 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | assets/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | sync.sh 137 | gpu1sync.sh 138 | .idea 139 | *.pdf 140 | **/._* 141 | **/*DS_* 142 | **.jsonl 143 | src/sbatch 144 | src/misc 145 | .vscode 146 | src/debug 147 | core.* 148 | 149 | # Allow 150 | !src/evaluation/misc/results_dbs/* 151 | src/test.py 152 | src/dev.ipynb 153 | src/sizes.json 154 | *.pt 155 | dev.ipynb 156 | src/json_files/* 157 | .vs/* 158 | run.sh 159 | Untitled.ipynb 160 | dev.ipynb 161 | src/dev.py 162 | ESC.csv 163 | .gitignore 164 | UrbanSound8K.csv 165 | -------------------------------------------------------------------------------- /src/laion_clap/training/infer_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | from clap_module import create_model 4 | from training.data import get_audio_features 5 | from training.data import int16_to_float32, float32_to_int16 6 | from transformers import RobertaTokenizer 7 | 8 | tokenize = RobertaTokenizer.from_pretrained('roberta-base') 9 | def tokenizer(text): 10 | result = tokenize( 11 | text, 12 | padding="max_length", 13 | truncation=True, 14 | max_length=77, 15 | return_tensors="pt", 16 | ) 17 | return {k: v.squeeze(0) for k, v in result.items()} 18 | 19 | def infer_text(): 20 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 21 | precision = 'fp32' 22 | amodel = 'HTSAT-tiny' # or 'PANN-14' 23 | tmodel = 'roberta' # the best text encoder in our training 24 | enable_fusion = True # False if you do not want to use the fusion model 25 | fusion_type = 'aff_2d' 26 | pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded. 27 | 28 | model, model_cfg = create_model( 29 | amodel, 30 | tmodel, 31 | pretrained, 32 | precision=precision, 33 | device=device, 34 | enable_fusion=enable_fusion, 35 | fusion_type=fusion_type 36 | ) 37 | # load the text, can be a list (i.e. batch size) 38 | text_data = ["I love the contrastive learning", "I love the pretrain model"] 39 | # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 40 | text_data = tokenizer(text_data) 41 | model.eval() 42 | text_embed = model.get_text_embedding(text_data) 43 | text_embed = text_embed.detach().cpu().numpy() 44 | print(text_embed) 45 | print(text_embed.shape) 46 | 47 | def infer_audio(): 48 | 49 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 50 | precision = 'fp32' 51 | amodel = 'HTSAT-tiny' # or 'PANN-14' 52 | tmodel = 'roberta' # the best text encoder in our training 53 | enable_fusion = True # False if you do not want to use the fusion model 54 | fusion_type = 'aff_2d' 55 | pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded. 56 | 57 | model, model_cfg = create_model( 58 | amodel, 59 | tmodel, 60 | pretrained, 61 | precision=precision, 62 | device=device, 63 | enable_fusion=enable_fusion, 64 | fusion_type=fusion_type 65 | ) 66 | 67 | # load the waveform of the shape (T,), should resample to 48000 68 | audio_waveform, sr = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) 69 | # quantize 70 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) 71 | audio_waveform = torch.from_numpy(audio_waveform).float() 72 | audio_dict = {} 73 | 74 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 75 | audio_dict = get_audio_features( 76 | audio_dict, audio_waveform, 480000, 77 | data_truncating='fusion', 78 | data_filling='repeatpad', 79 | audio_cfg=model_cfg['audio_cfg'] 80 | ) 81 | model.eval() 82 | # can send a list to the model, to process many audio tracks in one time (i.e. batch size) 83 | audio_embed = model.get_audio_embedding([audio_dict]) 84 | audio_embed = audio_embed.detach().cpu().numpy() 85 | print(audio_embed) 86 | print(audio_embed.shape) 87 | 88 | 89 | 90 | if __name__ == "__main__": 91 | infer_text() 92 | # infer_audio() 93 | -------------------------------------------------------------------------------- /src/laion_clap/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | # NOTE: This script is currently not supported for CLAP. 2 | import logging 3 | from contextlib import suppress 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | from clap_module import tokenize 10 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 11 | 12 | 13 | def zero_shot_classifier(model, classnames, templates, args): 14 | with torch.no_grad(): 15 | zeroshot_weights = [] 16 | for classname in tqdm(classnames): 17 | texts = [template(classname) for template in templates] # format with class 18 | texts = tokenize(texts).to(args.device) # tokenize 19 | if args.distributed and not args.horovod: 20 | class_embeddings = model.module.encode_text(texts) 21 | else: 22 | class_embeddings = model.encode_text(texts) 23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 24 | class_embedding /= class_embedding.norm() 25 | zeroshot_weights.append(class_embedding) 26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 27 | return zeroshot_weights 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | pred = output.topk(max(topk), 1, True, True)[1].t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 34 | 35 | 36 | def run(model, classifier, dataloader, args): 37 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 38 | with torch.no_grad(): 39 | top1, top5, n = 0., 0., 0. 40 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 41 | images = images.to(args.device) 42 | target = target.to(args.device) 43 | 44 | with autocast(): 45 | # predict 46 | if args.distributed and not args.horovod: 47 | image_features = model.module.encode_image(images) 48 | else: 49 | image_features = model.encode_image(images) 50 | image_features = F.normalize(image_features, dim=-1) 51 | logits = 100. * image_features @ classifier 52 | 53 | # measure accuracy 54 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 55 | top1 += acc1 56 | top5 += acc5 57 | n += images.size(0) 58 | 59 | top1 = (top1 / n) 60 | top5 = (top5 / n) 61 | return top1, top5 62 | 63 | 64 | def zero_shot_eval(model, data, epoch, args): 65 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 66 | return {} 67 | if args.zeroshot_frequency == 0: 68 | return {} 69 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 70 | return {} 71 | 72 | logging.info('Starting zero-shot imagenet.') 73 | 74 | logging.info('Building zero-shot classifier') 75 | classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) 76 | 77 | logging.info('Using classifier') 78 | results = {} 79 | if 'imagenet-val' in data: 80 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 81 | results['imagenet-zeroshot-val-top1'] = top1 82 | results['imagenet-zeroshot-val-top5'] = top5 83 | if 'imagenet-v2' in data: 84 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 85 | results['imagenetv2-zeroshot-val-top1'] = top1 86 | results['imagenetv2-zeroshot-val-top5'] = top5 87 | 88 | logging.info('Finished zero-shot imagenet.') 89 | 90 | return results 91 | -------------------------------------------------------------------------------- /class_labels/FSD50k_class_labels_indices.json: -------------------------------------------------------------------------------- 1 | {"Whispering": 0, "Gunshot, gunfire": 1, "Pour": 2, "Wind chime": 3, "Livestock, farm animals, working animals": 4, "Crackle": 5, "Waves, surf": 6, "Chicken, rooster": 7, "Chatter": 8, "Keyboard (musical)": 9, "Bark": 10, "Rail transport": 11, "Gong": 12, "Shatter": 13, "Ratchet, pawl": 14, "Clapping": 15, "Mallet percussion": 16, "Whoosh, swoosh, swish": 17, "Speech synthesizer": 18, "Respiratory sounds": 19, "Sliding door": 20, "Boat, Water vehicle": 21, "Boiling": 22, "Human voice": 23, "Drip": 24, "Thunderstorm": 25, "Male singing": 26, "Sneeze": 27, "Hi-hat": 28, "Guitar": 29, "Crying, sobbing": 30, "Speech": 31, "Slam": 32, "Crack": 33, "Yell": 34, "Drawer open or close": 35, "Run": 36, "Cheering": 37, "Splash, splatter": 38, "Tabla": 39, "Sigh": 40, "Packing tape, duct tape": 41, "Raindrop": 42, "Cymbal": 43, "Fill (with liquid)": 44, "Harp": 45, "Squeak": 46, "Zipper (clothing)": 47, "Tearing": 48, "Alarm": 49, "Skateboard": 50, "Wind instrument, woodwind instrument": 51, "Chink, clink": 52, "Wind": 53, "Ringtone": 54, "Microwave oven": 55, "Power tool": 56, "Dishes, pots, and pans": 57, "Musical instrument": 58, "Door": 59, "Domestic sounds, home sounds": 60, "Subway, metro, underground": 61, "Glockenspiel": 62, "Female speech, woman speaking": 63, "Coin (dropping)": 64, "Mechanical fan": 65, "Male speech, man speaking": 66, "Crowd": 67, "Screech": 68, "Animal": 69, "Human group actions": 70, "Telephone": 71, "Tools": 72, "Giggle": 73, "Crushing": 74, "Thump, thud": 75, "Hammer": 76, "Engine": 77, "Cupboard open or close": 78, "Glass": 79, "Writing": 80, "Clock": 81, "Plucked string instrument": 82, "Fowl": 83, "Water tap, faucet": 84, "Knock": 85, "Trickle, dribble": 86, "Rattle": 87, "Conversation": 88, "Accelerating, revving, vroom": 89, "Fixed-wing aircraft, airplane": 90, "Screaming": 91, "Walk, footsteps": 92, "Stream": 93, "Printer": 94, "Traffic noise, roadway noise": 95, "Motorcycle": 96, "Water": 97, "Scratching (performance technique)": 98, "Tap": 99, "Percussion": 100, "Chuckle, chortle": 101, "Motor vehicle (road)": 102, "Crow": 103, "Vehicle horn, car horn, honking": 104, "Bird vocalization, bird call, bird song": 105, "Drill": 106, "Race car, auto racing": 107, "Meow": 108, "Bass drum": 109, "Drum kit": 110, "Wild animals": 111, "Crash cymbal": 112, "Cough": 113, "Typing": 114, "Bowed string instrument": 115, "Computer keyboard": 116, "Vehicle": 117, "Train": 118, "Applause": 119, "Bicycle": 120, "Tick": 121, "Drum": 122, "Burping, eructation": 123, "Bicycle bell": 124, "Cowbell": 125, "Accordion": 126, "Toilet flush": 127, "Purr": 128, "Church bell": 129, "Cat": 130, "Insect": 131, "Engine starting": 132, "Chewing, mastication": 133, "Sink (filling or washing)": 134, "Dog": 135, "Bird": 136, "Finger snapping": 137, "Child speech, kid speaking": 138, "Wood": 139, "Music": 140, "Sawing": 141, "Bell": 142, "Fireworks": 143, "Crumpling, crinkling": 144, "Ocean": 145, "Gurgling": 146, "Fart": 147, "Mechanisms": 148, "Acoustic guitar": 149, "Singing": 150, "Boom": 151, "Bus": 152, "Cutlery, silverware": 153, "Liquid": 154, "Explosion": 155, "Gull, seagull": 156, "Thunder": 157, "Siren": 158, "Marimba, xylophone": 159, "Female singing": 160, "Tick-tock": 161, "Frog": 162, "Frying (food)": 163, "Buzz": 164, "Car passing by": 165, "Electric guitar": 166, "Gasp": 167, "Rattle (instrument)": 168, "Piano": 169, "Doorbell": 170, "Chime": 171, "Car": 172, "Fire": 173, "Trumpet": 174, "Truck": 175, "Hands": 176, "Domestic animals, pets": 177, "Chirp, tweet": 178, "Breathing": 179, "Cricket": 180, "Tambourine": 181, "Bass guitar": 182, "Idling": 183, "Scissors": 184, "Rain": 185, "Strum": 186, "Shout": 187, "Keys jangling": 188, "Camera": 189, "Hiss": 190, "Growling": 191, "Snare drum": 192, "Brass instrument": 193, "Bathtub (filling or washing)": 194, "Typewriter": 195, "Aircraft": 196, "Organ": 197, "Laughter": 198, "Harmonica": 199} -------------------------------------------------------------------------------- /src/tests/check_tars.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import soundfile as sf 3 | import io 4 | import os 5 | import random 6 | import copy 7 | from tqdm import tqdm 8 | import shutil 9 | import argparse 10 | import traceback 11 | import logging 12 | import json 13 | from laion_clap import tokenize 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--tar-path", 20 | type=str, 21 | default=None, 22 | help="Path to the tars", 23 | ) 24 | parser.add_argument( 25 | "--start", 26 | type=int, 27 | default=0, 28 | help="start from tar-path + start", 29 | ) 30 | parser.add_argument( 31 | "--end", 32 | type=int, 33 | default=99999, 34 | help="end with tar-path + end", 35 | ) 36 | parser.add_argument( 37 | "--exclude", 38 | nargs='+', 39 | default=None, 40 | help="exclude tar-path + exclude", 41 | ) 42 | parser.add_argument( 43 | "--batch-size", 44 | type=int, 45 | default=1, 46 | ) 47 | parser.add_argument( 48 | "--order", 49 | default=False, 50 | action='store_true', 51 | help="if keep the search order accendingly", 52 | ) 53 | args = parser.parse_args() 54 | return args 55 | 56 | def log_and_continue(exn): 57 | """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" 58 | logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") 59 | return True 60 | 61 | def preprocess( 62 | sample, 63 | ): 64 | """ 65 | Preprocess a single sample for wdsdataloader. 66 | """ 67 | audio_ext = "flac" 68 | text_ext = "json" 69 | audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) 70 | json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) 71 | sample["waveform"] = audio_data 72 | texts = json_dict_raw["text"] 73 | if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: 74 | texts = random.choice(texts) 75 | sample["raw_text"] = texts 76 | sample["text"] = tokenize(texts) 77 | return sample 78 | 79 | if __name__ == "__main__": 80 | args = parse_args() 81 | tar_path = args.tar_path 82 | idx_list = list(range(args.start, args.end)) 83 | if args.exclude != None: 84 | for x in args.exclude: 85 | idx_list.remove(x) 86 | if not args.order: 87 | random.shuffle(idx_list) 88 | if "aws" in tar_path: 89 | args.local = False 90 | if args.local: 91 | input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list] 92 | else: 93 | input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list] 94 | pipeline = [wds.SimpleShardList(input_shards)] 95 | pipeline.extend( 96 | [ 97 | wds.split_by_node, 98 | wds.split_by_worker, 99 | wds.tarfile_to_samples(handler=log_and_continue), 100 | wds.map(preprocess), 101 | wds.to_tuple("__url__", "__key__", "waveform"), 102 | wds.batched(1), 103 | ] 104 | ) 105 | dataset = wds.DataPipeline(*pipeline) 106 | dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) 107 | old_k = 0 108 | old_batch = None 109 | try: 110 | for k, batch in tqdm(enumerate(dataloader)): 111 | print("k:", k) 112 | print("batch:", batch) 113 | old_k = k 114 | old_batch = copy.deepcopy(batch) 115 | except: 116 | with open("check_tar_log.txt","a") as file: 117 | traceback.print_exc(file = file) 118 | print("old_k:", old_k) 119 | print("old_batch:", old_batch) 120 | pass 121 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | model_cfg, 26 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 27 | jit=True, 28 | cache_dir=os.path.expanduser("~/.cache/clip"), 29 | enable_fusion: bool = False, 30 | fusion_type: str = 'None' 31 | ): 32 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 33 | 34 | Parameters 35 | ---------- 36 | name : str 37 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLAP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if get_pretrained_url(name, 'openai'): 51 | model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir) 52 | elif os.path.isfile(name): 53 | model_path = name 54 | else: 55 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 56 | 57 | try: 58 | # loading JIT archive 59 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 60 | state_dict = None 61 | except RuntimeError: 62 | # loading saved state dict 63 | if jit: 64 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 65 | jit = False 66 | state_dict = torch.load(model_path, map_location="cpu") 67 | 68 | if not jit: 69 | try: 70 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device) 71 | except KeyError: 72 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 73 | model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device) 74 | 75 | if str(device) == "cpu": 76 | model.float() 77 | return model 78 | 79 | # patch the device names 80 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 81 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 82 | 83 | def patch_device(module): 84 | try: 85 | graphs = [module.graph] if hasattr(module, "graph") else [] 86 | except RuntimeError: 87 | graphs = [] 88 | 89 | if hasattr(module, "forward1"): 90 | graphs.append(module.forward1.graph) 91 | 92 | for graph in graphs: 93 | for node in graph.findAllNodes("prim::Constant"): 94 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 95 | node.copyAttributes(device_node) 96 | 97 | model.apply(patch_device) 98 | patch_device(model.encode_audio) 99 | patch_device(model.encode_text) 100 | 101 | # patch dtype to float32 on CPU 102 | if str(device) == "cpu": 103 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 104 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 105 | float_node = float_input.node() 106 | 107 | def patch_float(module): 108 | try: 109 | graphs = [module.graph] if hasattr(module, "graph") else [] 110 | except RuntimeError: 111 | graphs = [] 112 | 113 | if hasattr(module, "forward1"): 114 | graphs.append(module.forward1.graph) 115 | 116 | for graph in graphs: 117 | for node in graph.findAllNodes("aten::to"): 118 | inputs = list(node.inputs()) 119 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 120 | if inputs[i].node()["value"] == 5: 121 | inputs[i].node().copyAttributes(float_node) 122 | 123 | model.apply(patch_float) 124 | patch_float(model.encode_audio) 125 | patch_float(model.encode_text) 126 | model.float() 127 | 128 | model.audio_branch.audio_length = model.audio_cfg.audio_length 129 | return model 130 | -------------------------------------------------------------------------------- /src/laion_clap/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import socket 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if 'WORLD_SIZE' in os.environ: 37 | return int(os.environ['WORLD_SIZE']) > 1 38 | if 'SLURM_NTASKS' in os.environ: 39 | return int(os.environ['SLURM_NTASKS']) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ('SLURM_LOCALID', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'LOCAL_RANK'): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ('SLURM_PROCID', 'PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'RANK'): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ('SLURM_NTASKS', 'PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'WORLD_SIZE'): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 74 | world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 75 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 76 | args.local_rank = local_rank 77 | args.rank = world_rank 78 | args.world_size = world_size 79 | # args.local_rank = int(hvd.local_rank()) 80 | # args.rank = hvd.rank() 81 | # args.world_size = hvd.size() 82 | args.distributed = True 83 | os.environ['LOCAL_RANK'] = str(args.local_rank) 84 | os.environ['RANK'] = str(args.rank) 85 | os.environ['WORLD_SIZE'] = str(args.world_size) 86 | print(f"Distributed training: local_rank={args.local_rank}, " 87 | f"rank={args.rank}, world_size={args.world_size}, " 88 | f"hostname={socket.gethostname()}, pid={os.getpid()}") 89 | elif is_using_distributed(): 90 | if 'SLURM_PROCID' in os.environ: 91 | # DDP via SLURM 92 | args.local_rank, args.rank, args.world_size = world_info_from_env() 93 | # SLURM var -> torch.distributed vars in case needed 94 | os.environ['LOCAL_RANK'] = str(args.local_rank) 95 | os.environ['RANK'] = str(args.rank) 96 | os.environ['WORLD_SIZE'] = str(args.world_size) 97 | torch.distributed.init_process_group( 98 | backend=args.dist_backend, 99 | init_method=args.dist_url, 100 | world_size=args.world_size, 101 | rank=args.rank, 102 | ) 103 | elif 'OMPI_COMM_WORLD_SIZE' in os.environ: # using Summit cluster 104 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 105 | world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 106 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 107 | args.local_rank = local_rank 108 | args.rank = world_rank 109 | args.world_size = world_size 110 | torch.distributed.init_process_group( 111 | backend=args.dist_backend, 112 | init_method=args.dist_url, 113 | world_size=args.world_size, 114 | rank=args.rank, 115 | ) 116 | else: 117 | # DDP via torchrun, torch.distributed.launch 118 | args.local_rank, _, _ = world_info_from_env() 119 | torch.distributed.init_process_group( 120 | backend=args.dist_backend, 121 | init_method=args.dist_url) 122 | args.world_size = torch.distributed.get_world_size() 123 | args.rank = torch.distributed.get_rank() 124 | args.distributed = True 125 | print(f"Distributed training: local_rank={args.local_rank}, " 126 | f"rank={args.rank}, world_size={args.world_size}, " 127 | f"hostname={socket.gethostname()}, pid={os.getpid()}") 128 | 129 | if torch.cuda.is_available(): 130 | if args.distributed and not args.no_set_device_rank: 131 | device = 'cuda:%d' % args.local_rank 132 | else: 133 | device = 'cuda:0' 134 | torch.cuda.set_device(device) 135 | else: 136 | device = 'cpu' 137 | args.device = device 138 | device = torch.device(device) 139 | return device 140 | -------------------------------------------------------------------------------- /src/laion_clap/evaluate/eval_dcase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.backends.cudnn as cudnn 4 | from open_clip import create_model 5 | from open_clip import tokenize 6 | import glob 7 | import json 8 | import librosa 9 | from tqdm import tqdm 10 | import numpy as np 11 | import os 12 | from laion_clap.training.params import parse_args 13 | 14 | 15 | def get_output_from_single_audio(audio, text, model, device): 16 | 17 | # audio_embedding = model.audio_infer(audio, hopsize=5 * 48000, key="embedding", device=device)['embedding'] 18 | # if audio_embedding.ndim > 1: 19 | # audio_embedding = audio_embedding.mean(dim=0, keepdim=True) 20 | # else: 21 | # audio_embedding = audio_embedding.unsqueeze(0) 22 | audio_features = model(audio, None, device) 23 | audio_features = F.normalize(audio_features, dim=-1) 24 | text_features = model(None, text, device=device) 25 | text_features = F.normalize(text_features, dim=-1) 26 | 27 | # CHANGE: before normalize or after 28 | audio_features_mlp = model.audio_transform(audio_features) 29 | text_features_mlp = model.text_transform(text_features) 30 | return audio_features, text_features, audio_features_mlp, text_features_mlp, model.logit_scale_a.exp(), model.logit_scale_t.exp() 31 | 32 | 33 | def get_metrics(text_to_audio_logits): 34 | metrics = {} 35 | 36 | # repeat ground truth 5 times because Clotho has 5 text for 1 audio 37 | ground_truth = torch.repeat_interleave(torch.arange(len(text_features) // 5), 5).view(-1, 1) 38 | 39 | ranking = torch.argsort(text_to_audio_logits, descending=True) 40 | preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread 41 | preds = preds.detach().cpu().numpy() 42 | metrics[f"mean_rank"] = preds.mean() + 1 43 | metrics[f"median_rank"] = np.floor(np.median(preds)) + 1 44 | for k in [1, 5, 10]: 45 | metrics[f"R@{k}"] = np.mean(preds < k) 46 | # map@10 47 | metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) 48 | return metrics 49 | 50 | 51 | if __name__ == '__main__': 52 | args = parse_args() 53 | 54 | model_path = args.pretrained 55 | 56 | clotho_test_preprocessed_dir = "/fsx/yusong/clotho_test_set/test" 57 | 58 | cudnn.benchmark = True 59 | cudnn.deterministic = False 60 | 61 | audio_features_ensemble_all = [] 62 | text_features_ensemble_all = [] 63 | audio_features_mlp_ensemble_all = [] 64 | text_features_mlp_ensemble_all = [] 65 | logit_scale_a_ensemble_all = [] 66 | logit_scale_t_ensemble_all = [] 67 | 68 | 69 | device = torch.device('cuda') 70 | model, clap_model_cfg = create_model( 71 | args.amodel, 72 | args.tmodel, 73 | args.pretrained, 74 | precision=args.precision, 75 | device=device, 76 | jit=args.torchscript, 77 | force_quick_gelu=args.force_quick_gelu, 78 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 79 | skip_params=False 80 | ) 81 | 82 | # load model 83 | checkpoint = torch.load(model_path, map_location=device) 84 | if "epoch" in checkpoint: 85 | # resuming a train checkpoint w/ epoch and optimizer state 86 | start_epoch = checkpoint["epoch"] 87 | sd = checkpoint["state_dict"] 88 | if next(iter(sd.items()))[0].startswith( 89 | "module" 90 | ): 91 | sd = {k[len("module."):]: v for k, v in sd.items()} 92 | model.load_state_dict(sd) 93 | else: 94 | # loading a bare (model only) checkpoint for fine-tune or evaluation 95 | model.load_state_dict(checkpoint) 96 | 97 | model.to(device) 98 | model.eval() 99 | for param in model.parameters(): 100 | param.requires_grad = False 101 | 102 | # take every 5th file because clotho has 5 texts for 1 audio 103 | test_file_list = sorted(glob.glob(f"{clotho_test_preprocessed_dir}/*.flac")) 104 | 105 | audio_features_all = [] 106 | text_features_all = [] 107 | audio_features_mlp_all = [] 108 | text_features_mlp_all = [] 109 | logit_scale_a_all = [] 110 | logit_scale_t_all = [] 111 | 112 | with torch.no_grad(): 113 | for file_path in tqdm(test_file_list): 114 | json_path = file_path.replace(".flac", ".json") 115 | with open(json_path, "r") as f: 116 | json_data = json.load(f) 117 | audio, sr = librosa.load(file_path, sr=48000, mono=True) 118 | audio = torch.from_numpy(audio).to(device) 119 | audio = {'waveform': audio.unsqueeze(0), 'sample_rate': sr} 120 | text = json_data["text"] 121 | 122 | if args.tmodel == "transformer": 123 | from open_clip import tokenize 124 | text = tokenize(text) 125 | else: 126 | from laion_clap.training.data import tokenizer 127 | text = tokenizer(text, tmodel=args.tmodel) # 5 texts for each audio 128 | 129 | audio_features, text_features, audio_features_mlp, text_features_mlp, logit_scale_a, logit_scale_t = \ 130 | get_output_from_single_audio(audio, text, model, device) 131 | 132 | audio_features_all.append(audio_features.detach().cpu()) 133 | text_features_all.append(text_features.detach().cpu()) 134 | audio_features_mlp_all.append(audio_features_mlp.detach().cpu()) 135 | text_features_mlp_all.append(text_features_mlp.detach().cpu()) 136 | logit_scale_a_all.append(logit_scale_a.detach().cpu()) 137 | logit_scale_t_all.append(logit_scale_t.detach().cpu()) 138 | 139 | audio_features = torch.cat(audio_features_all) 140 | text_features = torch.cat(text_features_all) 141 | logit_scale_a = logit_scale_a_all[0] 142 | 143 | logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu() 144 | logits_per_text = logits_per_audio.t().detach().cpu() 145 | 146 | metrics = get_metrics( 147 | logits_per_text 148 | ) 149 | 150 | print(metrics) 151 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 54 | ) 55 | 56 | _VITB16 = dict( 57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 62 | ) 63 | 64 | _PRETRAINED = { 65 | "RN50": _RN50, 66 | "RN50-quickgelu": _RN50_quickgelu, 67 | "RN101": _RN101, 68 | "RN101-quickgelu": _RN101_quickgelu, 69 | "RN50x4": _RN50x4, 70 | "RN50x16": _RN50x16, 71 | "ViT-B-32": _VITB32, 72 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 73 | "ViT-B-16": _VITB16, 74 | "ViT-L-14": _VITL14, 75 | } 76 | 77 | 78 | def list_pretrained(as_str: bool = False): 79 | """ returns list of pretrained models 80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 81 | """ 82 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 83 | 84 | 85 | def list_pretrained_tag_models(tag: str): 86 | """ return all models having the specified pretrain tag """ 87 | models = [] 88 | for k in _PRETRAINED.keys(): 89 | if tag in _PRETRAINED[k]: 90 | models.append(k) 91 | return models 92 | 93 | 94 | def list_pretrained_model_tags(model: str): 95 | """ return all pretrain tags for the specified model architecture """ 96 | tags = [] 97 | if model in _PRETRAINED: 98 | tags.extend(_PRETRAINED[model].keys()) 99 | return tags 100 | 101 | 102 | def get_pretrained_url(model: str, tag: str): 103 | if model not in _PRETRAINED: 104 | return '' 105 | model_pretrained = _PRETRAINED[model] 106 | if tag not in model_pretrained: 107 | return '' 108 | return model_pretrained[tag] 109 | 110 | 111 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 112 | os.makedirs(root, exist_ok=True) 113 | filename = os.path.basename(url) 114 | 115 | if 'openaipublic' in url: 116 | expected_sha256 = url.split("/")[-2] 117 | else: 118 | expected_sha256 = '' 119 | 120 | download_target = os.path.join(root, filename) 121 | 122 | if os.path.exists(download_target) and not os.path.isfile(download_target): 123 | raise RuntimeError(f"{download_target} exists and is not a regular file") 124 | 125 | if os.path.isfile(download_target): 126 | if expected_sha256: 127 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 128 | return download_target 129 | else: 130 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 131 | else: 132 | return download_target 133 | 134 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 135 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 136 | while True: 137 | buffer = source.read(8192) 138 | if not buffer: 139 | break 140 | 141 | output.write(buffer) 142 | loop.update(len(buffer)) 143 | 144 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 145 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 146 | 147 | return download_target 148 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | result[i, :len(tokens)] = torch.tensor(tokens) 179 | 180 | return result 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /class_labels/VGGSound_class_labels_indices.json: -------------------------------------------------------------------------------- 1 | {"people crowd": 0, "playing mandolin": 1, "pumping water": 2, "horse neighing": 3, "airplane flyby": 4, "playing drum kit": 5, "pheasant crowing": 6, "duck quacking": 7, "wood thrush calling": 8, "dog bow-wow": 9, "arc welding": 10, "writing on blackboard with chalk": 11, "forging swords": 12, "swimming": 13, "bee, wasp, etc. buzzing": 14, "child singing": 15, "mouse clicking": 16, "playing trombone": 17, "telephone bell ringing": 18, "beat boxing": 19, "cattle mooing": 20, "lions roaring": 21, "ambulance siren": 22, "gibbon howling": 23, "people sniggering": 24, "playing clarinet": 25, "playing bassoon": 26, "playing bongo": 27, "playing electric guitar": 28, "playing badminton": 29, "bull bellowing": 30, "cat caterwauling": 31, "playing sitar": 32, "whale calling": 33, "snake hissing": 34, "people burping": 35, "francolin calling": 36, "fireworks banging": 37, "driving buses": 38, "people belly laughing": 39, "chicken clucking": 40, "playing double bass": 41, "canary calling": 42, "people battle cry": 43, "male singing": 44, "horse clip-clop": 45, "baby crying": 46, "cow lowing": 47, "reversing beeps": 48, "otter growling": 49, "cheetah chirrup": 50, "people running": 51, "ice cream truck, ice cream van": 52, "playing harpsichord": 53, "heart sounds, heartbeat": 54, "pig oinking": 55, "police radio chatter": 56, "cat hissing": 57, "wind chime": 58, "elk bugling": 59, "lions growling": 60, "fly, housefly buzzing": 61, "ferret dooking": 62, "railroad car, train wagon": 63, "church bell ringing": 64, "cat meowing": 65, "wind rustling leaves": 66, "bouncing on trampoline": 67, "mouse squeaking": 68, "sheep bleating": 69, "people eating crisps": 70, "people sneezing": 71, "playing squash": 72, "footsteps on snow": 73, "people humming": 74, "tap dancing": 75, "snake rattling": 76, "elephant trumpeting": 77, "people booing": 78, "disc scratching": 79, "skidding": 80, "cupboard opening or closing": 81, "playing bagpipes": 82, "basketball bounce": 83, "chinchilla barking": 84, "parrot talking": 85, "woodpecker pecking tree": 86, "fire truck siren": 87, "slot machine": 88, "playing french horn": 89, "air conditioning noise": 90, "people finger snapping": 91, "eagle screaming": 92, "playing harmonica": 93, "playing tympani": 94, "zebra braying": 95, "hedge trimmer running": 96, "playing acoustic guitar": 97, "hair dryer drying": 98, "orchestra": 99, "playing darts": 100, "children shouting": 101, "people slurping": 102, "alligators, crocodiles hissing": 103, "mouse pattering": 104, "people marching": 105, "vehicle horn, car horn, honking": 106, "sea lion barking": 107, "people clapping": 108, "hail": 109, "fire crackling": 110, "bathroom ventilation fan running": 111, "opening or closing car doors": 112, "skiing": 113, "dog barking": 114, "race car, auto racing": 115, "subway, metro, underground": 116, "underwater bubbling": 117, "car passing by": 118, "playing tennis": 119, "warbler chirping": 120, "helicopter": 121, "driving motorcycle": 122, "train wheels squealing": 123, "baby laughter": 124, "driving snowmobile": 125, "bird squawking": 126, "cuckoo bird calling": 127, "people whistling": 128, "shot football": 129, "playing tuning fork": 130, "dog howling": 131, "playing violin, fiddle": 132, "people eating": 133, "baltimore oriole calling": 134, "playing timbales": 135, "door slamming": 136, "people shuffling": 137, "typing on typewriter": 138, "magpie calling": 139, "playing harp": 140, "playing hammond organ": 141, "people eating apple": 142, "mosquito buzzing": 143, "playing oboe": 144, "playing volleyball": 145, "using sewing machines": 146, "electric grinder grinding": 147, "cutting hair with electric trimmers": 148, "splashing water": 149, "people sobbing": 150, "female singing": 151, "wind noise": 152, "car engine knocking": 153, "black capped chickadee calling": 154, "people screaming": 155, "cat growling": 156, "penguins braying": 157, "people coughing": 158, "metronome": 159, "train horning": 160, "goat bleating": 161, "playing tambourine": 162, "fox barking": 163, "airplane": 164, "firing cannon": 165, "thunder": 166, "smoke detector beeping": 167, "playing erhu": 168, "ice cracking": 169, "dog growling": 170, "playing saxophone": 171, "owl hooting": 172, "playing trumpet": 173, "sailing": 174, "waterfall burbling": 175, "machine gun shooting": 176, "baby babbling": 177, "playing synthesizer": 178, "donkey, ass braying": 179, "people cheering": 180, "playing shofar": 181, "playing hockey": 182, "playing banjo": 183, "cricket chirping": 184, "playing snare drum": 185, "ripping paper": 186, "child speech, kid speaking": 187, "crow cawing": 188, "sloshing water": 189, "playing zither": 190, "scuba diving": 191, "playing steelpan": 192, "goose honking": 193, "tapping guitar": 194, "spraying water": 195, "playing bass drum": 196, "printer printing": 197, "playing ukulele": 198, "ocean burbling": 199, "playing didgeridoo": 200, "sharpen knife": 201, "typing on computer keyboard": 202, "playing table tennis": 203, "rope skipping": 204, "playing marimba, xylophone": 205, "playing bugle": 206, "playing guiro": 207, "playing flute": 208, "tornado roaring": 209, "stream burbling": 210, "electric shaver, electric razor shaving": 211, "playing gong": 212, "eating with cutlery": 213, "playing piano": 214, "people giggling": 215, "chicken crowing": 216, "female speech, woman speaking": 217, "golf driving": 218, "frog croaking": 219, "people eating noodle": 220, "mynah bird singing": 221, "playing timpani": 222, "playing congas": 223, "dinosaurs bellowing": 224, "playing bass guitar": 225, "turkey gobbling": 226, "chipmunk chirping": 227, "chopping food": 228, "striking bowling": 229, "missile launch": 230, "squishing water": 231, "civil defense siren": 232, "blowtorch igniting": 233, "tractor digging": 234, "lighting firecrackers": 235, "playing theremin": 236, "train whistling": 237, "people nose blowing": 238, "car engine starting": 239, "lathe spinning": 240, "playing cello": 241, "motorboat, speedboat acceleration": 242, "playing vibraphone": 243, "playing washboard": 244, "playing cornet": 245, "pigeon, dove cooing": 246, "roller coaster running": 247, "opening or closing car electric windows": 248, "foghorn": 249, "coyote howling": 250, "hammering nails": 251, "toilet flushing": 252, "strike lighter": 253, "bird wings flapping": 254, "playing steel guitar, slide guitar": 255, "volcano explosion": 256, "people whispering": 257, "bowling impact": 258, "yodelling": 259, "firing muskets": 260, "raining": 261, "singing bowl": 262, "plastic bottle crushing": 263, "chimpanzee pant-hooting": 264, "playing electronic organ": 265, "chainsawing trees": 266, "dog baying": 267, "lawn mowing": 268, "people babbling": 269, "striking pool": 270, "eletric blender running": 271, "playing tabla": 272, "cap gun shooting": 273, "planing timber": 274, "air horn": 275, "sliding door": 276, "cell phone buzzing": 277, "sea waves": 278, "playing castanets": 279, "singing choir": 280, "people slapping": 281, "barn swallow calling": 282, "people hiccup": 283, "vacuum cleaner cleaning floors": 284, "playing lacrosse": 285, "bird chirping, tweeting": 286, "lip smacking": 287, "chopping wood": 288, "police car (siren)": 289, "running electric fan": 290, "cattle, bovinae cowbell": 291, "people gargling": 292, "opening or closing drawers": 293, "playing djembe": 294, "skateboarding": 295, "cat purring": 296, "rowboat, canoe, kayak rowing": 297, "engine accelerating, revving, vroom": 298, "playing glockenspiel": 299, "popping popcorn": 300, "car engine idling": 301, "alarm clock ringing": 302, "dog whimpering": 303, "playing accordion": 304, "playing cymbal": 305, "male speech, man speaking": 306, "rapping": 307, "people farting": 308} -------------------------------------------------------------------------------- /src/laion_clap/evaluate/eval_retrieval.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | import random 4 | import numpy as np 5 | import logging 6 | import wandb 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from laion_clap import create_model 10 | from laion_clap.training.logger import setup_logging 11 | from laion_clap.training.data import get_data 12 | from laion_clap.training.train import evaluate 13 | from laion_clap.utils import get_tar_path_from_dataset_name, dataset_split 14 | from laion_clap.training.params import parse_args 15 | 16 | 17 | def find_params_value(file, key): 18 | # find value of params in params_file 19 | with open(file, 'r') as f: 20 | for line in f: 21 | if key + ': ' in line: 22 | return line.split(': ')[1].strip() 23 | return None 24 | 25 | 26 | if __name__ == '__main__': 27 | # (yusong) repeated run might have different metric results. 28 | # This is because we randomly select crop 10s for each audio. 29 | args = parse_args() 30 | 31 | if os.path.isdir(args.pretrained): 32 | log_dir = os.path.dirname(args.pretrained) 33 | else: 34 | log_dir = os.path.dirname(os.path.dirname(args.pretrained)) 35 | 36 | args.log_level = logging.DEBUG if args.debug else logging.INFO 37 | log_path = os.path.join(log_dir, 'out.log') 38 | setup_logging(log_path, args.log_level) 39 | params_file = os.path.join(log_dir, 'params.txt') 40 | 41 | seed = 3407 42 | random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | np.random.seed(seed) 47 | 48 | cudnn.benchmark = True 49 | cudnn.deterministic = False 50 | pretrained = 'openai' 51 | amodel = find_params_value(params_file, 'amodel') 52 | tmodel = find_params_value(params_file, 'tmodel') 53 | 54 | if amodel is None or tmodel is None: 55 | raise ValueError('model type not found in params file') 56 | 57 | # set up dummy values for args 58 | args.parallel_eval = False 59 | args.rank = 0 60 | args.local_rank = 0 61 | args.world_size = 1 62 | args.val_frequency = 1 63 | args.epochs = 1 64 | args.precision = 'fp32' 65 | args.save_logs = True 66 | args.wandb = True 67 | args.class_index_dict = None 68 | 69 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 70 | args.device = device 71 | 72 | if args.remotedata: 73 | for dataset_name in args.datasetnames: 74 | for split in dataset_split[dataset_name]: 75 | if not os.path.exists(f"./json_files/{dataset_name}/{split}"): 76 | os.makedirs(f"./json_files/{dataset_name}/{split}") 77 | os.system( 78 | f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" 79 | ) 80 | 81 | if args.datasetinfos is None: 82 | args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] 83 | if args.dataset_type == "webdataset": 84 | args.train_data = get_tar_path_from_dataset_name( 85 | args.datasetnames, 86 | args.datasetinfos, 87 | islocal=not args.remotedata, 88 | proportion=args.dataset_proportion, 89 | dataset_path=args.datasetpath, 90 | ) 91 | args.val_data = get_tar_path_from_dataset_name( 92 | args.datasetnames, 93 | ["valid", "test", "eval"], 94 | islocal=not args.remotedata, 95 | proportion=1, 96 | dataset_path=args.datasetpath, 97 | ) 98 | model, model_cfg = create_model( 99 | amodel, 100 | tmodel, 101 | pretrained, 102 | precision='fp32', 103 | device=device, 104 | jit=False, 105 | force_quick_gelu=False, 106 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 107 | skip_params=False, 108 | enable_fusion=args.enable_fusion, 109 | fusion_type=args.fusion_type 110 | ) # a hack to get model_cfg 111 | 112 | data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data 113 | 114 | writer = None # if use tensorboard, initalize writer here 115 | 116 | if args.wandb: 117 | assert wandb is not None, "Please install wandb." 118 | 119 | # # find the line with "wandb_notes" and get the value 120 | # wandb_notes = find_params_value(params_file, 'wandb_notes') 121 | # if wandb_notes is None: 122 | # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') 123 | # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' 124 | # wandb_notes = wandb_notes + '-eval-retrieval' 125 | wandb_notes = args.wandb_notes 126 | 127 | logging.debug("Starting wandb.") 128 | args.train_sz = data["train"].dataloader.num_samples 129 | if args.val_data is not None: 130 | args.val_sz = data["val"].dataloader.num_samples 131 | # you will have to configure this for your project! 132 | if args.wandb_id is not None: 133 | wandb.init( 134 | project="clap", 135 | id=args.wandb_id, 136 | resume=True 137 | ) 138 | else: 139 | wandb.init( 140 | project="clap", 141 | notes=wandb_notes, 142 | name=wandb_notes, 143 | tags=[], 144 | config=vars(args), 145 | ) 146 | logging.debug("Finished loading wandb.") 147 | 148 | if os.path.isdir(args.pretrained): 149 | all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) 150 | else: 151 | all_model_checkpoints = [args.pretrained] 152 | for model_path in all_model_checkpoints: 153 | args.checkpoint_path = os.path.dirname(model_path) 154 | model, model_cfg = create_model( 155 | amodel, 156 | tmodel, 157 | pretrained, 158 | precision='fp32', 159 | device=device, 160 | jit=False, 161 | force_quick_gelu=False, 162 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 163 | skip_params=False, 164 | enable_fusion=args.enable_fusion, 165 | fusion_type=args.fusion_type 166 | ) 167 | 168 | # load model 169 | checkpoint = torch.load(model_path, map_location=device) 170 | if "epoch" in checkpoint: 171 | # resuming a train checkpoint w/ epoch and optimizer state 172 | start_epoch = checkpoint["epoch"] 173 | sd = checkpoint["state_dict"] 174 | if next(iter(sd.items()))[0].startswith( 175 | "module" 176 | ): 177 | sd = {k[len("module."):]: v for k, v in sd.items()} 178 | model.load_state_dict(sd) 179 | logging.info( 180 | f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" 181 | ) 182 | else: 183 | # loading a bare (model only) checkpoint for fine-tune or evaluation 184 | model.load_state_dict(checkpoint) 185 | start_epoch = 0 186 | 187 | model.to(device) 188 | model.eval() 189 | for param in model.parameters(): 190 | param.requires_grad = False 191 | 192 | evaluate(model, data, start_epoch, args, writer) 193 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/feature_fusion.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Feature Fusion for Varible-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | ''' 13 | 直接相加 DirectAddFuse 14 | ''' 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | ''' 25 | 多特征融合 iAFF 26 | ''' 27 | 28 | def __init__(self, channels=64, r=4, type='2D'): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == '1D': 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == '2D': 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f'the type is not supported' 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa,xa],dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | ''' 135 | 多特征融合 AFF 136 | ''' 137 | 138 | def __init__(self, channels=64, r=4, type='2D'): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == '1D': 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == '2D': 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f'the type is not supported.' 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa,xa],dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | 194 | -------------------------------------------------------------------------------- /src/laion_clap/hook.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contrastive Language-Audio Pretraining Model from LAION 3 | -------------------------------------------------------- 4 | Paper: https://arxiv.org/abs/2211.06687 5 | Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui 6 | Support: LAION 7 | """ 8 | import os 9 | import torch 10 | import librosa 11 | from clap_module import create_model 12 | from training.data import get_audio_features 13 | from training.data import int16_to_float32, float32_to_int16 14 | 15 | from transformers import RobertaTokenizer 16 | import wget 17 | from clap_module.factory import load_state_dict 18 | 19 | 20 | class CLAP_Module(torch.nn.Module): 21 | def __init__(self, enable_fusion=False, device=None, amodel= 'HTSAT-tiny', tmodel='roberta') -> None: 22 | """Initialize CLAP Model 23 | 24 | Parameters 25 | ---------- 26 | enable_fusion: bool 27 | if true, it will create the fusion clap model, otherwise non-fusion clap model (default: false) 28 | device: str 29 | if None, it will automatically detect the device (gpu or cpu) 30 | amodel: str 31 | audio encoder architecture, default: HTSAT-tiny 32 | tmodel: str 33 | text encoder architecture, default: roberta 34 | """ 35 | super(CLAP_Module, self).__init__() 36 | if device is None: 37 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 38 | 39 | precision = 'fp32' 40 | 41 | if enable_fusion: 42 | fusion_type = 'aff_2d' 43 | model, model_cfg = create_model( 44 | amodel, 45 | tmodel, 46 | precision=precision, 47 | device=device, 48 | enable_fusion=enable_fusion, 49 | fusion_type=fusion_type 50 | ) 51 | else: 52 | model, model_cfg = create_model( 53 | amodel, 54 | tmodel, 55 | precision=precision, 56 | device=device, 57 | enable_fusion=enable_fusion 58 | ) 59 | self.enable_fusion = enable_fusion 60 | self.model = model 61 | self.model_cfg = model_cfg 62 | self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') 63 | 64 | def tokenizer(self, text): 65 | result = self.tokenize( 66 | text, 67 | padding="max_length", 68 | truncation=True, 69 | max_length=77, 70 | return_tensors="pt", 71 | ) 72 | return {k: v.squeeze(0) for k, v in result.items()} 73 | 74 | def load_ckpt(self, ckpt = None, model_id = -1, verbose = True): 75 | """Load the pretrained checkpoint of CLAP model 76 | 77 | Parameters 78 | ---------- 79 | ckpt: str 80 | if ckpt is specified, the model will load this ckpt, otherwise the model will download the ckpt from zenodo. \n 81 | For fusion model, it will download the 630k+audioset fusion model (id=3). For non-fusion model, it will download the 630k+audioset model (id=1). 82 | model_id: 83 | if model_id is specified, you can download our best ckpt, as: 84 | id = 0 --> 630k non-fusion ckpt \n 85 | id = 1 --> 630k+audioset non-fusion ckpt \n 86 | id = 2 --> 630k fusion ckpt \n 87 | id = 3 --> 630k+audioset fusion ckpt \n 88 | Note that if your model is specied as non-fusion model but you download a fusion model ckpt, you will face an error. 89 | """ 90 | download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/' 91 | download_names = [ 92 | '630k-best.pt', 93 | '630k-audioset-best.pt', 94 | '630k-fusion-best.pt', 95 | '630k-audioset-fusion-best.pt' 96 | ] 97 | if ckpt is not None: 98 | print(f'Load the specified checkpoint {ckpt} from users.') 99 | else: 100 | print(f'Load our best checkpoint in the paper.') 101 | if model_id == -1: 102 | model_id = 3 if self.enable_fusion else 1 103 | package_dir = os.path.dirname(os.path.realpath(__file__)) 104 | weight_file_name = download_names[model_id] 105 | ckpt = os.path.join(package_dir, weight_file_name) 106 | if os.path.exists(ckpt): 107 | print(f'The checkpoint is already downloaded') 108 | else: 109 | print('Downloading laion_clap weight files...') 110 | ckpt = wget.download(download_link + weight_file_name, os.path.dirname(ckpt)) 111 | print('Download completed!') 112 | print('Load Checkpoint...') 113 | ckpt = load_state_dict(ckpt, skip_params=True) 114 | self.model.load_state_dict(ckpt) 115 | if verbose: 116 | param_names = [n for n, p in self.model.named_parameters()] 117 | for n in param_names: 118 | print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 119 | 120 | def get_audio_embedding_from_filelist(self, x, use_tensor=False): 121 | """get audio embeddings from the audio file list 122 | 123 | Parameters 124 | ---------- 125 | x: List[str] (N,): 126 | an audio file list to extract features, audio files can have different lengths (as we have the feature fusion machanism) 127 | use_tensor: boolean: 128 | if True, it will return the torch tensor, preserving the gradient (default: False). 129 | Returns 130 | ---------- 131 | audio_embed : numpy.darray | torch.Tensor (N,D): 132 | audio embeddings that extracted from audio files 133 | """ 134 | self.model.eval() 135 | audio_input = [] 136 | for f in x: 137 | # load the waveform of the shape (T,), should resample to 48000 138 | audio_waveform, _ = librosa.load(f, sr=48000) 139 | # quantize 140 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) 141 | audio_waveform = torch.from_numpy(audio_waveform).float() 142 | temp_dict = {} 143 | temp_dict = get_audio_features( 144 | temp_dict, audio_waveform, 480000, 145 | data_truncating='fusion' if self.enable_fusion else 'rand_trunc', 146 | data_filling='repeatpad', 147 | audio_cfg=self.model_cfg['audio_cfg'], 148 | require_grad=audio_waveform.requires_grad 149 | ) 150 | audio_input.append(temp_dict) 151 | audio_embed = self.model.get_audio_embedding(audio_input) 152 | if not use_tensor: 153 | audio_embed = audio_embed.detach().cpu().numpy() 154 | return audio_embed 155 | 156 | 157 | def get_audio_embedding_from_data(self, x, use_tensor=False): 158 | """get audio embeddings from the audio data 159 | 160 | Parameters 161 | ---------- 162 | x: np.darray | torch.Tensor (N,T): 163 | audio data, must be mono audio tracks. 164 | use_tensor: boolean: 165 | if True, x should be the tensor input and the output will be the tesnor, preserving the gradient (default: False). 166 | Note that if 'use tensor' is set to True, it will not do the quantize of the audio waveform (otherwise the gradient will not be preserved). 167 | Returns 168 | ---------- 169 | audio embed: numpy.darray | torch.Tensor (N,D): 170 | audio embeddings that extracted from audio files 171 | """ 172 | self.model.eval() 173 | audio_input = [] 174 | for audio_waveform in x: 175 | # quantize 176 | if not use_tensor: 177 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) 178 | audio_waveform = torch.from_numpy(audio_waveform).float() 179 | temp_dict = {} 180 | temp_dict = get_audio_features( 181 | temp_dict, audio_waveform, 480000, 182 | data_truncating='fusion' if self.enable_fusion else 'rand_trunc', 183 | data_filling='repeatpad', 184 | audio_cfg=self.model_cfg['audio_cfg'], 185 | require_grad=audio_waveform.requires_grad 186 | ) 187 | audio_input.append(temp_dict) 188 | audio_embed = self.model.get_audio_embedding(audio_input) 189 | if not use_tensor: 190 | audio_embed = audio_embed.detach().cpu().numpy() 191 | return audio_embed 192 | 193 | def get_text_embedding(self, x, tokenizer = None, use_tensor = False): 194 | """get text embeddings from texts 195 | 196 | Parameters 197 | ---------- 198 | x: List[str] (N,): 199 | text list 200 | tokenizer: func: 201 | the tokenizer function, if not provided (None), will use the default Roberta tokenizer. 202 | use_tensor: boolean: 203 | if True, the output will be the tesnor, preserving the gradient (default: False). 204 | Returns 205 | ---------- 206 | text_embed : numpy.darray | torch.Tensor (N,D): 207 | text embeddings that extracted from texts 208 | """ 209 | self.model.eval() 210 | if tokenizer is not None: 211 | text_input = tokenizer(x) 212 | else: 213 | text_input = self.tokenizer(x) 214 | text_embed = self.model.get_text_embedding(text_input) 215 | if not use_tensor: 216 | text_embed = text_embed.detach().cpu().numpy() 217 | return text_embed 218 | 219 | 220 | -------------------------------------------------------------------------------- /class_labels/audioset_class_labels_indices.json: -------------------------------------------------------------------------------- 1 | {"Speech": 0, "Male speech, man speaking": 1, "Female speech, woman speaking": 2, "Child speech, kid speaking": 3, "Conversation": 4, "Narration, monologue": 5, "Babbling": 6, "Speech synthesizer": 7, "Shout": 8, "Bellow": 9, "Whoop": 10, "Yell": 11, "Battle cry": 12, "Children shouting": 13, "Screaming": 14, "Whispering": 15, "Laughter": 16, "Baby laughter": 17, "Giggle": 18, "Snicker": 19, "Belly laugh": 20, "Chuckle, chortle": 21, "Crying, sobbing": 22, "Baby cry, infant cry": 23, "Whimper": 24, "Wail, moan": 25, "Sigh": 26, "Singing": 27, "Choir": 28, "Yodeling": 29, "Chant": 30, "Mantra": 31, "Male singing": 32, "Female singing": 33, "Child singing": 34, "Synthetic singing": 35, "Rapping": 36, "Humming": 37, "Groan": 38, "Grunt": 39, "Whistling": 40, "Breathing": 41, "Wheeze": 42, "Snoring": 43, "Gasp": 44, "Pant": 45, "Snort": 46, "Cough": 47, "Throat clearing": 48, "Sneeze": 49, "Sniff": 50, "Run": 51, "Shuffle": 52, "Walk, footsteps": 53, "Chewing, mastication": 54, "Biting": 55, "Gargling": 56, "Stomach rumble": 57, "Burping, eructation": 58, "Hiccup": 59, "Fart": 60, "Hands": 61, "Finger snapping": 62, "Clapping": 63, "Heart sounds, heartbeat": 64, "Heart murmur": 65, "Cheering": 66, "Applause": 67, "Chatter": 68, "Crowd": 69, "Hubbub, speech noise, speech babble": 70, "Children playing": 71, "Animal": 72, "Domestic animals, pets": 73, "Dog": 74, "Bark": 75, "Yip": 76, "Howl": 77, "Bow-wow": 78, "Growling": 79, "Whimper (dog)": 80, "Cat": 81, "Purr": 82, "Meow": 83, "Hiss": 84, "Caterwaul": 85, "Livestock, farm animals, working animals": 86, "Horse": 87, "Clip-clop": 88, "Neigh, whinny": 89, "Cattle, bovinae": 90, "Moo": 91, "Cowbell": 92, "Pig": 93, "Oink": 94, "Goat": 95, "Bleat": 96, "Sheep": 97, "Fowl": 98, "Chicken, rooster": 99, "Cluck": 100, "Crowing, cock-a-doodle-doo": 101, "Turkey": 102, "Gobble": 103, "Duck": 104, "Quack": 105, "Goose": 106, "Honk": 107, "Wild animals": 108, "Roaring cats (lions, tigers)": 109, "Roar": 110, "Bird": 111, "Bird vocalization, bird call, bird song": 112, "Chirp, tweet": 113, "Squawk": 114, "Pigeon, dove": 115, "Coo": 116, "Crow": 117, "Caw": 118, "Owl": 119, "Hoot": 120, "Bird flight, flapping wings": 121, "Canidae, dogs, wolves": 122, "Rodents, rats, mice": 123, "Mouse": 124, "Patter": 125, "Insect": 126, "Cricket": 127, "Mosquito": 128, "Fly, housefly": 129, "Buzz": 130, "Bee, wasp, etc.": 131, "Frog": 132, "Croak": 133, "Snake": 134, "Rattle": 135, "Whale vocalization": 136, "Music": 137, "Musical instrument": 138, "Plucked string instrument": 139, "Guitar": 140, "Electric guitar": 141, "Bass guitar": 142, "Acoustic guitar": 143, "Steel guitar, slide guitar": 144, "Tapping (guitar technique)": 145, "Strum": 146, "Banjo": 147, "Sitar": 148, "Mandolin": 149, "Zither": 150, "Ukulele": 151, "Keyboard (musical)": 152, "Piano": 153, "Electric piano": 154, "Organ": 155, "Electronic organ": 156, "Hammond organ": 157, "Synthesizer": 158, "Sampler": 159, "Harpsichord": 160, "Percussion": 161, "Drum kit": 162, "Drum machine": 163, "Drum": 164, "Snare drum": 165, "Rimshot": 166, "Drum roll": 167, "Bass drum": 168, "Timpani": 169, "Tabla": 170, "Cymbal": 171, "Hi-hat": 172, "Wood block": 173, "Tambourine": 174, "Rattle (instrument)": 175, "Maraca": 176, "Gong": 177, "Tubular bells": 178, "Mallet percussion": 179, "Marimba, xylophone": 180, "Glockenspiel": 181, "Vibraphone": 182, "Steelpan": 183, "Orchestra": 184, "Brass instrument": 185, "French horn": 186, "Trumpet": 187, "Trombone": 188, "Bowed string instrument": 189, "String section": 190, "Violin, fiddle": 191, "Pizzicato": 192, "Cello": 193, "Double bass": 194, "Wind instrument, woodwind instrument": 195, "Flute": 196, "Saxophone": 197, "Clarinet": 198, "Harp": 199, "Bell": 200, "Church bell": 201, "Jingle bell": 202, "Bicycle bell": 203, "Tuning fork": 204, "Chime": 205, "Wind chime": 206, "Change ringing (campanology)": 207, "Harmonica": 208, "Accordion": 209, "Bagpipes": 210, "Didgeridoo": 211, "Shofar": 212, "Theremin": 213, "Singing bowl": 214, "Scratching (performance technique)": 215, "Pop music": 216, "Hip hop music": 217, "Beatboxing": 218, "Rock music": 219, "Heavy metal": 220, "Punk rock": 221, "Grunge": 222, "Progressive rock": 223, "Rock and roll": 224, "Psychedelic rock": 225, "Rhythm and blues": 226, "Soul music": 227, "Reggae": 228, "Country": 229, "Swing music": 230, "Bluegrass": 231, "Funk": 232, "Folk music": 233, "Middle Eastern music": 234, "Jazz": 235, "Disco": 236, "Classical music": 237, "Opera": 238, "Electronic music": 239, "House music": 240, "Techno": 241, "Dubstep": 242, "Drum and bass": 243, "Electronica": 244, "Electronic dance music": 245, "Ambient music": 246, "Trance music": 247, "Music of Latin America": 248, "Salsa music": 249, "Flamenco": 250, "Blues": 251, "Music for children": 252, "New-age music": 253, "Vocal music": 254, "A capella": 255, "Music of Africa": 256, "Afrobeat": 257, "Christian music": 258, "Gospel music": 259, "Music of Asia": 260, "Carnatic music": 261, "Music of Bollywood": 262, "Ska": 263, "Traditional music": 264, "Independent music": 265, "Song": 266, "Background music": 267, "Theme music": 268, "Jingle (music)": 269, "Soundtrack music": 270, "Lullaby": 271, "Video game music": 272, "Christmas music": 273, "Dance music": 274, "Wedding music": 275, "Happy music": 276, "Funny music": 277, "Sad music": 278, "Tender music": 279, "Exciting music": 280, "Angry music": 281, "Scary music": 282, "Wind": 283, "Rustling leaves": 284, "Wind noise (microphone)": 285, "Thunderstorm": 286, "Thunder": 287, "Water": 288, "Rain": 289, "Raindrop": 290, "Rain on surface": 291, "Stream": 292, "Waterfall": 293, "Ocean": 294, "Waves, surf": 295, "Steam": 296, "Gurgling": 297, "Fire": 298, "Crackle": 299, "Vehicle": 300, "Boat, Water vehicle": 301, "Sailboat, sailing ship": 302, "Rowboat, canoe, kayak": 303, "Motorboat, speedboat": 304, "Ship": 305, "Motor vehicle (road)": 306, "Car": 307, "Vehicle horn, car horn, honking": 308, "Toot": 309, "Car alarm": 310, "Power windows, electric windows": 311, "Skidding": 312, "Tire squeal": 313, "Car passing by": 314, "Race car, auto racing": 315, "Truck": 316, "Air brake": 317, "Air horn, truck horn": 318, "Reversing beeps": 319, "Ice cream truck, ice cream van": 320, "Bus": 321, "Emergency vehicle": 322, "Police car (siren)": 323, "Ambulance (siren)": 324, "Fire engine, fire truck (siren)": 325, "Motorcycle": 326, "Traffic noise, roadway noise": 327, "Rail transport": 328, "Train": 329, "Train whistle": 330, "Train horn": 331, "Railroad car, train wagon": 332, "Train wheels squealing": 333, "Subway, metro, underground": 334, "Aircraft": 335, "Aircraft engine": 336, "Jet engine": 337, "Propeller, airscrew": 338, "Helicopter": 339, "Fixed-wing aircraft, airplane": 340, "Bicycle": 341, "Skateboard": 342, "Engine": 343, "Light engine (high frequency)": 344, "Dental drill, dentist's drill": 345, "Lawn mower": 346, "Chainsaw": 347, "Medium engine (mid frequency)": 348, "Heavy engine (low frequency)": 349, "Engine knocking": 350, "Engine starting": 351, "Idling": 352, "Accelerating, revving, vroom": 353, "Door": 354, "Doorbell": 355, "Ding-dong": 356, "Sliding door": 357, "Slam": 358, "Knock": 359, "Tap": 360, "Squeak": 361, "Cupboard open or close": 362, "Drawer open or close": 363, "Dishes, pots, and pans": 364, "Cutlery, silverware": 365, "Chopping (food)": 366, "Frying (food)": 367, "Microwave oven": 368, "Blender": 369, "Water tap, faucet": 370, "Sink (filling or washing)": 371, "Bathtub (filling or washing)": 372, "Hair dryer": 373, "Toilet flush": 374, "Toothbrush": 375, "Electric toothbrush": 376, "Vacuum cleaner": 377, "Zipper (clothing)": 378, "Keys jangling": 379, "Coin (dropping)": 380, "Scissors": 381, "Electric shaver, electric razor": 382, "Shuffling cards": 383, "Typing": 384, "Typewriter": 385, "Computer keyboard": 386, "Writing": 387, "Alarm": 388, "Telephone": 389, "Telephone bell ringing": 390, "Ringtone": 391, "Telephone dialing, DTMF": 392, "Dial tone": 393, "Busy signal": 394, "Alarm clock": 395, "Siren": 396, "Civil defense siren": 397, "Buzzer": 398, "Smoke detector, smoke alarm": 399, "Fire alarm": 400, "Foghorn": 401, "Whistle": 402, "Steam whistle": 403, "Mechanisms": 404, "Ratchet, pawl": 405, "Clock": 406, "Tick": 407, "Tick-tock": 408, "Gears": 409, "Pulleys": 410, "Sewing machine": 411, "Mechanical fan": 412, "Air conditioning": 413, "Cash register": 414, "Printer": 415, "Camera": 416, "Single-lens reflex camera": 417, "Tools": 418, "Hammer": 419, "Jackhammer": 420, "Sawing": 421, "Filing (rasp)": 422, "Sanding": 423, "Power tool": 424, "Drill": 425, "Explosion": 426, "Gunshot, gunfire": 427, "Machine gun": 428, "Fusillade": 429, "Artillery fire": 430, "Cap gun": 431, "Fireworks": 432, "Firecracker": 433, "Burst, pop": 434, "Eruption": 435, "Boom": 436, "Wood": 437, "Chop": 438, "Splinter": 439, "Crack": 440, "Glass": 441, "Chink, clink": 442, "Shatter": 443, "Liquid": 444, "Splash, splatter": 445, "Slosh": 446, "Squish": 447, "Drip": 448, "Pour": 449, "Trickle, dribble": 450, "Gush": 451, "Fill (with liquid)": 452, "Spray": 453, "Pump (liquid)": 454, "Stir": 455, "Boiling": 456, "Sonar": 457, "Arrow": 458, "Whoosh, swoosh, swish": 459, "Thump, thud": 460, "Thunk": 461, "Electronic tuner": 462, "Effects unit": 463, "Chorus effect": 464, "Basketball bounce": 465, "Bang": 466, "Slap, smack": 467, "Whack, thwack": 468, "Smash, crash": 469, "Breaking": 470, "Bouncing": 471, "Whip": 472, "Flap": 473, "Scratch": 474, "Scrape": 475, "Rub": 476, "Roll": 477, "Crushing": 478, "Crumpling, crinkling": 479, "Tearing": 480, "Beep, bleep": 481, "Ping": 482, "Ding": 483, "Clang": 484, "Squeal": 485, "Creak": 486, "Rustle": 487, "Whir": 488, "Clatter": 489, "Sizzle": 490, "Clicking": 491, "Clickety-clack": 492, "Rumble": 493, "Plop": 494, "Jingle, tinkle": 495, "Hum": 496, "Zing": 497, "Boing": 498, "Crunch": 499, "Silence": 500, "Sine wave": 501, "Harmonic": 502, "Chirp tone": 503, "Sound effect": 504, "Pulse": 505, "Inside, small room": 506, "Inside, large room or hall": 507, "Inside, public space": 508, "Outside, urban or manmade": 509, "Outside, rural or natural": 510, "Reverberation": 511, "Echo": 512, "Noise": 513, "Environmental noise": 514, "Static": 515, "Mains hum": 516, "Distortion": 517, "Sidetone": 518, "Cacophony": 519, "White noise": 520, "Pink noise": 521, "Throbbing": 522, "Vibration": 523, "Television": 524, "Radio": 525, "Field recording": 526} -------------------------------------------------------------------------------- /class_labels/audioset_fsd50k_class_labels_indices.json: -------------------------------------------------------------------------------- 1 | {"Speech": 0, "Male speech, man speaking": 1, "Female speech, woman speaking": 2, "Child speech, kid speaking": 3, "Conversation": 4, "Narration, monologue": 5, "Babbling": 6, "Speech synthesizer": 7, "Shout": 8, "Bellow": 9, "Whoop": 10, "Yell": 11, "Battle cry": 12, "Children shouting": 13, "Screaming": 14, "Whispering": 15, "Laughter": 16, "Baby laughter": 17, "Giggle": 18, "Snicker": 19, "Belly laugh": 20, "Chuckle, chortle": 21, "Crying, sobbing": 22, "Baby cry, infant cry": 23, "Whimper": 24, "Wail, moan": 25, "Sigh": 26, "Singing": 27, "Choir": 28, "Yodeling": 29, "Chant": 30, "Mantra": 31, "Male singing": 32, "Female singing": 33, "Child singing": 34, "Synthetic singing": 35, "Rapping": 36, "Humming": 37, "Groan": 38, "Grunt": 39, "Whistling": 40, "Breathing": 41, "Wheeze": 42, "Snoring": 43, "Gasp": 44, "Pant": 45, "Snort": 46, "Cough": 47, "Throat clearing": 48, "Sneeze": 49, "Sniff": 50, "Run": 51, "Shuffle": 52, "Walk, footsteps": 53, "Chewing, mastication": 54, "Biting": 55, "Gargling": 56, "Stomach rumble": 57, "Burping, eructation": 58, "Hiccup": 59, "Fart": 60, "Hands": 61, "Finger snapping": 62, "Clapping": 63, "Heart sounds, heartbeat": 64, "Heart murmur": 65, "Cheering": 66, "Applause": 67, "Chatter": 68, "Crowd": 69, "Hubbub, speech noise, speech babble": 70, "Children playing": 71, "Animal": 72, "Domestic animals, pets": 73, "Dog": 74, "Bark": 75, "Yip": 76, "Howl": 77, "Bow-wow": 78, "Growling": 79, "Whimper (dog)": 80, "Cat": 81, "Purr": 82, "Meow": 83, "Hiss": 84, "Caterwaul": 85, "Livestock, farm animals, working animals": 86, "Horse": 87, "Clip-clop": 88, "Neigh, whinny": 89, "Cattle, bovinae": 90, "Moo": 91, "Cowbell": 92, "Pig": 93, "Oink": 94, "Goat": 95, "Bleat": 96, "Sheep": 97, "Fowl": 98, "Chicken, rooster": 99, "Cluck": 100, "Crowing, cock-a-doodle-doo": 101, "Turkey": 102, "Gobble": 103, "Duck": 104, "Quack": 105, "Goose": 106, "Honk": 107, "Wild animals": 108, "Roaring cats (lions, tigers)": 109, "Roar": 110, "Bird": 111, "Bird vocalization, bird call, bird song": 112, "Chirp, tweet": 113, "Squawk": 114, "Pigeon, dove": 115, "Coo": 116, "Crow": 117, "Caw": 118, "Owl": 119, "Hoot": 120, "Bird flight, flapping wings": 121, "Canidae, dogs, wolves": 122, "Rodents, rats, mice": 123, "Mouse": 124, "Patter": 125, "Insect": 126, "Cricket": 127, "Mosquito": 128, "Fly, housefly": 129, "Buzz": 130, "Bee, wasp, etc.": 131, "Frog": 132, "Croak": 133, "Snake": 134, "Rattle": 135, "Whale vocalization": 136, "Music": 137, "Musical instrument": 138, "Plucked string instrument": 139, "Guitar": 140, "Electric guitar": 141, "Bass guitar": 142, "Acoustic guitar": 143, "Steel guitar, slide guitar": 144, "Tapping (guitar technique)": 145, "Strum": 146, "Banjo": 147, "Sitar": 148, "Mandolin": 149, "Zither": 150, "Ukulele": 151, "Keyboard (musical)": 152, "Piano": 153, "Electric piano": 154, "Organ": 155, "Electronic organ": 156, "Hammond organ": 157, "Synthesizer": 158, "Sampler": 159, "Harpsichord": 160, "Percussion": 161, "Drum kit": 162, "Drum machine": 163, "Drum": 164, "Snare drum": 165, "Rimshot": 166, "Drum roll": 167, "Bass drum": 168, "Timpani": 169, "Tabla": 170, "Cymbal": 171, "Hi-hat": 172, "Wood block": 173, "Tambourine": 174, "Rattle (instrument)": 175, "Maraca": 176, "Gong": 177, "Tubular bells": 178, "Mallet percussion": 179, "Marimba, xylophone": 180, "Glockenspiel": 181, "Vibraphone": 182, "Steelpan": 183, "Orchestra": 184, "Brass instrument": 185, "French horn": 186, "Trumpet": 187, "Trombone": 188, "Bowed string instrument": 189, "String section": 190, "Violin, fiddle": 191, "Pizzicato": 192, "Cello": 193, "Double bass": 194, "Wind instrument, woodwind instrument": 195, "Flute": 196, "Saxophone": 197, "Clarinet": 198, "Harp": 199, "Bell": 200, "Church bell": 201, "Jingle bell": 202, "Bicycle bell": 203, "Tuning fork": 204, "Chime": 205, "Wind chime": 206, "Change ringing (campanology)": 207, "Harmonica": 208, "Accordion": 209, "Bagpipes": 210, "Didgeridoo": 211, "Shofar": 212, "Theremin": 213, "Singing bowl": 214, "Scratching (performance technique)": 215, "Pop music": 216, "Hip hop music": 217, "Beatboxing": 218, "Rock music": 219, "Heavy metal": 220, "Punk rock": 221, "Grunge": 222, "Progressive rock": 223, "Rock and roll": 224, "Psychedelic rock": 225, "Rhythm and blues": 226, "Soul music": 227, "Reggae": 228, "Country": 229, "Swing music": 230, "Bluegrass": 231, "Funk": 232, "Folk music": 233, "Middle Eastern music": 234, "Jazz": 235, "Disco": 236, "Classical music": 237, "Opera": 238, "Electronic music": 239, "House music": 240, "Techno": 241, "Dubstep": 242, "Drum and bass": 243, "Electronica": 244, "Electronic dance music": 245, "Ambient music": 246, "Trance music": 247, "Music of Latin America": 248, "Salsa music": 249, "Flamenco": 250, "Blues": 251, "Music for children": 252, "New-age music": 253, "Vocal music": 254, "A capella": 255, "Music of Africa": 256, "Afrobeat": 257, "Christian music": 258, "Gospel music": 259, "Music of Asia": 260, "Carnatic music": 261, "Music of Bollywood": 262, "Ska": 263, "Traditional music": 264, "Independent music": 265, "Song": 266, "Background music": 267, "Theme music": 268, "Jingle (music)": 269, "Soundtrack music": 270, "Lullaby": 271, "Video game music": 272, "Christmas music": 273, "Dance music": 274, "Wedding music": 275, "Happy music": 276, "Funny music": 277, "Sad music": 278, "Tender music": 279, "Exciting music": 280, "Angry music": 281, "Scary music": 282, "Wind": 283, "Rustling leaves": 284, "Wind noise (microphone)": 285, "Thunderstorm": 286, "Thunder": 287, "Water": 288, "Rain": 289, "Raindrop": 290, "Rain on surface": 291, "Stream": 292, "Waterfall": 293, "Ocean": 294, "Waves, surf": 295, "Steam": 296, "Gurgling": 297, "Fire": 298, "Crackle": 299, "Vehicle": 300, "Boat, Water vehicle": 301, "Sailboat, sailing ship": 302, "Rowboat, canoe, kayak": 303, "Motorboat, speedboat": 304, "Ship": 305, "Motor vehicle (road)": 306, "Car": 307, "Vehicle horn, car horn, honking": 308, "Toot": 309, "Car alarm": 310, "Power windows, electric windows": 311, "Skidding": 312, "Tire squeal": 313, "Car passing by": 314, "Race car, auto racing": 315, "Truck": 316, "Air brake": 317, "Air horn, truck horn": 318, "Reversing beeps": 319, "Ice cream truck, ice cream van": 320, "Bus": 321, "Emergency vehicle": 322, "Police car (siren)": 323, "Ambulance (siren)": 324, "Fire engine, fire truck (siren)": 325, "Motorcycle": 326, "Traffic noise, roadway noise": 327, "Rail transport": 328, "Train": 329, "Train whistle": 330, "Train horn": 331, "Railroad car, train wagon": 332, "Train wheels squealing": 333, "Subway, metro, underground": 334, "Aircraft": 335, "Aircraft engine": 336, "Jet engine": 337, "Propeller, airscrew": 338, "Helicopter": 339, "Fixed-wing aircraft, airplane": 340, "Bicycle": 341, "Skateboard": 342, "Engine": 343, "Light engine (high frequency)": 344, "Dental drill, dentist's drill": 345, "Lawn mower": 346, "Chainsaw": 347, "Medium engine (mid frequency)": 348, "Heavy engine (low frequency)": 349, "Engine knocking": 350, "Engine starting": 351, "Idling": 352, "Accelerating, revving, vroom": 353, "Door": 354, "Doorbell": 355, "Ding-dong": 356, "Sliding door": 357, "Slam": 358, "Knock": 359, "Tap": 360, "Squeak": 361, "Cupboard open or close": 362, "Drawer open or close": 363, "Dishes, pots, and pans": 364, "Cutlery, silverware": 365, "Chopping (food)": 366, "Frying (food)": 367, "Microwave oven": 368, "Blender": 369, "Water tap, faucet": 370, "Sink (filling or washing)": 371, "Bathtub (filling or washing)": 372, "Hair dryer": 373, "Toilet flush": 374, "Toothbrush": 375, "Electric toothbrush": 376, "Vacuum cleaner": 377, "Zipper (clothing)": 378, "Keys jangling": 379, "Coin (dropping)": 380, "Scissors": 381, "Electric shaver, electric razor": 382, "Shuffling cards": 383, "Typing": 384, "Typewriter": 385, "Computer keyboard": 386, "Writing": 387, "Alarm": 388, "Telephone": 389, "Telephone bell ringing": 390, "Ringtone": 391, "Telephone dialing, DTMF": 392, "Dial tone": 393, "Busy signal": 394, "Alarm clock": 395, "Siren": 396, "Civil defense siren": 397, "Buzzer": 398, "Smoke detector, smoke alarm": 399, "Fire alarm": 400, "Foghorn": 401, "Whistle": 402, "Steam whistle": 403, "Mechanisms": 404, "Ratchet, pawl": 405, "Clock": 406, "Tick": 407, "Tick-tock": 408, "Gears": 409, "Pulleys": 410, "Sewing machine": 411, "Mechanical fan": 412, "Air conditioning": 413, "Cash register": 414, "Printer": 415, "Camera": 416, "Single-lens reflex camera": 417, "Tools": 418, "Hammer": 419, "Jackhammer": 420, "Sawing": 421, "Filing (rasp)": 422, "Sanding": 423, "Power tool": 424, "Drill": 425, "Explosion": 426, "Gunshot, gunfire": 427, "Machine gun": 428, "Fusillade": 429, "Artillery fire": 430, "Cap gun": 431, "Fireworks": 432, "Firecracker": 433, "Burst, pop": 434, "Eruption": 435, "Boom": 436, "Wood": 437, "Chop": 438, "Splinter": 439, "Crack": 440, "Glass": 441, "Chink, clink": 442, "Shatter": 443, "Liquid": 444, "Splash, splatter": 445, "Slosh": 446, "Squish": 447, "Drip": 448, "Pour": 449, "Trickle, dribble": 450, "Gush": 451, "Fill (with liquid)": 452, "Spray": 453, "Pump (liquid)": 454, "Stir": 455, "Boiling": 456, "Sonar": 457, "Arrow": 458, "Whoosh, swoosh, swish": 459, "Thump, thud": 460, "Thunk": 461, "Electronic tuner": 462, "Effects unit": 463, "Chorus effect": 464, "Basketball bounce": 465, "Bang": 466, "Slap, smack": 467, "Whack, thwack": 468, "Smash, crash": 469, "Breaking": 470, "Bouncing": 471, "Whip": 472, "Flap": 473, "Scratch": 474, "Scrape": 475, "Rub": 476, "Roll": 477, "Crushing": 478, "Crumpling, crinkling": 479, "Tearing": 480, "Beep, bleep": 481, "Ping": 482, "Ding": 483, "Clang": 484, "Squeal": 485, "Creak": 486, "Rustle": 487, "Whir": 488, "Clatter": 489, "Sizzle": 490, "Clicking": 491, "Clickety-clack": 492, "Rumble": 493, "Plop": 494, "Jingle, tinkle": 495, "Hum": 496, "Zing": 497, "Boing": 498, "Crunch": 499, "Silence": 500, "Sine wave": 501, "Harmonic": 502, "Chirp tone": 503, "Sound effect": 504, "Pulse": 505, "Inside, small room": 506, "Inside, large room or hall": 507, "Inside, public space": 508, "Outside, urban or manmade": 509, "Outside, rural or natural": 510, "Reverberation": 511, "Echo": 512, "Noise": 513, "Environmental noise": 514, "Static": 515, "Mains hum": 516, "Distortion": 517, "Sidetone": 518, "Cacophony": 519, "White noise": 520, "Pink noise": 521, "Throbbing": 522, "Vibration": 523, "Television": 524, "Radio": 525, "Field recording": 526, "Respiratory sounds": 527, "Human voice": 528, "Packing tape, duct tape": 529, "Domestic sounds, home sounds": 530, "Screech": 531, "Human group actions": 532, "Crash cymbal": 533, "Gull, seagull": 534} -------------------------------------------------------------------------------- /src/laion_clap/evaluate/eval_retrieval_main.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | import random 4 | import numpy as np 5 | import logging 6 | import wandb 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.backends.cudnn as cudnn 10 | from clap_module import create_model 11 | from clap_module import tokenize 12 | from training.logger import setup_logging 13 | from training.data import get_data 14 | from training.train import evaluate 15 | from clap_module.utils import get_tar_path_from_dataset_name, dataset_split 16 | from training.params import parse_args 17 | 18 | 19 | def find_params_value(file, key): 20 | # find value of params in params_file 21 | with open(file, 'r') as f: 22 | for line in f: 23 | if key + ': ' in line: 24 | return line.split(': ')[1].strip() 25 | return None 26 | 27 | 28 | def evaluate_zeroshot(model, data, start_epoch, args, writer): 29 | dataloader = data["val"].dataloader 30 | metrics = {} 31 | device = torch.device(args.device) 32 | model.eval() 33 | metrics.update({"epoch": start_epoch}) 34 | 35 | all_audio_features = [] 36 | all_class_labels = [] 37 | with torch.no_grad(): 38 | for i, batch in enumerate(dataloader): 39 | audios = batch # contains mel_spec, wavform, and longer list 40 | audio_features = model(audios, None, device) 41 | audio_features = F.normalize(audio_features, dim=-1) 42 | all_audio_features.append(audio_features.detach().cpu()) 43 | all_class_labels.append(torch.argmax(batch["class_label"], 1).long()) 44 | all_audio_features = torch.cat(all_audio_features, dim=0) 45 | all_class_labels = torch.cat(all_class_labels, dim=0) 46 | metrics["num_samples"] = all_audio_features.shape[0] 47 | 48 | # get text features 49 | all_texts = ["This is a sound of " + t for t in args.class_index_dict.keys()] 50 | # (yusong): a hack, can make it better 51 | if args.tmodel == "transformer": 52 | from clap_module.tokenizer import tokenize 53 | all_texts = tokenize(all_texts) 54 | else: 55 | from training.data import tokenizer 56 | all_texts = tokenizer(all_texts) 57 | all_text_features = model(None, all_texts, device) 58 | all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu() 59 | 60 | # compute similarity 61 | logit_scale_a, logit_scale_t = model(None, None, device) 62 | logit_scale_a = logit_scale_a.cpu() 63 | 64 | logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu() 65 | logits_per_text = logits_per_audio.t().detach().cpu() 66 | 67 | ground_truth = all_class_labels.view(-1, 1) 68 | logit = logits_per_audio 69 | 70 | ranking = torch.argsort(logit, descending=True) 71 | preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread 72 | preds = preds.detach().cpu().numpy() 73 | metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1 74 | metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1 75 | for k in [1, 5, 10]: 76 | metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k) 77 | # map@10 78 | metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) 79 | 80 | logging.info( 81 | f"Eval Epoch: {start_epoch} " 82 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 83 | ) 84 | 85 | if args.wandb: 86 | assert wandb is not None, "Please install wandb." 87 | for name, val in metrics.items(): 88 | wandb.log({f"val/{name}": val, "epoch": start_epoch}) 89 | 90 | 91 | if __name__ == '__main__': 92 | # (yusong) repeated run might have different metric results. 93 | # This is because we randomly select crop 10s for each audio. 94 | args = parse_args() 95 | 96 | if os.path.isdir(args.pretrained): 97 | log_dir = os.path.dirname(args.pretrained) 98 | else: 99 | log_dir = os.path.dirname(os.path.dirname(args.pretrained)) 100 | 101 | args.log_level = logging.DEBUG if args.debug else logging.INFO 102 | log_path = os.path.join(log_dir, 'out.log') 103 | setup_logging(log_path, args.log_level) 104 | params_file = os.path.join(log_dir, 'params.txt') 105 | 106 | seed = 3407 107 | random.seed(seed) 108 | torch.manual_seed(seed) 109 | torch.cuda.manual_seed(seed) 110 | torch.cuda.manual_seed_all(seed) 111 | np.random.seed(seed) 112 | 113 | cudnn.benchmark = True 114 | cudnn.deterministic = False 115 | pretrained = 'openai' 116 | amodel = find_params_value(params_file, 'amodel') 117 | tmodel = find_params_value(params_file, 'tmodel') 118 | 119 | if amodel is None or tmodel is None: 120 | raise ValueError('model type not found in params file') 121 | 122 | # set up dummy values for args 123 | args.parallel_eval = False 124 | args.rank = 0 125 | args.local_rank = 0 126 | args.world_size = 1 127 | args.val_frequency = 1 128 | args.epochs = 1 129 | args.precision = 'fp32' 130 | args.save_logs = True 131 | args.wandb = args.report_to == 'wandb' 132 | args.class_index_dict = None 133 | 134 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 135 | args.device = device 136 | 137 | if args.remotedata: 138 | for dataset_name in args.datasetnames: 139 | for split in dataset_split[dataset_name]: 140 | if not os.path.exists(f"./json_files/{dataset_name}/{split}"): 141 | os.makedirs(f"./json_files/{dataset_name}/{split}") 142 | os.system( 143 | f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" 144 | ) 145 | 146 | if args.datasetinfos is None: 147 | args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] 148 | if args.dataset_type == "webdataset": 149 | args.train_data = get_tar_path_from_dataset_name( 150 | args.datasetnames, 151 | args.datasetinfos, 152 | islocal=not args.remotedata, 153 | proportion=args.dataset_proportion, 154 | dataset_path=args.datasetpath, 155 | ) 156 | args.val_data = get_tar_path_from_dataset_name( 157 | args.datasetnames, 158 | ["valid", "test", "eval"], 159 | islocal=not args.remotedata, 160 | proportion=1, 161 | dataset_path=args.datasetpath, 162 | ) 163 | model, model_cfg = create_model( 164 | amodel, 165 | tmodel, 166 | pretrained, 167 | precision='fp32', 168 | device=device, 169 | jit=False, 170 | force_quick_gelu=False, 171 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 172 | skip_params=False, 173 | enable_fusion=args.enable_fusion, 174 | fusion_type=args.fusion_type 175 | ) # a hack to get model_cfg 176 | 177 | data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data 178 | 179 | writer = None # if use tensorboard, initalize writer here 180 | 181 | if args.wandb: 182 | assert wandb is not None, "Please install wandb." 183 | 184 | # # find the line with "wandb_notes" and get the value 185 | # wandb_notes = find_params_value(params_file, 'wandb_notes') 186 | # if wandb_notes is None: 187 | # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') 188 | # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' 189 | # wandb_notes = wandb_notes + '-eval-retrieval' 190 | wandb_notes = args.wandb_notes 191 | 192 | logging.debug("Starting wandb.") 193 | args.train_sz = data["train"].dataloader.num_samples 194 | if args.val_data is not None: 195 | args.val_sz = data["val"].dataloader.num_samples 196 | # you will have to configure this for your project! 197 | if args.wandb_id is not None: 198 | wandb.init( 199 | project="clap", 200 | id=args.wandb_id, 201 | resume=True 202 | ) 203 | else: 204 | wandb.init( 205 | project="clap", 206 | notes=wandb_notes, 207 | name=wandb_notes, 208 | tags=[], 209 | config=vars(args), 210 | ) 211 | logging.debug("Finished loading wandb.") 212 | 213 | if os.path.isdir(args.pretrained): 214 | all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) 215 | else: 216 | all_model_checkpoints = [args.pretrained] 217 | for model_path in all_model_checkpoints: 218 | args.checkpoint_path = os.path.dirname(model_path) 219 | model, model_cfg = create_model( 220 | amodel, 221 | tmodel, 222 | pretrained, 223 | precision='fp32', 224 | device=device, 225 | jit=False, 226 | force_quick_gelu=False, 227 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 228 | skip_params=False, 229 | enable_fusion=args.enable_fusion, 230 | fusion_type=args.fusion_type 231 | ) 232 | 233 | # load model 234 | checkpoint = torch.load(model_path, map_location=device) 235 | if "epoch" in checkpoint: 236 | # resuming a train checkpoint w/ epoch and optimizer state 237 | start_epoch = checkpoint["epoch"] 238 | sd = checkpoint["state_dict"] 239 | if next(iter(sd.items()))[0].startswith( 240 | "module" 241 | ): 242 | sd = {k[len("module."):]: v for k, v in sd.items()} 243 | model.load_state_dict(sd) 244 | logging.info( 245 | f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" 246 | ) 247 | else: 248 | # loading a bare (model only) checkpoint for fine-tune or evaluation 249 | model.load_state_dict(checkpoint) 250 | start_epoch = 0 251 | 252 | model.to(device) 253 | model.eval() 254 | for param in model.parameters(): 255 | param.requires_grad = False 256 | 257 | evaluate_zeroshot(model, data, start_epoch, args, writer) 258 | -------------------------------------------------------------------------------- /src/laion_clap/evaluate/eval_zeroshot_classification.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | import random 4 | import numpy as np 5 | import logging 6 | import wandb 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.backends.cudnn as cudnn 10 | from clap_module import create_model 11 | from clap_module import tokenize 12 | from training.logger import setup_logging 13 | from training.data import get_data 14 | from training.train import evaluate 15 | from clap_module.utils import get_tar_path_from_dataset_name, dataset_split 16 | from training.params import parse_args 17 | 18 | 19 | def find_params_value(file, key): 20 | # find value of params in params_file 21 | with open(file, 'r') as f: 22 | for line in f: 23 | if key + ': ' in line: 24 | return line.split(': ')[1].strip() 25 | return None 26 | 27 | 28 | def evaluate_zeroshot(model, data, start_epoch, args, writer): 29 | dataloader = data["val"].dataloader 30 | metrics = {} 31 | device = torch.device(args.device) 32 | model.eval() 33 | metrics.update({"epoch": start_epoch}) 34 | 35 | all_audio_features = [] 36 | all_class_labels = [] 37 | with torch.no_grad(): 38 | for i, batch in enumerate(dataloader): 39 | audios = batch # contains mel_spec, wavform, and longer list 40 | audio_features = model(audios, None, device) 41 | audio_features = F.normalize(audio_features, dim=-1) 42 | all_audio_features.append(audio_features.detach().cpu()) 43 | all_class_labels.append(torch.argmax(batch["class_label"], 1).long()) 44 | all_audio_features = torch.cat(all_audio_features, dim=0) 45 | all_class_labels = torch.cat(all_class_labels, dim=0) 46 | metrics["num_samples"] = all_audio_features.shape[0] 47 | 48 | # get text features 49 | if args.val_dataset_names == ['GTZAN']: 50 | all_texts = [f"This is a {t} song." for t in args.class_index_dict.keys()] 51 | else: 52 | all_texts = [f"This is a sound of {t}." for t in args.class_index_dict.keys()] 53 | logging.info(f'class label prompts: {all_texts}') 54 | # (yusong): a hack, can make it better 55 | if args.tmodel == "transformer": 56 | from clap_module.tokenizer import tokenize 57 | all_texts = tokenize(all_texts) 58 | else: 59 | from training.data import tokenizer 60 | all_texts = tokenizer(all_texts) 61 | all_text_features = model(None, all_texts, device) 62 | all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu() 63 | 64 | # compute similarity 65 | logit_scale_a, logit_scale_t = model(None, None, device) 66 | logit_scale_a = logit_scale_a.cpu() 67 | 68 | logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu() 69 | logits_per_text = logits_per_audio.t().detach().cpu() 70 | 71 | ground_truth = all_class_labels.view(-1, 1) 72 | logit = logits_per_audio 73 | 74 | ranking = torch.argsort(logit, descending=True) 75 | preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread 76 | preds = preds.detach().cpu().numpy() 77 | metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1 78 | metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1 79 | for k in [1, 5, 10]: 80 | metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k) 81 | # map@10 82 | metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) 83 | 84 | logging.info( 85 | f"Eval Epoch: {start_epoch} " 86 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 87 | ) 88 | 89 | if args.wandb: 90 | assert wandb is not None, "Please install wandb." 91 | for name, val in metrics.items(): 92 | wandb.log({f"val/{name}": val, "epoch": start_epoch}) 93 | 94 | 95 | if __name__ == '__main__': 96 | # (yusong) repeated run might have different metric results. 97 | # This is because we randomly select crop 10s for each audio. 98 | args = parse_args() 99 | 100 | if os.path.isdir(args.pretrained): 101 | log_dir = os.path.dirname(args.pretrained) 102 | else: 103 | log_dir = os.path.dirname(os.path.dirname(args.pretrained)) 104 | 105 | args.log_level = logging.DEBUG if args.debug else logging.INFO 106 | log_path = os.path.join(log_dir, 'out.log') 107 | setup_logging(log_path, args.log_level) 108 | params_file = os.path.join(log_dir, 'params.txt') 109 | 110 | seed = 3407 111 | random.seed(seed) 112 | torch.manual_seed(seed) 113 | torch.cuda.manual_seed(seed) 114 | torch.cuda.manual_seed_all(seed) 115 | np.random.seed(seed) 116 | 117 | cudnn.benchmark = True 118 | cudnn.deterministic = False 119 | pretrained = 'openai' 120 | amodel = find_params_value(params_file, 'amodel') 121 | tmodel = find_params_value(params_file, 'tmodel') 122 | 123 | if amodel is None or tmodel is None: 124 | raise ValueError('model type not found in params file') 125 | 126 | # set up dummy values for args 127 | args.parallel_eval = False 128 | args.rank = 0 129 | args.local_rank = 0 130 | args.world_size = 1 131 | args.val_frequency = 1 132 | args.epochs = 1 133 | args.precision = 'fp32' 134 | args.save_logs = True 135 | args.wandb = args.report_to == 'wandb' 136 | args.class_index_dict = None 137 | 138 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 139 | args.device = device 140 | 141 | if args.remotedata: 142 | for dataset_name in args.datasetnames: 143 | for split in dataset_split[dataset_name]: 144 | if not os.path.exists(f"./json_files/{dataset_name}/{split}"): 145 | os.makedirs(f"./json_files/{dataset_name}/{split}") 146 | os.system( 147 | f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" 148 | ) 149 | 150 | if args.datasetinfos is None: 151 | args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] 152 | if args.dataset_type == "webdataset": 153 | args.train_data = get_tar_path_from_dataset_name( 154 | args.datasetnames, 155 | args.datasetinfos, 156 | islocal=not args.remotedata, 157 | proportion=args.dataset_proportion, 158 | dataset_path=args.datasetpath, 159 | ) 160 | args.val_data = get_tar_path_from_dataset_name( 161 | args.datasetnames, 162 | ["valid", "test", "eval"], 163 | islocal=not args.remotedata, 164 | proportion=1, 165 | dataset_path=args.datasetpath, 166 | ) 167 | model, model_cfg = create_model( 168 | amodel, 169 | tmodel, 170 | pretrained, 171 | precision='fp32', 172 | device=device, 173 | jit=False, 174 | force_quick_gelu=False, 175 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 176 | skip_params=False, 177 | enable_fusion=args.enable_fusion, 178 | fusion_type=args.fusion_type 179 | ) # a hack to get model_cfg 180 | 181 | data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data 182 | 183 | writer = None # if use tensorboard, initalize writer here 184 | 185 | if args.wandb: 186 | assert wandb is not None, "Please install wandb." 187 | 188 | # # find the line with "wandb_notes" and get the value 189 | # wandb_notes = find_params_value(params_file, 'wandb_notes') 190 | # if wandb_notes is None: 191 | # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') 192 | # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' 193 | # wandb_notes = wandb_notes + '-eval-retrieval' 194 | wandb_notes = args.wandb_notes 195 | 196 | logging.debug("Starting wandb.") 197 | args.train_sz = data["train"].dataloader.num_samples 198 | if args.val_data is not None: 199 | args.val_sz = data["val"].dataloader.num_samples 200 | # you will have to configure this for your project! 201 | if args.wandb_id is not None: 202 | wandb.init( 203 | project="clap", 204 | id=args.wandb_id, 205 | resume=True 206 | ) 207 | else: 208 | wandb.init( 209 | project="clap", 210 | notes=wandb_notes, 211 | name=wandb_notes, 212 | tags=[], 213 | config=vars(args), 214 | ) 215 | logging.debug("Finished loading wandb.") 216 | 217 | if os.path.isdir(args.pretrained): 218 | all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) 219 | else: 220 | all_model_checkpoints = [args.pretrained] 221 | for model_path in all_model_checkpoints: 222 | args.checkpoint_path = os.path.dirname(model_path) 223 | model, model_cfg = create_model( 224 | amodel, 225 | tmodel, 226 | pretrained, 227 | precision='fp32', 228 | device=device, 229 | jit=False, 230 | force_quick_gelu=False, 231 | openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), 232 | skip_params=False, 233 | enable_fusion=args.enable_fusion, 234 | fusion_type=args.fusion_type 235 | ) 236 | 237 | # load model 238 | checkpoint = torch.load(model_path, map_location=device) 239 | if "epoch" in checkpoint: 240 | # resuming a train checkpoint w/ epoch and optimizer state 241 | start_epoch = checkpoint["epoch"] 242 | sd = checkpoint["state_dict"] 243 | if next(iter(sd.items()))[0].startswith( 244 | "module" 245 | ): 246 | sd = {k[len("module."):]: v for k, v in sd.items()} 247 | model.load_state_dict(sd) 248 | logging.info( 249 | f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" 250 | ) 251 | else: 252 | # loading a bare (model only) checkpoint for fine-tune or evaluation 253 | model.load_state_dict(checkpoint) 254 | start_epoch = 0 255 | 256 | model.to(device) 257 | model.eval() 258 | for param in model.parameters(): 259 | param.requires_grad = False 260 | 261 | evaluate_zeroshot(model, data, start_epoch, args, writer) 262 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from .model import CLAP, convert_weights_to_fp16 12 | from .openai import load_openai_model 13 | from .pretrained import get_pretrained_url, download_pretrained 14 | from .transform import image_transform 15 | 16 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 17 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 18 | 19 | 20 | def _natural_key(string_): 21 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 22 | 23 | 24 | def _rescan_model_configs(): 25 | global _MODEL_CONFIGS 26 | 27 | config_ext = (".json",) 28 | config_files = [] 29 | for config_path in _MODEL_CONFIG_PATHS: 30 | if config_path.is_file() and config_path.suffix in config_ext: 31 | config_files.append(config_path) 32 | elif config_path.is_dir(): 33 | for ext in config_ext: 34 | config_files.extend(config_path.glob(f"*{ext}")) 35 | 36 | for cf in config_files: 37 | with open(cf, "r") as f: 38 | model_cfg = json.load(f) 39 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): 40 | _MODEL_CONFIGS[cf.stem] = model_cfg 41 | 42 | _MODEL_CONFIGS = { 43 | k: v 44 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) 45 | } 46 | 47 | 48 | _rescan_model_configs() # initial populate of model config registry 49 | 50 | 51 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): 52 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 53 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: 54 | state_dict = checkpoint["state_dict"] 55 | else: 56 | state_dict = checkpoint 57 | if skip_params: 58 | if next(iter(state_dict.items()))[0].startswith("module"): 59 | state_dict = {k[7:]: v for k, v in state_dict.items()} 60 | # for k in state_dict: 61 | # if k.startswith('transformer'): 62 | # v = state_dict.pop(k) 63 | # state_dict['text_branch.' + k[12:]] = v 64 | return state_dict 65 | 66 | 67 | def create_model( 68 | amodel_name: str, 69 | tmodel_name: str, 70 | pretrained: str = "", 71 | precision: str = "fp32", 72 | device: torch.device = torch.device("cpu"), 73 | jit: bool = False, 74 | force_quick_gelu: bool = False, 75 | openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), 76 | skip_params=True, 77 | pretrained_audio: str = "", 78 | pretrained_text: str = "", 79 | enable_fusion: bool = False, 80 | fusion_type: str = 'None' 81 | # pretrained_image: bool = False, 82 | ): 83 | amodel_name = amodel_name.replace( 84 | "/", "-" 85 | ) # for callers using old naming with / in ViT names 86 | pretrained_orig = pretrained 87 | pretrained = pretrained.lower() 88 | if pretrained == "openai": 89 | if amodel_name in _MODEL_CONFIGS: 90 | logging.info(f"Loading {amodel_name} model config.") 91 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 92 | else: 93 | logging.error( 94 | f"Model config for {amodel_name} not found; available models {list_models()}." 95 | ) 96 | raise RuntimeError(f"Model config for {amodel_name} not found.") 97 | 98 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") 99 | # Hard Code in model name 100 | model_cfg["text_cfg"]["model_type"] = tmodel_name 101 | model = load_openai_model( 102 | "ViT-B-16", 103 | model_cfg, 104 | device=device, 105 | jit=jit, 106 | cache_dir=openai_model_cache_dir, 107 | enable_fusion=enable_fusion, 108 | fusion_type=fusion_type 109 | ) 110 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 111 | if precision == "amp" or precision == "fp32": 112 | model = model.float() 113 | else: 114 | if amodel_name in _MODEL_CONFIGS: 115 | logging.info(f"Loading {amodel_name} model config.") 116 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 117 | else: 118 | logging.error( 119 | f"Model config for {amodel_name} not found; available models {list_models()}." 120 | ) 121 | raise RuntimeError(f"Model config for {amodel_name} not found.") 122 | 123 | if force_quick_gelu: 124 | # override for use of QuickGELU on non-OpenAI transformer models 125 | model_cfg["quick_gelu"] = True 126 | 127 | # if pretrained_image: 128 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): 129 | # # pretrained weight loading for timm models set via vision_cfg 130 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True 131 | # else: 132 | # assert False, 'pretrained image towers currently only supported for timm models' 133 | model_cfg["text_cfg"]["model_type"] = tmodel_name 134 | model_cfg["enable_fusion"] = enable_fusion 135 | model_cfg["fusion_type"] = fusion_type 136 | model = CLAP(**model_cfg) 137 | 138 | if pretrained: 139 | checkpoint_path = "" 140 | url = get_pretrained_url(amodel_name, pretrained) 141 | if url: 142 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) 143 | elif os.path.exists(pretrained_orig): 144 | checkpoint_path = pretrained_orig 145 | if checkpoint_path: 146 | logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).") 147 | ckpt = load_state_dict(checkpoint_path, skip_params=True) 148 | model.load_state_dict(ckpt) 149 | param_names = [n for n, p in model.named_parameters()] 150 | for n in param_names: 151 | print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 152 | else: 153 | logging.warning( 154 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 155 | ) 156 | raise RuntimeError( 157 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 158 | ) 159 | 160 | if pretrained_audio: 161 | if amodel_name.startswith('PANN'): 162 | if 'Cnn14_mAP' in pretrained_audio: # official checkpoint 163 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 164 | audio_ckpt = audio_ckpt['model'] 165 | keys = list(audio_ckpt.keys()) 166 | for key in keys: 167 | if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key: 168 | v = audio_ckpt.pop(key) 169 | audio_ckpt['audio_branch.' + key] = v 170 | elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase 171 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 172 | audio_ckpt = audio_ckpt['state_dict'] 173 | keys = list(audio_ckpt.keys()) 174 | for key in keys: 175 | if key.startswith('sed_model'): 176 | v = audio_ckpt.pop(key) 177 | audio_ckpt['audio_branch.' + key[10:]] = v 178 | elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase 179 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 180 | else: 181 | raise ValueError('Unknown audio checkpoint') 182 | elif amodel_name.startswith('HTSAT'): 183 | if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint 184 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 185 | audio_ckpt = audio_ckpt['state_dict'] 186 | keys = list(audio_ckpt.keys()) 187 | for key in keys: 188 | if key.startswith('sed_model') and ('spectrogram_extractor' not in key 189 | and 'logmel_extractor' not in key): 190 | v = audio_ckpt.pop(key) 191 | audio_ckpt['audio_branch.' + key[10:]] = v 192 | elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase 193 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 194 | audio_ckpt = audio_ckpt['state_dict'] 195 | keys = list(audio_ckpt.keys()) 196 | for key in keys: 197 | if key.startswith('sed_model'): 198 | v = audio_ckpt.pop(key) 199 | audio_ckpt['audio_branch.' + key[10:]] = v 200 | elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase 201 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 202 | else: 203 | raise ValueError('Unknown audio checkpoint') 204 | else: 205 | raise f'this audio encoder pretrained checkpoint is not support' 206 | 207 | model.load_state_dict(audio_ckpt, strict=False) 208 | logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).") 209 | param_names = [n for n, p in model.named_parameters()] 210 | for n in param_names: 211 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") 212 | 213 | model.to(device=device) 214 | if precision == "fp16": 215 | assert device.type != "cpu" 216 | convert_weights_to_fp16(model) 217 | 218 | if jit: 219 | model = torch.jit.script(model) 220 | 221 | return model, model_cfg 222 | 223 | 224 | def create_model_and_transforms( 225 | model_name: str, 226 | pretrained: str = "", 227 | precision: str = "fp32", 228 | device: torch.device = torch.device("cpu"), 229 | jit: bool = False, 230 | force_quick_gelu: bool = False, 231 | # pretrained_image: bool = False, 232 | ): 233 | model = create_model( 234 | model_name, 235 | pretrained, 236 | precision, 237 | device, 238 | jit, 239 | force_quick_gelu=force_quick_gelu, 240 | # pretrained_image=pretrained_image 241 | ) 242 | preprocess_train = image_transform(model.visual.image_size, is_train=True) 243 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 244 | return model, preprocess_train, preprocess_val 245 | 246 | 247 | def list_models(): 248 | """enumerate available model architectures based on config files""" 249 | return list(_MODEL_CONFIGS.keys()) 250 | 251 | 252 | def add_model_config(path): 253 | """add model config path or file and update registry""" 254 | if not isinstance(path, Path): 255 | path = Path(path) 256 | _MODEL_CONFIG_PATHS.append(path) 257 | _rescan_model_configs() 258 | -------------------------------------------------------------------------------- /src/laion_clap/training/lp_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | from contextlib import suppress 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | try: 13 | import wandb 14 | except ImportError: 15 | wandb = None 16 | 17 | from clap_module import LPLoss, LPMetrics, lp_gather_features 18 | from clap_module.utils import do_mixup, get_mix_lambda 19 | from .distributed import is_master 20 | from .zero_shot import zero_shot_eval 21 | 22 | 23 | class AverageMeter(object): 24 | """Computes and stores the average and current value""" 25 | 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | 42 | def unwrap_model(model): 43 | if hasattr(model, "module"): 44 | return model.module 45 | else: 46 | return model 47 | 48 | 49 | def train_one_epoch( 50 | model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None, extra_suffix="" 51 | ): 52 | device = torch.device(args.device) 53 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 54 | model.train() 55 | loss = LPLoss(args.lp_loss) 56 | 57 | dataloader, sampler = data["train"].dataloader, data["train"].sampler 58 | if args.distributed and sampler is not None: 59 | sampler.set_epoch(epoch) 60 | num_batches_per_epoch = dataloader.num_batches 61 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 62 | 63 | # for toy dataset 64 | if args.dataset_type == "toy": 65 | dataloader.dataset.generate_queue() 66 | 67 | loss_m = AverageMeter() 68 | batch_time_m = AverageMeter() 69 | data_time_m = AverageMeter() 70 | end = time.time() 71 | 72 | for i, batch in enumerate(dataloader): 73 | step = num_batches_per_epoch * epoch + i 74 | 75 | if isinstance(scheduler, dict): 76 | for s in scheduler.values(): 77 | s(step) 78 | else: 79 | scheduler(step) 80 | 81 | audio = batch # contains mel_spec, wavform, and longer list 82 | class_label = batch['class_label'] 83 | # audio = audio.to(device=device, non_blocking=True) 84 | class_label = class_label.to(device=device, non_blocking=True) 85 | 86 | if args.mixup: 87 | # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 88 | mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(audio["waveform"]))).to(device) 89 | class_label = do_mixup(class_label, mix_lambda) 90 | else: 91 | mix_lambda = None 92 | 93 | data_time_m.update(time.time() - end) 94 | if isinstance(optimizer, dict): 95 | for o_ in optimizer.values(): 96 | o_.zero_grad() 97 | else: 98 | optimizer.zero_grad() 99 | 100 | with autocast(): 101 | pred = model(audio, mix_lambda=mix_lambda, device=device) 102 | total_loss = loss(pred, class_label) 103 | 104 | if isinstance(optimizer, dict): 105 | if scaler is not None: 106 | scaler.scale(total_loss).backward() 107 | for o_ in optimizer.values(): 108 | if args.horovod: 109 | o_.synchronize() 110 | scaler.unscale_(o_) 111 | with o_.skip_synchronize(): 112 | scaler.step(o_) 113 | else: 114 | scaler.step(o_) 115 | scaler.update() 116 | else: 117 | total_loss.backward() 118 | for o_ in optimizer.values(): 119 | o_.step() 120 | else: 121 | if scaler is not None: 122 | scaler.scale(total_loss).backward() 123 | if args.horovod: 124 | optimizer.synchronize() 125 | scaler.unscale_(optimizer) 126 | with optimizer.skip_synchronize(): 127 | scaler.step(optimizer) 128 | else: 129 | scaler.step(optimizer) 130 | scaler.update() 131 | else: 132 | total_loss.backward() 133 | optimizer.step() 134 | 135 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 136 | with torch.no_grad(): 137 | unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) 138 | unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) 139 | 140 | batch_time_m.update(time.time() - end) 141 | end = time.time() 142 | batch_count = i + 1 143 | 144 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): 145 | if isinstance(audio, dict): 146 | batch_size = len(audio["waveform"]) 147 | else: 148 | batch_size = len(audio) 149 | num_samples = batch_count * batch_size * args.world_size 150 | samples_per_epoch = dataloader.num_samples 151 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 152 | 153 | # NOTE loss is coarsely sampled, just master node and per log update 154 | loss_m.update(total_loss.item(), batch_size) 155 | if isinstance(optimizer, dict): 156 | logging.info( 157 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 158 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 159 | f"Data (t): {data_time_m.avg:.3f} " 160 | f"Batch (t): {batch_time_m.avg:.3f} " 161 | f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" 162 | ) 163 | log_data = { 164 | "loss": loss_m.val, 165 | "data_time": data_time_m.val, 166 | "batch_time": batch_time_m.val, 167 | "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], 168 | } 169 | else: 170 | logging.info( 171 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 172 | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " 173 | f"Data (t): {data_time_m.avg:.3f} " 174 | f"Batch (t): {batch_time_m.avg:.3f} " 175 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 176 | ) 177 | 178 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 179 | log_data = { 180 | "loss": loss_m.val, 181 | "data_time": data_time_m.val, 182 | "batch_time": batch_time_m.val, 183 | "lr": optimizer.param_groups[0]["lr"], 184 | } 185 | for name, val in log_data.items(): 186 | name = f"train{extra_suffix}/{name}" 187 | if tb_writer is not None: 188 | tb_writer.add_scalar(name, val, step) 189 | if args.wandb: 190 | assert wandb is not None, "Please install wandb." 191 | wandb.log({name: val, "step": step}) 192 | 193 | # resetting batch / data time meters per log window 194 | batch_time_m.reset() 195 | data_time_m.reset() 196 | # end for 197 | 198 | def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): 199 | metrics = {} 200 | if not args.parallel_eval: 201 | if not is_master(args): 202 | return metrics 203 | device = torch.device(args.device) 204 | model.eval() 205 | 206 | # CHANGE 207 | # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) 208 | # metrics.update(zero_shot_metrics) 209 | if is_master(args): 210 | print('Evaluating...') 211 | metric_names = args.lp_metrics.split(',') 212 | eval_tool = LPMetrics(metric_names=metric_names) 213 | 214 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 215 | if "val" in data and ( 216 | args.val_frequency 217 | and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) 218 | ): 219 | if args.parallel_eval: 220 | dataloader, sampler = data["val"].dataloader, data["val"].sampler 221 | if args.distributed and sampler is not None: 222 | sampler.set_epoch(epoch) 223 | samples_per_val = dataloader.num_samples 224 | else: 225 | dataloader = data["val"].dataloader 226 | num_samples = 0 227 | samples_per_val = dataloader.num_samples 228 | 229 | eval_info = { 230 | 'pred': [], 231 | 'target': [] 232 | } 233 | with torch.no_grad(): 234 | for i, batch in enumerate(dataloader): 235 | audio = batch # contains mel_spec, wavform, and longer list 236 | class_label = batch['class_label'] 237 | 238 | # audio = audio.to(device=device, non_blocking=True) 239 | class_label = class_label.to(device=device, non_blocking=True) 240 | 241 | with autocast(): 242 | pred = model(audio, device=device) 243 | if args.parallel_eval: 244 | pred, class_label = lp_gather_features(pred, class_label, args.world_size, args.horovod) 245 | eval_info['pred'].append(pred) 246 | eval_info['target'].append(class_label) 247 | 248 | num_samples += class_label.shape[0] 249 | 250 | if (i % 100) == 0: # and i != 0: 251 | logging.info( 252 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" 253 | ) 254 | 255 | if is_master(args): 256 | eval_info['pred'] = torch.cat(eval_info['pred'], 0).cpu() 257 | eval_info['target'] = torch.cat(eval_info['target'], 0).cpu() 258 | metric_dict = eval_tool.evaluate_mertics(eval_info['pred'], eval_info['target']) 259 | metrics.update(metric_dict) 260 | if "epoch" not in metrics.keys(): 261 | metrics.update({"epoch": epoch}) 262 | 263 | if is_master(args): 264 | if not metrics: 265 | return metrics 266 | 267 | logging.info( 268 | f"Eval Epoch: {epoch} " 269 | + "\n".join( 270 | [ 271 | "\t".join([f"{m}: {round(metrics[m], 4):.4f}" ]) 272 | for m in metrics 273 | ] 274 | ) 275 | ) 276 | if args.save_logs: 277 | for name, val in metrics.items(): 278 | if tb_writer is not None: 279 | tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) 280 | 281 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: 282 | f.write(json.dumps(metrics)) 283 | f.write("\n") 284 | 285 | if args.wandb: 286 | assert wandb is not None, "Please install wandb." 287 | for name, val in metrics.items(): 288 | wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) 289 | 290 | return metrics 291 | else: 292 | return metrics 293 | -------------------------------------------------------------------------------- /src/laion_clap/clap_module/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torchvision.ops.misc import FrozenBatchNorm2d 5 | import logging 6 | import h5py 7 | from tqdm import tqdm 8 | import random 9 | import json 10 | import os 11 | import pathlib 12 | 13 | # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. 14 | dataset_split = { 15 | "audiocaps": ["train", "valid", "test"], 16 | "audioset": ["balanced_train", "unbalanced_train", "eval"], 17 | "BBCSoundEffects": ["train", "test"], 18 | "Clotho": ["train", "test", "valid"], 19 | "free_to_use_sounds": ["train", "test"], 20 | "paramount_motion": ["train", "test"], 21 | "sonniss_game_effects": ["train", "test"], 22 | "wesoundeffects": ["train", "test"], 23 | "MACS": ["train", "test"], 24 | "freesound": ["train", "test"], 25 | "FSD50K": ["train", "test", "valid"], 26 | "fsd50k_class_label": ["train", "test", "valid"], 27 | "esc50": ["train", "test"], 28 | "ESC50_1": ["train", "test"], 29 | "ESC50_2": ["train", "test"], 30 | "ESC50_3": ["train", "test"], 31 | "ESC50_4": ["train", "test"], 32 | "ESC50_5": ["train", "test"], 33 | "audiostock": ["train", "test"], 34 | "freesound_no_overlap_noesc50": ["train", "test"], 35 | "epidemic_sound_effects": ["train", "test"], 36 | "VGGSound": ["train", "test"], 37 | "urbansound8k_class_label": ["train", "test"], 38 | "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], 39 | "audioset_t5_debiased": ["balanced_train", "unbalanced_train", "eval"], 40 | "epidemic_sound_effects_t5": ["train", "test"], 41 | "epidemic_sound_effects_t5_debiased": ["train", "test"], 42 | "WavText5K": ["train", "test"], 43 | "esc50_no_overlap": ["train", "test"], 44 | "usd8k_no_overlap": ["train", "test"], 45 | "fsd50k_200_class_label": ["train", "test", "valid"], 46 | "fma_full": ["train", "test"], 47 | "Genius": ["train", "test"], 48 | "Jamendo": ["train", "test"], 49 | "juno": ["train", "test"], 50 | "CMU_Arctic": ["train", "test"], 51 | "ravdess": ["train", "test"], 52 | "Europarl-st": ["train", "test"], 53 | "common_voice": ["train", "test"], 54 | "Jamendo_16bit": ["train", "test"], 55 | "genius_16bit_128": ["train", "test"], 56 | "juno_16bit": ["train", "test"], 57 | "fma_full_16bit_128": ["train", "test"], 58 | "GTZAN": ["train", "test"], 59 | } 60 | 61 | 62 | def freeze_batch_norm_2d(module, module_match={}, name=""): 63 | """ 64 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 65 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 66 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 67 | 68 | Args: 69 | module (torch.nn.Module): Any PyTorch module. 70 | module_match (dict): Dictionary of full module names to freeze (all if empty) 71 | name (str): Full module name (prefix) 72 | 73 | Returns: 74 | torch.nn.Module: Resulting module 75 | 76 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 77 | """ 78 | res = module 79 | is_match = True 80 | if module_match: 81 | is_match = name in module_match 82 | if is_match and isinstance( 83 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) 84 | ): 85 | res = FrozenBatchNorm2d(module.num_features) 86 | res.num_features = module.num_features 87 | res.affine = module.affine 88 | if module.affine: 89 | res.weight.data = module.weight.data.clone().detach() 90 | res.bias.data = module.bias.data.clone().detach() 91 | res.running_mean.data = module.running_mean.data 92 | res.running_var.data = module.running_var.data 93 | res.eps = module.eps 94 | else: 95 | for child_name, child in module.named_children(): 96 | full_child_name = ".".join([name, child_name]) if name else child_name 97 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 98 | if new_child is not child: 99 | res.add_module(child_name, new_child) 100 | return res 101 | 102 | 103 | def exist(dataset_name, dataset_type): 104 | """ 105 | Check if dataset exists 106 | """ 107 | if dataset_type in dataset_split[dataset_name]: 108 | return True 109 | else: 110 | return False 111 | 112 | 113 | def get_tar_path_from_dataset_name( 114 | dataset_names, 115 | dataset_types, 116 | islocal, 117 | dataset_path, 118 | proportion=1, 119 | full_dataset=None 120 | ): 121 | """ 122 | Get tar path from dataset name and type 123 | """ 124 | output = [] 125 | for n in dataset_names: 126 | if full_dataset is not None and n in full_dataset: 127 | current_dataset_types = dataset_split[n] 128 | else: 129 | current_dataset_types = dataset_types 130 | for s in current_dataset_types: 131 | tmp = [] 132 | if islocal: 133 | sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" 134 | if not os.path.exists(sizefilepath_): 135 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 136 | else: 137 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 138 | if not os.path.exists(sizefilepath_): 139 | continue 140 | sizes = json.load(open(sizefilepath_, "r")) 141 | for k in sizes.keys(): 142 | if islocal: 143 | tmp.append(f"{dataset_path}/{n}/{s}/{k}") 144 | else: 145 | tmp.append( 146 | f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" 147 | ) 148 | if proportion != 1: 149 | tmp = random.sample(tmp, int(proportion * len(tmp))) 150 | output.append(tmp) 151 | return sum(output, []) 152 | 153 | 154 | def get_tar_path_from_txts(txt_path, islocal, proportion=1): 155 | """ 156 | Get tar path from txt path 157 | """ 158 | if isinstance(txt_path, (list, tuple)): 159 | return sum( 160 | [ 161 | get_tar_path_from_txts( 162 | txt_path[i], islocal=islocal, proportion=proportion 163 | ) 164 | for i in range(len(txt_path)) 165 | ], 166 | [], 167 | ) 168 | if isinstance(txt_path, str): 169 | with open(txt_path) as f: 170 | lines = f.readlines() 171 | if islocal: 172 | lines = [ 173 | lines[i] 174 | .split("\n")[0] 175 | .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") 176 | for i in range(len(lines)) 177 | ] 178 | else: 179 | lines = [ 180 | lines[i].split("\n")[0].replace(".tar", ".tar -") 181 | for i in range(len(lines)) 182 | ] 183 | if proportion != 1: 184 | print("Sampling tars with proportion of {}".format(proportion)) 185 | lines = random.sample(lines, int(proportion * len(lines))) 186 | return lines 187 | 188 | 189 | def get_mix_lambda(mixup_alpha, batch_size): 190 | mixup_lambdas = [ 191 | np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) 192 | ] 193 | return np.array(mixup_lambdas).astype(np.float32) 194 | 195 | 196 | def do_mixup(x, mixup_lambda): 197 | """ 198 | Args: 199 | x: (batch_size , ...) 200 | mixup_lambda: (batch_size,) 201 | Returns: 202 | out: (batch_size, ...) 203 | """ 204 | out = ( 205 | x.transpose(0, -1) * mixup_lambda 206 | + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) 207 | ).transpose(0, -1) 208 | return out 209 | 210 | 211 | def interpolate(x, ratio): 212 | """Interpolate data in time domain. This is used to compensate the 213 | resolution reduction in downsampling of a CNN. 214 | 215 | Args: 216 | x: (batch_size, time_steps, classes_num) 217 | ratio: int, ratio to interpolate 218 | Returns: 219 | upsampled: (batch_size, time_steps * ratio, classes_num) 220 | """ 221 | (batch_size, time_steps, classes_num) = x.shape 222 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 223 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 224 | return upsampled 225 | 226 | 227 | def pad_framewise_output(framewise_output, frames_num): 228 | """Pad framewise_output to the same length as input frames. The pad value 229 | is the same as the value of the last frame. 230 | Args: 231 | framewise_output: (batch_size, frames_num, classes_num) 232 | frames_num: int, number of frames to pad 233 | Outputs: 234 | output: (batch_size, frames_num, classes_num) 235 | """ 236 | pad = framewise_output[:, -1:, :].repeat( 237 | 1, frames_num - framewise_output.shape[1], 1 238 | ) 239 | """tensor for padding""" 240 | 241 | output = torch.cat((framewise_output, pad), dim=1) 242 | """(batch_size, frames_num, classes_num)""" 243 | 244 | 245 | def process_ipc(index_path, classes_num, filename): 246 | # load data 247 | logging.info("Load Data...............") 248 | ipc = [[] for _ in range(classes_num)] 249 | with h5py.File(index_path, "r") as f: 250 | for i in tqdm(range(len(f["target"]))): 251 | t_class = np.where(f["target"][i])[0] 252 | for t in t_class: 253 | ipc[t].append(i) 254 | print(ipc) 255 | np.save(filename, ipc) 256 | logging.info("Load Data Succeed...............") 257 | 258 | 259 | def save_to_dict(s, o_={}): 260 | sp = s.split(": ") 261 | o_.update({sp[0]: float(sp[1])}) 262 | return o_ 263 | 264 | 265 | def get_data_from_log(txt_path): 266 | """ 267 | Output dictionary from out.txt log file 268 | """ 269 | with open(txt_path) as f: 270 | lines = f.readlines() 271 | val_data = {} 272 | train_data = {} 273 | train_losses = [] 274 | train_losses_epoch = [] 275 | for i in range(len(lines)): 276 | if "| INFO |" in lines[i]: 277 | if "Eval Epoch" in lines[i]: 278 | if "val_loss" in lines[i]: 279 | # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) 280 | line = lines[i].split("Eval Epoch: ")[-1] 281 | num_epoch = int(line.split(" ")[0].split(" ")[0]) 282 | d = { 283 | line.split(" ")[0] 284 | .split(" ")[1] 285 | .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) 286 | } 287 | for i in range(1, len(line.split(" "))): 288 | d = save_to_dict(line.split(" ")[i], d) 289 | val_data[num_epoch] = d 290 | elif "Train Epoch" in lines[i]: 291 | num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) 292 | loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) 293 | train_losses.append(loss) 294 | train_losses_epoch.append(num_epoch) 295 | for i in range(len(train_losses)): 296 | train_data[i] = { 297 | "num_epoch": train_losses_epoch[i], 298 | "train_loss": train_losses[i], 299 | } 300 | return train_data, val_data 301 | 302 | 303 | def save_p(obj, filename): 304 | import pickle 305 | 306 | try: 307 | from deepdiff import DeepDiff 308 | except: 309 | os.system("pip install deepdiff") 310 | from deepdiff import DeepDiff 311 | with open(filename, "wb") as file: 312 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol 313 | with open(filename, "rb") as file: 314 | z = pickle.load(file) 315 | assert ( 316 | DeepDiff(obj, z, ignore_string_case=True) == {} 317 | ), "there is something wrong with the saving process" 318 | return 319 | 320 | 321 | def load_p(filename): 322 | import pickle 323 | 324 | with open(filename, "rb") as file: 325 | z = pickle.load(file) 326 | return z 327 | 328 | 329 | def save_json(data, name="data.json"): 330 | import json 331 | with open(name, 'w') as fp: 332 | json.dump(data, fp) 333 | return 334 | 335 | 336 | def load_json(name): 337 | import json 338 | with open(name, 'r') as fp: 339 | data = json.load(fp) 340 | return data 341 | 342 | 343 | from multiprocessing import Process, Manager 344 | from multiprocessing import Process, Value, Array 345 | from ctypes import c_wchar 346 | 347 | 348 | def load_class_label(path): 349 | # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing 350 | # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array 351 | out = None 352 | if path is not None: 353 | if pathlib.Path(path).suffix in [".pkl", ".pickle"]: 354 | out = load_p(path) 355 | elif pathlib.Path(path).suffix in [".json", ".txt"]: 356 | out = load_json(path) 357 | elif pathlib.Path(path).suffix in [".npy", ".npz"]: 358 | out = np.load(path) 359 | elif pathlib.Path(path).suffix in [".csv"]: 360 | import pandas as pd 361 | out = pd.read_csv(path) 362 | return out 363 | # if out is None: 364 | # return None 365 | # else: 366 | # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) 367 | # val = Array('i', out.values(), lock=False) 368 | # return (key, val) 369 | 370 | 371 | from torch import optim 372 | 373 | 374 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): 375 | if optimizer_name.lower() == "adamw": 376 | optimizer = optim.AdamW( 377 | params, lr=lr, betas=betas, eps=eps 378 | ) 379 | elif optimizer_name.lower() == "sgd": 380 | optimizer = optim.SGD( 381 | params, lr=lr, momentum=momentum 382 | ) 383 | elif optimizer_name.lower() == "adam": 384 | optimizer = optim.Adam( 385 | params, lr=lr, betas=betas, eps=eps 386 | ) 387 | else: 388 | raise ValueError("optimizer name is not correct") 389 | return optimizer 390 | --------------------------------------------------------------------------------