├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── LICENSE ├── simple_tokenizer.py ├── README.md ├── model-card.md └── clip.py ├── configs ├── model │ ├── loss │ │ ├── ce.yaml │ │ ├── ce_lm.yaml │ │ ├── ce_cls.yaml │ │ ├── barlow.yaml │ │ ├── ce_val.yaml │ │ ├── bce.yaml │ │ ├── ce_va.yaml │ │ ├── barlow_ce.yaml │ │ └── imagine_and_classify.yaml │ ├── text │ │ ├── dummy.yaml │ │ ├── transformer.yaml │ │ ├── transformer_decoder.yaml │ │ └── transformer_val.yaml │ ├── image │ │ ├── vit.yaml │ │ ├── rn50_val.yaml │ │ └── vit_val.yaml │ └── audio │ │ ├── vit.yaml │ │ ├── deit.yaml │ │ ├── rn50_val.yaml │ │ └── vit_val.yaml ├── default.yaml ├── running │ ├── audio │ │ └── default.yaml │ ├── us8k.yaml │ ├── esc50.yaml │ ├── train.yaml │ ├── clotho.yaml │ ├── bimodal.yaml │ ├── siamese.yaml │ ├── audioset.yaml │ └── trimodal.yaml └── optimizer │ └── standard.yaml ├── cvap ├── data │ ├── audio │ │ ├── __init__.py │ │ └── transform.py │ ├── image │ │ ├── __init__.py │ │ └── transform.py │ ├── __init__.py │ ├── audioset_hub.py │ ├── image_text.py │ ├── audioset_clf.py │ ├── audio_text.py │ └── audiocaps.py ├── monitor │ ├── __init__.py │ └── siamese_va.py ├── module │ ├── decoder │ │ └── __init__.py │ ├── encoder │ │ ├── __init__.py │ │ ├── image_head.py │ │ ├── text_head.py │ │ └── audio_head.py │ ├── __init__.py │ ├── lars.py │ ├── vit.py │ ├── txt.py │ ├── transformer.py │ ├── deit.py │ └── resnet.py ├── model │ ├── __init__.py │ ├── helper.py │ ├── clvp.py │ ├── cvap.py │ ├── esc50_clf.py │ ├── clap.py │ ├── siamese_va.py │ └── audioset_clf.py └── util │ └── __init__.py ├── bash ├── run_docker.sh ├── run_bimodal_va.sh └── run_bimodal_at.sh ├── requirements.txt ├── .gitignore ├── Dockerfile ├── train.py └── README.md /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .model import * 3 | -------------------------------------------------------------------------------- /configs/model/loss/ce.yaml: -------------------------------------------------------------------------------- 1 | name: 'CELossHead' 2 | layers: [] 3 | scaling: True 4 | scale_max: null 5 | -------------------------------------------------------------------------------- /configs/model/loss/ce_lm.yaml: -------------------------------------------------------------------------------- 1 | name: 'LMLossHead' 2 | layers: [] 3 | scaling: True 4 | max_len_dec: 20 5 | -------------------------------------------------------------------------------- /configs/model/text/dummy.yaml: -------------------------------------------------------------------------------- 1 | name: 'DummyHead' 2 | freeze: True 3 | from_scratch: True 4 | ctx_len: null 5 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaoyanpeng/vipant/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /configs/model/loss/ce_cls.yaml: -------------------------------------------------------------------------------- 1 | name: 'ClassificationHead' 2 | embed_dim: ${model.image.embed_dim} 3 | layers: [] 4 | scaling: True 5 | -------------------------------------------------------------------------------- /configs/model/loss/barlow.yaml: -------------------------------------------------------------------------------- 1 | name: 'BarlowLossHead' 2 | embed_dim: ${model.image.embed_dim} 3 | lambd_off: 0.0051 4 | layers: [2048, 4096, 4096] 5 | -------------------------------------------------------------------------------- /configs/model/loss/ce_val.yaml: -------------------------------------------------------------------------------- 1 | name: 'VALCELossHead' 2 | layers: [] 3 | scaling: True 4 | scale_max: null 5 | va: True 6 | lv: False 7 | al: True 8 | -------------------------------------------------------------------------------- /cvap/data/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | _extract_kaldi_spectrogram, 3 | make_transform, 4 | FbankTransform, 5 | RandomCrop, 6 | ) 7 | -------------------------------------------------------------------------------- /configs/model/loss/bce.yaml: -------------------------------------------------------------------------------- 1 | name: 'BCELossHead' 2 | embed_dim: ${model.audio.embed_dim} 3 | width: ${model.audio.width} 4 | layers: [] 5 | scaling: True 6 | bias: False 7 | -------------------------------------------------------------------------------- /configs/model/image/vit.yaml: -------------------------------------------------------------------------------- 1 | name: 'ImageHead' 2 | freeze: True 3 | from_scratch: False 4 | embed_dim: 512 5 | resolution: 224 6 | patch_size: 32 7 | width: 768 8 | layers: 12 9 | -------------------------------------------------------------------------------- /cvap/data/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | make_clip_image_transform, 3 | AuthenticCLIPImageTransform, 4 | CLIPImageTransform, 5 | BarlowImageTransform, 6 | ) 7 | -------------------------------------------------------------------------------- /configs/model/text/transformer.yaml: -------------------------------------------------------------------------------- 1 | name: 'TextHead' 2 | freeze: True 3 | from_scratch: False 4 | from_text: True 5 | embed_dim: ${model.image.embed_dim} 6 | vocab_size: 49408 7 | ctx_len: 77 8 | layers: 12 9 | width: 512 10 | heads: 8 11 | -------------------------------------------------------------------------------- /configs/model/loss/ce_va.yaml: -------------------------------------------------------------------------------- 1 | name: 'VACELossHead' 2 | layers: [] 3 | scaling: True 4 | scale_max: null 5 | vp: True 6 | vp_w: 1. 7 | ap: False 8 | ap_w: 1. 9 | va: True 10 | va_w: 1. 11 | vv: True 12 | vv_w: 1. 13 | aa: False 14 | aa_w: 1. 15 | -------------------------------------------------------------------------------- /configs/model/loss/barlow_ce.yaml: -------------------------------------------------------------------------------- 1 | name: 'BarlowCELossHead' 2 | lambd_barlow: 0.05 3 | ce: 4 | name: 'CELossHead' 5 | barlow: 6 | name: 'BarlowLossHead' 7 | embed_dim: ${model.image.embed_dim} 8 | lambd_off: 0.0051 9 | layers: [2048, 4096, 4096] 10 | -------------------------------------------------------------------------------- /configs/model/text/transformer_decoder.yaml: -------------------------------------------------------------------------------- 1 | name: 'SeqGenerationHead' 2 | freeze: True 3 | from_scratch: False 4 | embed_dim: ${model.image.embed_dim} 5 | vocab_size: 49408 6 | ctx_len: 77 7 | layers: 12 8 | width: 512 9 | heads: 8 10 | max_len_dec: 32 11 | mem_width: ${model.audio.width} 12 | bias: True 13 | -------------------------------------------------------------------------------- /cvap/monitor/__init__.py: -------------------------------------------------------------------------------- 1 | from .audioset_clf import Monitor as ASMonitor 2 | from .siamese_va import Monitor as VASMonitor # bimodal (V-A) siamese 3 | from .esc50_clf import Monitor as ESCMonitor 4 | from .cvalp import Monitor as VALMonitor 5 | from .clap import Monitor as LAMonitor 6 | from .cvap import Monitor as VAMonitor 7 | -------------------------------------------------------------------------------- /configs/model/audio/vit.yaml: -------------------------------------------------------------------------------- 1 | name: 'NaiveCLIPAudioHead' 2 | freeze: False 3 | from_scratch: False 4 | embed_dim: ${model.image.embed_dim} 5 | meme_path: '' 6 | meme_name: '' 7 | time_first: True 8 | in_channel: 1 9 | resolution: 10 | - ${running.max_audio_len} 11 | - ${running.num_mel_bins} 12 | patch_size: ${model.image.patch_size} 13 | width: 768 14 | stride: [16, 16] 15 | layers: 12 16 | -------------------------------------------------------------------------------- /cvap/module/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss_head import build_loss_head, LOSS_HEADS_REGISTRY 2 | from .loss_more import ( 3 | LMLossHead, BCELossHead, BCHingeLossHead, ImagineAndClassifyLossHead 4 | ) 5 | LOSS_HEADS_REGISTRY.register(LMLossHead) 6 | LOSS_HEADS_REGISTRY.register(BCELossHead) 7 | LOSS_HEADS_REGISTRY.register(BCHingeLossHead) 8 | LOSS_HEADS_REGISTRY.register(ImagineAndClassifyLossHead) 9 | -------------------------------------------------------------------------------- /configs/model/audio/deit.yaml: -------------------------------------------------------------------------------- 1 | name: 'NaiveDeiTAudioHead' 2 | freeze: False 3 | from_scratch: False 4 | embed_dim: ${model.image.embed_dim} 5 | meme_path: 'facebookresearch/deit:main' 6 | meme_name: 'deit_base_distilled_patch16_224' 7 | time_first: True 8 | in_channel: 1 9 | resolution: 10 | - ${running.max_audio_len} 11 | - ${running.num_mel_bins} 12 | patch_size: 16 13 | width: 768 14 | stride: [10, 10] 15 | layers: 12 16 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | alias_root: "/net/nfs2.mosaic/yann/model/cvap" 2 | model_root: "/net/nfs2.mosaic/yann/model/cvap" 3 | model_name: "test" 4 | model_file: "00000204.pth" 5 | blockprint: False 6 | monitor: VAMonitor 7 | worker: CVAP 8 | verbose: False 9 | seed: 1213 10 | eval: True 11 | rank: -1 12 | mode: "ddp" 13 | num_proc: 0 # always 0 in ddp mode 14 | num_gpus: 4 15 | port: 22829 16 | dist_url: "tcp://localhost:${port}" 17 | -------------------------------------------------------------------------------- /cvap/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .esc50 import build_xfold_dataloader_list 2 | from .audio_text import build_audio_text_dataloader 3 | from .image_text import build_image_text_dataloader 4 | from .image_audio import build_image_audio_dataloader 5 | 6 | from .audioset_clf import build_audioset_clf_dataloader 7 | 8 | from .audioset_hub import ( 9 | build_audioset_dataloader, 10 | build_audioset_label_map, 11 | build_filter_set, 12 | ) 13 | -------------------------------------------------------------------------------- /configs/model/loss/imagine_and_classify.yaml: -------------------------------------------------------------------------------- 1 | name: 'ImagineAndClassifyLossHead' 2 | lambd_ce: 1. 3 | layers: 4 | - ${model.image.embed_dim} 5 | bias: False 6 | ce: 7 | name: 'CELossHead' 8 | alive: True 9 | scaling: True 10 | scale_max: null 11 | bce: 12 | name: 'BCELossHead' 13 | alive: True 14 | embed_dim: ${model.audio.embed_dim} 15 | width: ${model.audio.width} 16 | layers: [] 17 | scaling: True 18 | bias: False 19 | -------------------------------------------------------------------------------- /cvap/module/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_head import build_image_head, IMAGE_HEADS_REGISTRY 2 | from .audio_head import build_audio_head, AUDIO_HEADS_REGISTRY 3 | from .text_head import build_text_head, TEXT_HEADS_REGISTRY 4 | # heads initialized from CLIP 5 | from .clip_head import ( 6 | CLIPImageHead, CLIPAudioHead, CLIPTextHead 7 | ) 8 | IMAGE_HEADS_REGISTRY.register(CLIPImageHead) 9 | AUDIO_HEADS_REGISTRY.register(CLIPAudioHead) 10 | TEXT_HEADS_REGISTRY.register(CLIPTextHead) 11 | -------------------------------------------------------------------------------- /configs/running/audio/default.yaml: -------------------------------------------------------------------------------- 1 | # kaldi 2 | max_len: 1000 # seq len 3 | norms: [] 4 | eval_norms: False 5 | normalized: False 6 | dither: 0.0 7 | tile_audio: False 8 | frame_shift: 10 9 | htk_compat: True 10 | use_energy: False 11 | window_type: hanning 12 | num_mel_bins: 128 13 | zero_mean_wf: True 14 | # transform of audios 15 | transform_audio: False 16 | audio_transforms: [] 17 | # transform of fbanks 18 | transform_fbank: True 19 | fbank_transforms: 20 | - [FrequencyMasking, [32]] 21 | - [TimeMasking, [200]] 22 | -------------------------------------------------------------------------------- /bash/run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | 3 | export OMP_NUM_THREADS=32 4 | 5 | run_type=$1 6 | gpu_list=(${CUDA_VISIBLE_DEVICES//,/ }) 7 | 8 | port=$(expr $RANDOM + 1000) 9 | ngpu=${#gpu_list[@]} 10 | 11 | if [ $ngpu -eq 0 ]; then 12 | ngpu=$2 13 | fi 14 | extra=$3 15 | 16 | echo ${@:4} 17 | 18 | echo "GPUs: "$CUDA_VISIBLE_DEVICES "#"$ngpu "PORT: "$port 19 | 20 | #nohup python -m torch.utils.bottleneck train.py \ 21 | python train.py port=$port num_gpus=$ngpu \ 22 | +running=$run_type $extra "${@:4}" 23 | #> profile.new 2>&1 & 24 | -------------------------------------------------------------------------------- /configs/model/image/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | name: 'CLIPImageHead' 2 | freeze: True #False # 3 | from_scratch: False #True # 4 | width: 64 5 | embed_dim: 1024 6 | resolution: 224 7 | ctx_len: ${model.text.ctx_len} 8 | encoder: 9 | name: ResNetBackbone 10 | layers: [3, 4, 6, 3] 11 | pre_encoder: 12 | name: ResNetPreEncoder 13 | in_channels: 3 14 | post_encoder: 15 | name: ResNetPostEncoder 16 | misc: 17 | name: CLIPMisc 18 | pre_encoder_addon: 19 | name: AddonEncoder 20 | post_encoder_addon: 21 | name: AddonEncoder 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy==6.0.3 2 | regex==2021.4.4 3 | tabulate==0.8.9 4 | timm==0.4.12 5 | omegaconf==2.1.0 6 | einops==0.3.0 7 | torchvision==0.9.1 8 | tqdm==4.62.3 9 | termcolor==1.1.0 10 | fvcore==0.1.5.post20210609 11 | torch==1.8.1 12 | torchaudio==0.8.1 13 | clip==0.2.0 14 | Pillow==8.4.0 15 | scikit_learn==1.0.1 16 | omegaconf==2.1.0 17 | hydra_core==1.1.0 18 | hydra==2.5 19 | pytest==6.2.4 20 | 21 | pydub 22 | soundfile 23 | youtube-dl 24 | scikit-learn 25 | scikit-video 26 | scikit-image 27 | google-cloud-storage 28 | numpy==1.19.5 29 | -------------------------------------------------------------------------------- /configs/model/text/transformer_val.yaml: -------------------------------------------------------------------------------- 1 | name: 'CLIPTextHead' 2 | freeze: True #False # 3 | from_scratch: False #True # 4 | from_text: True 5 | width: 512 6 | embed_dim: ${model.image.embed_dim} 7 | resolution: None 8 | ctx_len: 77 9 | encoder: 10 | name: TransformerBackbone 11 | layers: 12 12 | skip_attn_mask: False 13 | pre_encoder: 14 | name: GPTPreEncoder 15 | vocab_size: 49408 16 | post_encoder: 17 | name: GPTPostEncoder 18 | misc: 19 | name: CLIPMisc 20 | pre_encoder_addon: 21 | name: AddonEncoder 22 | post_encoder_addon: 23 | name: AddonEncoder 24 | -------------------------------------------------------------------------------- /configs/optimizer/standard.yaml: -------------------------------------------------------------------------------- 1 | use_lars: True 2 | name: Adam 3 | warmup: True 4 | warmup_steps: 1000 5 | warmup_epoch: 10 6 | lr: 5e-4 7 | weight_decay: 1e-6 8 | betas: [0.9, 0.999] 9 | max_norm: 0.5 10 | lr_weight: 0.2 11 | lr_bias: 0.0048 12 | batch_size: ${running.batch_size} 13 | epochs: ${running.epochs} 14 | steps: [] 15 | gamma: 0.5 16 | batch_sch: False # schedule lr per batch 17 | optimizer: [Adam, {lr: '${optimizer.lr}', betas: '${optimizer.betas}', weight_decay: '${optimizer.weight_decay}'}] 18 | scheduler: [MultiStepLR, {milestones: '${optimizer.steps}', gamma: '${optimizer.gamma}'}] 19 | -------------------------------------------------------------------------------- /configs/running/us8k.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/us8k/UrbanSound8K" 4 | prompt: "the sound of " #"label_map" 5 | data_name: 'UrbanSound8K' 6 | zero_shot: False 7 | eval_name: '' 8 | test_name: '' 9 | eval_samples: 5000 10 | test_samples: 5000 11 | peep_rate: 50 12 | save_rate: 1e9 13 | batch_size: 50 14 | epochs: 32 15 | save_epoch: True 16 | # vision backbone 17 | resolution: ${model.image.resolution} 18 | # audio backbone 19 | max_audio_len: ${running.audio.max_len} 20 | num_mel_bins: ${running.audio.num_mel_bins} 21 | -------------------------------------------------------------------------------- /configs/model/image/vit_val.yaml: -------------------------------------------------------------------------------- 1 | name: 'CLIPImageHead' 2 | freeze: True #False # 3 | from_scratch: False #True # 4 | width: 768 5 | embed_dim: 512 6 | resolution: 224 7 | ctx_len: ${model.text.ctx_len} 8 | encoder: 9 | name: TransformerBackbone 10 | layers: 12 11 | skip_attn_mask: True 12 | pre_encoder: 13 | name: ViTPreEncoder 14 | patch_size: 32 15 | stride: ${model.image.pre_encoder.patch_size} 16 | in_channels: 3 17 | post_encoder: 18 | name: ViTPostEncoder 19 | misc: 20 | name: CLIPMisc 21 | pre_encoder_addon: 22 | name: AddonEncoder 23 | post_encoder_addon: 24 | name: AddonEncoder 25 | -------------------------------------------------------------------------------- /configs/model/audio/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | name: 'CLIPAudioHead' 2 | freeze: False #True # 3 | from_scratch: False #True # 4 | width: ${model.image.width} 5 | embed_dim: ${model.image.embed_dim} 6 | resolution: 7 | - ${running.max_audio_len} 8 | - ${running.num_mel_bins} 9 | ctx_len: ${model.text.ctx_len} 10 | encoder: 11 | name: ResNetBackbone 12 | layers: ${model.image.encoder.layers} 13 | pre_encoder: 14 | name: ResNetPreEncoder 15 | in_channels: 3 16 | post_encoder: 17 | name: ResNetPostEncoder 18 | misc: 19 | name: CLIPMisc 20 | pre_encoder_addon: 21 | name: AddonEncoder 22 | post_encoder_addon: 23 | name: AddonEncoder 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | build/ 6 | dist/ 7 | 8 | 9 | # dev tools 10 | 11 | .envrc 12 | .python-version 13 | .idea 14 | 15 | 16 | # jupyter notebooks 17 | 18 | .ipynb_checkpoints 19 | 20 | 21 | # miscellaneous 22 | 23 | .cache/ 24 | doc/_build/ 25 | *.swp 26 | 27 | 28 | # python 29 | 30 | *.pyc 31 | *.pyo 32 | __pycache__ 33 | 34 | 35 | # testing and continuous integration 36 | 37 | .coverage 38 | .pytest_cache/ 39 | 40 | 41 | # hard-coded 42 | csv 43 | logs 44 | data 45 | scripts 46 | outputs 47 | *-Dockerfile 48 | requirements-* 49 | requirements_* 50 | configs/ai2-alexandria-fbf4c720d4a4.json 51 | -------------------------------------------------------------------------------- /configs/running/esc50.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/esc50" 4 | prompt: "the sound of " #"label_map" 5 | data_name: 'esc50' 6 | zero_shot: False 7 | eval_name: '' 8 | test_name: '' 9 | eval_samples: 5000 10 | test_samples: 5000 11 | peep_rate: 16 12 | save_rate: 1e9 13 | batch_size: 50 14 | epochs: 32 15 | save_epoch: True 16 | # vision backbone 17 | resolution: ${model.image.resolution} 18 | # audio backbone 19 | max_audio_len: ${running.audio.max_len} 20 | num_mel_bins: ${running.audio.num_mel_bins} 21 | excl_modules: # will be frozen 22 | amodules: [] 23 | -------------------------------------------------------------------------------- /configs/running/train.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/audioset" 4 | data_name: 'npz_unbalanced_train_segments' 5 | eval_name: 'npz_balanced_train_segments' #'eval' # 6 | test_name: 'eval' 7 | train_samples: 1. # fraction 8 | eval_samples: 5184 #5000 9 | test_samples: 5000 10 | peep_rate: 1 11 | save_rate: 300 12 | batch_size: 432 #108 #216 # 13 | epochs: 1000 14 | save_epoch: False 15 | frame_key: "frame" 16 | frame_emb: null 17 | embed_dim: ${model.image.embed_dim} 18 | # vision backbone 19 | resolution: ${model.image.resolution} 20 | # audio backbone 21 | max_audio_len: ${running.audio.max_len} 22 | num_mel_bins: ${running.audio.num_mel_bins} 23 | -------------------------------------------------------------------------------- /configs/model/audio/vit_val.yaml: -------------------------------------------------------------------------------- 1 | name: 'CLIPAudioHead' 2 | freeze: False #True # 3 | from_scratch: False #True # 4 | width: ${model.image.width} 5 | embed_dim: ${model.image.embed_dim} 6 | resolution: 7 | - ${running.max_audio_len} 8 | - ${running.num_mel_bins} 9 | ctx_len: ${model.text.ctx_len} 10 | encoder: 11 | name: TransformerBackbone 12 | layers: ${model.image.encoder.layers} 13 | skip_attn_mask: True 14 | pre_encoder: 15 | name: ViTPreEncoder 16 | patch_size: ${model.image.pre_encoder.patch_size} 17 | stride: [16, 16] 18 | in_channels: 3 19 | post_encoder: 20 | name: ViTPostEncoder 21 | misc: 22 | name: CLIPMisc 23 | pre_encoder_addon: 24 | name: AddonEncoder 25 | post_encoder_addon: 26 | name: AddonEncoder 27 | -------------------------------------------------------------------------------- /configs/running/clotho.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/clotho" 4 | prompt: "the sound of" 5 | data_name: 'clotho_captions_development' 6 | eval_name: 'clotho_captions_validation' 7 | test_name: 'clotho_captions_evaluation' 8 | eval_samples: 5000 9 | test_samples: 5000 10 | peep_rate: 16 11 | save_rate: 32 12 | batch_size: 50 13 | epochs: 32 14 | save_epoch: False 15 | np_rnd: False 16 | imagine: False 17 | retrieval: False 18 | mixup_rate: 0.0 19 | # vision backbone 20 | resolution: ${model.image.resolution} 21 | # audio backbone 22 | max_audio_len: ${running.audio.max_len} 23 | num_mel_bins: ${running.audio.num_mel_bins} 24 | 25 | # unused but to be compatible w/ audioset_cap.py 26 | frame_key: "frame" 27 | frame_emb: null 28 | clf: False 29 | weighted_sampling: False 30 | dataloader: al 31 | -------------------------------------------------------------------------------- /cvap/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .helper import * 2 | from .audioset_clf import ASClassifier 3 | from .esc50_clf import ESClassifier 4 | from .cvalp import CVALP 5 | from .clap import CLAP 6 | from .clvp import CLVP 7 | from .cvap import CVAP 8 | from .siamese_va import CVASP 9 | 10 | from fvcore.common.registry import Registry 11 | 12 | VAL_MODELS_REGISTRY = Registry("VAL_MODELS") 13 | VAL_MODELS_REGISTRY.__doc__ = """ 14 | Registry for vision-audio-language models. 15 | """ 16 | 17 | VAL_MODELS_REGISTRY.register(ASClassifier) 18 | VAL_MODELS_REGISTRY.register(ESClassifier) 19 | VAL_MODELS_REGISTRY.register(CVALP) 20 | VAL_MODELS_REGISTRY.register(CVASP) 21 | VAL_MODELS_REGISTRY.register(CLAP) 22 | VAL_MODELS_REGISTRY.register(CLVP) 23 | VAL_MODELS_REGISTRY.register(CVAP) 24 | 25 | def build_main_model(cfg, echo, **kwargs): 26 | return VAL_MODELS_REGISTRY.get(cfg.worker)(cfg, echo) 27 | -------------------------------------------------------------------------------- /bash/run_bimodal_va.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | 3 | export OMP_NUM_THREADS=32 4 | export CUDA_VISIBLE_DEVICES=$2 5 | 6 | run_type=$1 7 | gpu_list=(${CUDA_VISIBLE_DEVICES//,/ }) 8 | 9 | port=$(expr $RANDOM + 1000) 10 | ngpu=${#gpu_list[@]} 11 | 12 | mode="dp" 13 | num_proc=2 14 | 15 | echo "GPUs: "$CUDA_VISIBLE_DEVICES "#"$ngpu "PORT: "$port 16 | 17 | # bash bash/run_bimodal_va.sh bimodal 3 18 | 19 | # train: via bimodal (vision-audio) pre-training 20 | eval="bimodal_va" 21 | model_name="test" 22 | mtask=" 23 | model_name=$model_name worker=CVALP port=$port num_gpus=$ngpu mode=$mode num_proc=$num_proc eval=False verbose=True 24 | +model/image=vit_val +model/audio=vit_val +model/text=dummy +model/loss=ce +optimizer=standard +running/audio=default 25 | model.audio.pre_encoder.in_channels=3 model.audio.pre_encoder.stride=[16,24] 26 | optimizer.warmup=False running.audio.norms=[-4.93839311,5.75751113] 27 | running.epochs=1 running.batch_size=2 running.peep_rate=50 running.save_rate=100 running.eval_samples=100 28 | " 29 | 30 | # config 31 | extra="$mtask " 32 | 33 | python train.py +running=$run_type $extra 34 | -------------------------------------------------------------------------------- /configs/running/bimodal.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/audioset" 4 | data_name: 'src_unbalanced_train_segments' 5 | eval_name: 'src_balanced_train_segments' #'eval' # 6 | test_name: '' 7 | train_samples: 1. # fraction 8 | eval_samples: 5184 #5000 9 | test_samples: 5000 10 | peep_rate: 1 11 | save_rate: 1e9 12 | batch_size: 432 #108 #216 # 13 | epochs: 1000 14 | save_epoch: True 15 | multi_view: False 16 | frame_key: "frame" 17 | frame_emb: null 18 | text_emb: null 19 | imagine: True 20 | embed_dim: ${model.image.embed_dim} 21 | # vision backbone 22 | resolution: ${model.image.resolution} 23 | # audio backbone 24 | max_audio_len: ${running.audio.max_len} 25 | num_mel_bins: ${running.audio.num_mel_bins} 26 | #off # on: true interpreted by YAML, so on/off cannot be used as the keys 27 | #share, or not share; the image head as the reference 28 | siamese: 29 | alive: False 30 | keep_hp: True # keep run-time hyperparameters 31 | amodules: [] #, "pre_encoder", "post_encoder", "misc"] 32 | lmodules: [] #"encoder"] 33 | -------------------------------------------------------------------------------- /configs/running/siamese.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/audioset" 4 | data_name: 'npz_unbalanced_train_segments' 5 | eval_name: 'npz_balanced_train_segments' #'eval' # 6 | test_name: 'eval' 7 | train_samples: 1. # fraction 8 | eval_samples: 5184 #5000 9 | test_samples: 5000 10 | peep_rate: 1 11 | save_rate: 300 12 | batch_size: 432 #108 #216 # 13 | epochs: 1000 14 | save_epoch: False 15 | multi_view: True 16 | frame_key: "frame" 17 | frame_emb: null 18 | imagine: True 19 | embed_dim: ${model.image.embed_dim} 20 | # vision backbone 21 | resolution: ${model.image.resolution} 22 | clip_tf: False 23 | # audio backbone 24 | max_audio_len: ${running.audio.max_len} 25 | num_mel_bins: ${running.audio.num_mel_bins} 26 | #off # on: true interpreted by YAML, so on/off cannot be used as the keys 27 | #share, or not share; the image head as the reference 28 | siamese: 29 | alive: False 30 | keep_hp: True # keep run-time hyperparameters 31 | amodules: [] #, "pre_encoder", "post_encoder", "misc"] 32 | lmodules: [] #"encoder"] 33 | -------------------------------------------------------------------------------- /clip/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /configs/running/audioset.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/audioset" 4 | prompt: "the sound of" 5 | cat_label: False 6 | nper_label: -1 7 | filter_set: null 8 | label_map: "ontology,eval_segments" 9 | data_name: 'npz_train_toy' #'src_train_toy' # 10 | eval_name: 'npz_eval_toy' #'src_eval_toy' # 11 | test_name: 'npz_eval_toy' #'src_eval_toy' # 12 | eval_samples: 250 13 | test_samples: 250 14 | peep_rate: 1 15 | save_rate: 3 16 | batch_size: 64 17 | epochs: 1000 18 | save_epoch: False 19 | frame_key: "frame" 20 | frame_emb: null 21 | text_emb: null 22 | embed_dim: ${model.image.embed_dim} 23 | force_npz: False # force to use npz, i.e., precomputed image and audio features 24 | zero_shot: False 25 | clf: True 26 | np_rnd: False 27 | imagine: False 28 | mixup_rate: 0.5 29 | weighted_sampling: False 30 | # vision backbone 31 | resolution: ${model.image.resolution} 32 | # audio backbone 33 | max_audio_len: ${running.audio.max_len} 34 | num_mel_bins: ${running.audio.num_mel_bins} 35 | excl_modules: # will be frozen 36 | vmodules: [] 37 | amodules: [] 38 | lmodules: [] 39 | -------------------------------------------------------------------------------- /cvap/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer # has to be the first because the following depend on it 2 | # better API 3 | from .val import ( 4 | build_encoder_module, ENCODER_MODULES_REGISTRY, 5 | interp_clip_vp_embedding, 6 | interp_conv_weight_channel, 7 | interp_conv_weight_spatial, 8 | ) 9 | # deprecated API 10 | from .deit import PatchEmbed, DistilledVisionTransformer 11 | from .vit import VisualTransformer 12 | from .txt import TextualTransformer 13 | from .resnet import ModifiedResNet 14 | # optimizer 15 | from .lars import * 16 | # encoder heads 17 | from .encoder import * 18 | from .decoder import * 19 | # dummy heads 20 | import torch 21 | class DummyHead(torch.nn.Module): 22 | def __init__(self, cfg, **kwargs): 23 | super().__init__() 24 | pass 25 | def from_pretrained(self, state_dict, cfg, *args, **kwargs): 26 | pass 27 | def copy_state_dict(self, state_dict): 28 | return {}, {} 29 | def replace_modules(self, **kwargs): 30 | return [] 31 | def forward(self, x, *args, **kwargs): 32 | return None 33 | IMAGE_HEADS_REGISTRY.register(DummyHead) 34 | AUDIO_HEADS_REGISTRY.register(DummyHead) 35 | TEXT_HEADS_REGISTRY.register(DummyHead) 36 | LOSS_HEADS_REGISTRY.register(DummyHead) 37 | -------------------------------------------------------------------------------- /configs/running/trimodal.yaml: -------------------------------------------------------------------------------- 1 | clip_model_root: "/net/nfs2.mosaic/yann/model/clip" 2 | clip_model_name: "ViT-B32" 3 | data_root: "/home/yanpengz/data/audioset" 4 | prompt: "the sound of" 5 | cat_label: False 6 | filter_set: null 7 | label_map: "ontology,eval_segments" 8 | data_name: '' # 9 | eval_name: '' # 10 | test_name: '' # 11 | train_samples: 1. # fraction 12 | eval_samples: 250 13 | test_samples: 250 14 | peep_rate: 1 15 | save_rate: 1e9 16 | batch_size: 64 17 | epochs: 1000 18 | save_epoch: True 19 | frame_key: "frame" 20 | frame_emb: null 21 | text_emb: null 22 | embed_dim: ${model.image.embed_dim} 23 | force_npz: False # force to use npz, i.e., precomputed image and audio features 24 | clf: False 25 | np_rnd: False 26 | imagine: True 27 | mixup_rate: 0.0 28 | weighted_sampling: False 29 | # vision backbone 30 | resolution: ${model.image.resolution} 31 | # audio backbone 32 | max_audio_len: ${running.audio.max_len} 33 | num_mel_bins: ${running.audio.num_mel_bins} 34 | #off # on: true interpreted by YAML, so on/off cannot be used as the keys 35 | #share, or not share; the image head as the reference 36 | siamese: 37 | alive: False 38 | keep_hp: True # keep run-time hyperparameters 39 | amodules: [] #, "pre_encoder", "post_encoder", "misc"] 40 | lmodules: [] #"encoder"] 41 | -------------------------------------------------------------------------------- /bash/run_bimodal_at.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/sh 2 | 3 | root=./ 4 | 5 | export OMP_NUM_THREADS=32 6 | export CUDA_VISIBLE_DEVICES=$2 7 | 8 | run_type=$1 9 | gpu_list=(${CUDA_VISIBLE_DEVICES//,/ }) 10 | 11 | port=$(expr $RANDOM + 1000) 12 | ngpu=${#gpu_list[@]} 13 | 14 | mode="dp" 15 | num_proc=8 16 | 17 | echo "GPUs: "$CUDA_VISIBLE_DEVICES "#"$ngpu "PORT: "$port 18 | 19 | # bash bash/run_bimodal_at.sh trimodal 3 20 | # train: finetune via the audio-text task (AudioCaps) 21 | eval="bimodal_at" 22 | model_file="bimodal_16x24_00071478.pth" # VA-pre-trained model 23 | model_file="notafile" 24 | model_name="test" 25 | mtask=" 26 | model_name=$model_name monitor=VALMonitor worker=CVALP port=$port num_gpus=$ngpu mode=$mode num_proc=$num_proc eval=False verbose=True 27 | +model/image=vit_val +model/audio=vit_val +model/text=transformer_val +model/loss=ce_val +optimizer=standard +running/audio=default 28 | model.audio.pre_encoder.in_channels=3 model.audio.pre_encoder.stride=[16,24] 29 | optimizer.warmup=False running.audio.norms=[-4.93839311,5.75751113] 30 | running.siamese.alive=True running.imagine=False model.loss.va=False 31 | running.batch_size=64 running.peep_rate=1 running.prompt=\"\" 32 | 33 | model_file=$model_file 34 | +running.rnd_cap=True 35 | 36 | running.data_name=audiocaps_train running.eval_name=audiocaps_val running.test_name=audiocaps_test 37 | running.eval_samples=250 running.test_samples=250 running.train_samples=0.1 38 | " 39 | 40 | # config 41 | extra="$mtask" 42 | 43 | python train.py +running=$run_type $extra 44 | -------------------------------------------------------------------------------- /cvap/util/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | import numpy 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def seed_all_rng(seed): 9 | random.seed(seed) 10 | numpy.random.seed(seed) 11 | torch.manual_seed(seed) 12 | 13 | def setup_logger(output_dir=None, name="cvap", rank=0, output=None): 14 | logger = logging.getLogger(name) 15 | logger.setLevel(logging.INFO) 16 | logger.propagate = False 17 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 18 | if rank == 0: 19 | console = logging.StreamHandler() 20 | console.setLevel(logging.INFO) 21 | console.setFormatter(formatter) 22 | logger.addHandler(console) 23 | if os.path.exists(output_dir): 24 | logger.info(f'Warning: the folder {output_dir} exists.') 25 | else: 26 | logger.info(f'Creating {output_dir}') 27 | if rank == 0: 28 | os.makedirs(output_dir) 29 | if torch.distributed.is_initialized(): 30 | dist.barrier() # output dir should have been ready 31 | if output is not None: 32 | filename = os.path.join(output_dir, f'train_{rank}.out') 33 | handler = logging.FileHandler(filename, 'w') 34 | handler.setLevel(logging.INFO) 35 | handler.setFormatter(formatter) 36 | logger.addHandler(handler) 37 | return logger 38 | 39 | def numel(model: torch.nn.Module, trainable: bool = False): 40 | parameters = list(model.parameters()) 41 | if trainable: 42 | parameters = [p for p in parameters if p.requires_grad] 43 | unique = {p.data_ptr(): p for p in parameters}.values() 44 | return sum(p.numel() for p in unique) 45 | 46 | def detect_nan(x): 47 | return torch.isnan(x).any(), torch.isinf(x).any() 48 | 49 | class AverageMeter(object): 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.sum = self.count = 0 55 | 56 | def __call__(self, val, n=1): 57 | self.count += n 58 | self.sum += val * n 59 | 60 | @property 61 | def average(self): 62 | return self.sum / self.count 63 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.2.2-base-ubuntu20.04 2 | 3 | ARG DEBIAN_FRONTEND="noninteractive" 4 | ENV TZ="America/Los_Angeles" 5 | 6 | # Install base tools. 7 | RUN apt-get update && apt-get install -y --no-install-recommends \ 8 | language-pack-en \ 9 | build-essential \ 10 | apt-utils \ 11 | ffmpeg \ 12 | unzip \ 13 | curl \ 14 | wget \ 15 | make \ 16 | sudo \ 17 | vim \ 18 | git && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Install conda 22 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh \ 23 | && echo "935d72deb16e42739d69644977290395561b7a6db059b316958d97939e9bdf3d Miniconda3-py38_4.10.3-Linux-x86_64.sh" \ 24 | | sha256sum --check \ 25 | && bash Miniconda3-py38_4.10.3-Linux-x86_64.sh -b -p /opt/miniconda3 \ 26 | && chgrp -R users /opt/miniconda3 \ 27 | && chmod -R 750 /opt/miniconda3 \ 28 | && rm Miniconda3-py38_4.10.3-Linux-x86_64.sh 29 | 30 | ENV PATH=/opt/miniconda3/bin:/opt/miniconda3/condabin:$PATH 31 | ENV LD_LIBRARY_PATH=/usr/local/cuda/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH 32 | 33 | # Install java 34 | RUN conda install -c conda-forge openjdk 35 | 36 | # Install AWS CLI 37 | RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" \ 38 | && unzip awscliv2.zip \ 39 | && ./aws/install \ 40 | && rm awscliv2.zip 41 | 42 | # Install Google Cloud CLI 43 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" \ 44 | | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ 45 | && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ 46 | | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - \ 47 | && apt-get update -y --no-install-recommends && apt-get install google-cloud-sdk -y --no-install-recommends 48 | 49 | WORKDIR /vipant 50 | ENV PYTHONPATH=/vipant:$PYTHONPATH 51 | 52 | COPY requirements.txt /vipant 53 | RUN pip install --no-cache-dir --upgrade pip setuptools 54 | RUN pip install --no-cache-dir -r /vipant/requirements.txt 55 | 56 | COPY configs /vipant/configs 57 | COPY bash /vipant/bash 58 | COPY clip /vipant/clip 59 | COPY cvap /vipant/cvap 60 | COPY train.py /vipant 61 | 62 | RUN ls -la /vipant/* 63 | 64 | # https://stackoverflow.com/a/62313159 65 | ENTRYPOINT [ "/bin/bash", "-l", "-c" ] 66 | CMD ["ls", "./"] 67 | -------------------------------------------------------------------------------- /cvap/module/lars.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | __all__ = ["exclude_bias_or_norm", "adjust_learning_rate", "LARS"] 5 | 6 | def exclude_bias_or_norm(p): 7 | return p.ndim < 2 8 | 9 | def adjust_learning_rate(cfg, optimizer, dataloader, step): 10 | max_steps = cfg.epochs * len(dataloader) 11 | warmup_steps = int(cfg.warmup_epoch * len(dataloader)) 12 | base_lr = cfg.batch_size / 256 13 | if step < warmup_steps: 14 | lr = base_lr * step / warmup_steps 15 | else: 16 | step -= warmup_steps 17 | max_steps -= warmup_steps 18 | q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) 19 | end_lr = base_lr * 0.001 20 | lr = base_lr * q + end_lr * (1 - q) 21 | optimizer.param_groups[0]['lr'] = lr * cfg.lr_weight 22 | optimizer.param_groups[1]['lr'] = lr * cfg.lr_bias 23 | 24 | class LARS(torch.optim.Optimizer): 25 | def __init__( 26 | self, params, lr, 27 | weight_decay=0, 28 | momentum=0.9, 29 | eta=0.001, 30 | weight_decay_filter=None, 31 | lars_adaptation_filter=None 32 | ): 33 | defaults = dict( 34 | lr=lr, 35 | weight_decay=weight_decay, 36 | momentum=momentum, 37 | eta=eta, 38 | weight_decay_filter=weight_decay_filter, 39 | lars_adaptation_filter=lars_adaptation_filter 40 | ) 41 | super().__init__(params, defaults) 42 | 43 | @torch.no_grad() 44 | def step(self): 45 | for g in self.param_groups: 46 | for p in g['params']: 47 | dp = p.grad 48 | if dp is None: 49 | continue 50 | 51 | if g['weight_decay_filter'] is None or not g['weight_decay_filter'](p): 52 | dp = dp.add(p, alpha=g['weight_decay']) 53 | 54 | if g['lars_adaptation_filter'] is None or not g['lars_adaptation_filter'](p): 55 | param_norm = torch.norm(p) 56 | update_norm = torch.norm(dp) 57 | one = torch.ones_like(param_norm) 58 | q = torch.where( 59 | param_norm > 0., torch.where( 60 | update_norm > 0, 61 | (g['eta'] * param_norm / update_norm), one 62 | ), one 63 | ) 64 | dp = dp.mul(q) 65 | 66 | param_state = self.state[p] 67 | if 'mu' not in param_state: 68 | param_state['mu'] = torch.zeros_like(p) 69 | mu = param_state['mu'] 70 | mu.mul_(g['momentum']).add_(dp) 71 | 72 | p.add_(mu, alpha=-g['lr']) 73 | 74 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from omegaconf import DictConfig, OmegaConf 3 | import logging 4 | import torch 5 | import hydra 6 | 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | 10 | from torch.nn import DataParallel 11 | from torch.nn.parallel import DistributedDataParallel 12 | 13 | from cvap.util import seed_all_rng, setup_logger 14 | from cvap.monitor import * 15 | 16 | 17 | def _distributed_worker(local_rank, main_func, cfg, ddp): 18 | assert torch.cuda.is_available(), "CUDA is not available" 19 | global_rank = 0 + local_rank 20 | try: 21 | dist.init_process_group( 22 | backend="NCCL", 23 | init_method=cfg.dist_url, 24 | world_size=cfg.num_gpus, 25 | rank=global_rank, 26 | ) 27 | except Exception as e: 28 | logger = logging.getLogger(__name__) 29 | logger.error("Process group URL: {}".format(cfg.dist_url)) 30 | raise e 31 | dist.barrier() 32 | torch.cuda.set_device(local_rank) 33 | pg = dist.new_group(range(cfg.num_gpus)) 34 | device = torch.device('cuda', local_rank) 35 | main_func(cfg, local_rank, ddp, pg, device, DDPMonitor) 36 | 37 | 38 | def main(cfg, rank, ddp, pg, device, manager): 39 | cfg.rank = rank 40 | seed_all_rng(cfg.seed) # + rank) 41 | 42 | output_dir = f"{cfg.alias_root}/{cfg.model_name}" 43 | logger = setup_logger( 44 | output_dir=output_dir, rank=rank, output=output_dir, 45 | ) 46 | 47 | if cfg.verbose or not cfg.eval: 48 | cfg_str = OmegaConf.to_yaml(cfg) 49 | logger.info(f"\n\n{cfg_str}") 50 | if cfg.blockprint: 51 | # https://stackoverflow.com/a/8391735 52 | sys.stdout = open(os.devnull, 'w') 53 | 54 | ngpu = torch.cuda.device_count() 55 | logger.info("World size: {}; rank: {}".format(ngpu, rank)) 56 | 57 | torch.backends.cudnn.benchmark=True 58 | 59 | if isinstance(manager, str): 60 | monitor = eval(manager)(cfg, logger.info, device) 61 | else: 62 | monitor = manager(cfg, logger.info, device) 63 | monitor.learn() 64 | 65 | 66 | @hydra.main(config_path="configs", config_name="default") 67 | def train(cfg: DictConfig) -> None: 68 | if cfg.mode == "dp": 69 | cfg.rank = 0 70 | torch.cuda.set_device(0) 71 | main(cfg, 0, False, False, torch.device('cuda', 0), cfg.monitor) 72 | elif cfg.mode == "ddp": 73 | try: 74 | mp.spawn( 75 | _distributed_worker, 76 | nprocs = cfg.num_gpus, 77 | args = (main, cfg, False), 78 | daemon = False, 79 | ) 80 | except KeyboardInterrupt as e: 81 | dist.destroy_process_group() 82 | else: 83 | cfg.rank = 0 84 | torch.cuda.set_device(0) 85 | main(cfg, 0, False, False, torch.device('cuda', 0), None) 86 | 87 | if __name__ == "__main__": 88 | train() 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VIP-ANT: VIsually-Pivoted Audio and(N) Text 2 | 3 | Code for the paper *[Connecting the Dots between Audio and Text without Parallel Data through Visual Knowledge Transfer](https://arxiv.org/abs/2112.08995)* @ NAACL 2022. 4 | 5 | ![VIP-ANT pivots audio and text via visual imagination.](https://drive.google.com/uc?id=13PnfOt4U6f86et-ebHs-nTWwMj0W6Ycv) 6 | 7 | ## Data 8 | 9 | [AudioSet](https://research.google.com/audioset/) can be downloaded and preprocessed via this [tool](https://github.com/zhaoyanpeng/audioset-dl). 10 | 11 | ### AudioSet Data 12 | 13 | See [AudioSet](https://github.com/zhaoyanpeng/audioset-dl/blob/beta/note/audioset.md). It elaborates on our customized index files for pre-training on AudioSet. 14 | 15 | ### Curated Audio-Text Data 16 | 17 | See [AudioTxt](https://github.com/zhaoyanpeng/audioset-dl/blob/beta/note/audiotxt.md). It elaborates on our curation methods and customized index files for audio-text fine-tuning. 18 | 19 | ## Vision-Audio (VA) Pre-training 20 | 21 | Check out the running script `bash/run_bimodal_va.sh`. 22 | 23 | ## Audio-Text (AT) Fine-tuning 24 | 25 | Check out the running script `bash/run_bimodal_at.sh`. Fine-tuning starts with a VA pre-trained [audio encoder](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FFQTZK9YBPRDQHHR6157AGBR/00071478.pth). 26 | 27 | ## Pre-trained Models 28 | 29 | We provide a checkpoint that performs best for each task. 30 | 31 | ### Pre-trained Models for Audio-Text Retrieval 32 | 33 | | Model | AudioCaps | Clotho (18s) | Clotho (10s) | 34 | |:-:|-:|-:|-:| 35 | | VIP-ANT | [00051623](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FFQTZK9YBPRDQHHR6157AGBR/00051623.pth) | [00043681](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FFQTZK9YBPRDQHHR6157AGBR/00043681.pth) | [00043681](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FFQTZK9YBPRDQHHR6157AGBR/00043681.pth) | 36 | | +AT w/ GC | [00006210](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FM2Y9HJ896B2G6NKRXTEVXZ7/00006210.pth) | [00006900](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FM2Y9HJ896B2G6NKRXTEVXZ7/00006900.pth) | [00004140](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FM2Y9HJ896B2G6NKRXTEVXZ7/00004140.pth) | 37 | 38 | ### Pre-trained Models for Zero-shot Audio Classification 39 | 40 | | Model | ESC50 (w/ prompt) | US8K (w/ prompt) | 41 | |:-:|-:|-:| 42 | | VIP-ANT | [00083391](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FFQTZK9YBPRDQHHR6157AGBR/00083391.pth) | [00079420](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FFQTZK9YBPRDQHHR6157AGBR/00079420.pth) | 43 | | +AT w/ GC | [00004140](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FM2Y9HJ896B2G6NKRXTEVXZ7/00004140.pth) | [00004140](https://storage.googleapis.com/ai2-mosaic-public/projects/vipant/model/01FM2Y9HJ896B2G6NKRXTEVXZ7/00004140.pth) | 44 | 45 | ## Dependencies 46 | 47 | `Dockerfile` defines minimum dependencies of the repo. 48 | 49 | ## License 50 | MIT 51 | -------------------------------------------------------------------------------- /cvap/module/encoder/image_head.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | from fvcore.common.registry import Registry 4 | from omegaconf.listconfig import ListConfig 5 | 6 | import copy 7 | import threading 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from timm.models.vision_transformer import _cfg 14 | 15 | from .. import ModifiedResNet, VisualTransformer, DistilledVisionTransformer 16 | 17 | IMAGE_HEADS_REGISTRY = Registry("IMAGE_HEADS") 18 | IMAGE_HEADS_REGISTRY.__doc__ = """ 19 | Registry for image encoders. 20 | """ 21 | 22 | def build_image_head(cfg, **kwargs): 23 | return IMAGE_HEADS_REGISTRY.get(cfg.name)(cfg, **kwargs) 24 | 25 | @IMAGE_HEADS_REGISTRY.register() 26 | class ImageHead(nn.Module): 27 | def __init__(self, cfg, **kwargs): 28 | super().__init__() 29 | if isinstance(cfg.layers, (tuple, list, ListConfig)): 30 | heads = cfg.width * 32 // 64 31 | self.encoder = ModifiedResNet( 32 | input_resolution=cfg.resolution, 33 | output_dim=cfg.embed_dim, 34 | layers=cfg.layers, 35 | width=cfg.width, 36 | heads=heads, 37 | ) 38 | else: 39 | heads = cfg.width // 64 40 | self.encoder = VisualTransformer( 41 | input_resolution=cfg.resolution, 42 | output_dim=cfg.embed_dim, 43 | patch_size=cfg.patch_size, 44 | layers=cfg.layers, 45 | width=cfg.width, 46 | heads=heads, 47 | ) 48 | 49 | def copy_state_dict(self, state_dict): 50 | self.encoder.load_state_dict(state_dict) 51 | 52 | def forward(self, images, *args, **kwargs): 53 | z = self.encoder(images) 54 | if kwargs.get("normalized", False): 55 | z = z / z.norm(dim=-1, keepdim=True) 56 | #print(f"{threading.current_thread().ident} image --{kwargs.get('normalized', False)}") 57 | return z 58 | 59 | @IMAGE_HEADS_REGISTRY.register() 60 | class DeiTImageHead(nn.Module): 61 | def __init__(self, cfg, **kwargs): 62 | super().__init__() 63 | heads = cfg.width // 64 64 | self.encoder = DistilledVisionTransformer( 65 | img_size=cfg.resolution, 66 | patch_size=cfg.patch_size, 67 | representation_size=cfg.embed_dim, 68 | embed_dim=cfg.width, 69 | depth=cfg.layers, 70 | num_heads=heads, 71 | mlp_ratio=4, 72 | qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 74 | **kwargs 75 | ) 76 | 77 | def copy_state_dict(self, state_dict): 78 | self.encoder.load_state_dict(state_dict) 79 | 80 | def forward(self, images, *args, **kwargs): 81 | cls_z, distilled_z = self.encoder.forward_features(images) 82 | z = (cls_z + distilled_z) / 2 83 | if kwargs.get("normalized", False): 84 | z = z / z.norm(dim=-1, keepdim=True) 85 | #print(f"{threading.current_thread().ident} image --{kwargs.get('normalized', False)}") 86 | return z 87 | -------------------------------------------------------------------------------- /cvap/model/helper.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import torch 4 | from collections import OrderedDict 5 | 6 | from clip import load 7 | 8 | __all__ = ["load_checkpoint", "load_clip", "load_meme", "extract_model_file"] 9 | 10 | def load_checkpoint(cfg, echo): 11 | model_file = f"{cfg.model_root}/{cfg.model_name}/{cfg.model_file}" 12 | try: 13 | checkpoint = torch.load(model_file, map_location="cpu") 14 | echo(f"Loading from {model_file}") 15 | except Exception as e: 16 | echo(f"Failed to load the checkpoint `{model_file}`") 17 | return (None,) * 5 18 | local_cfg = checkpoint["cfg"] 19 | local_str = OmegaConf.to_yaml(local_cfg) 20 | if cfg.verbose: 21 | echo(f"Old configs:\n\n{local_str}") 22 | nmodule = len(checkpoint["model"]) 23 | if nmodule == 2: 24 | audio_head_sd, loss_head_sd = checkpoint["model"] 25 | return local_cfg, None, audio_head_sd, None, loss_head_sd 26 | elif nmodule == 4: 27 | image_head_sd, audio_head_sd, text_head_sd, loss_head_sd = checkpoint["model"] 28 | return local_cfg, image_head_sd, audio_head_sd, text_head_sd, loss_head_sd 29 | else: 30 | raise ValueError(f"I don't know how to parse the checkpoint: # module is {nmodule}.") 31 | 32 | def load_clip(local_cfg, cfg, echo): 33 | try: # try image / text backbone 34 | rcfg = cfg.running 35 | model, _ = load( 36 | rcfg.clip_model_name, rcfg.clip_model_root, device="cpu", jit=False 37 | ) 38 | image_head_sd = model.visual.state_dict() if local_cfg is None else None 39 | text_head_sd = OrderedDict() 40 | for k, v in model.state_dict().items(): 41 | if k.startswith("visual") or k == "logit_scale": 42 | continue 43 | #k = re.sub("^transformer\.", "encoder.", k) 44 | text_head_sd[k] = v 45 | from_scratch = False 46 | except Exception as e: 47 | echo(f"Will learn from scratch because: {e}") 48 | model = image_head_sd = text_head_sd = None 49 | from_scratch = True 50 | return from_scratch, image_head_sd, text_head_sd, model 51 | 52 | def load_meme(cfg, echo): 53 | try: # try image / text backbone 54 | acfg = cfg.model.audio 55 | model = torch.hub.load(acfg.meme_path, acfg.meme_name, pretrained=True) 56 | image_head_sd = model.state_dict() 57 | with_meme = True 58 | except Exception as e: 59 | meme_name = getattr(acfg, "meme_name", None) 60 | echo(f"Failed to load the meme `{meme_name}` because: {e}") 61 | image_head_sd = None 62 | with_meme = False 63 | return with_meme, image_head_sd 64 | 65 | def extract_model_file(cfg, echo): 66 | model_files = list() 67 | log_file = f"{cfg.model_root}/{cfg.model_name}/{cfg.model_file}" 68 | try: 69 | pattern = '.?Saving the checkpoint to.*\/([0-9]+\.pth)' 70 | with open(log_file, "r") as fr: 71 | for line in fr: 72 | ret = re.search(pattern, line) 73 | if ret is not None: 74 | model_files.append(ret.groups(0)[0]) 75 | except Exception as e: 76 | echo(f"Failed to extract model files from `{log_file}`") 77 | return model_files 78 | -------------------------------------------------------------------------------- /cvap/module/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from clip import LayerNorm 6 | from . import Transformer 7 | 8 | class VisualTransformer(nn.Module): 9 | def __init__( 10 | self, 11 | input_resolution: int, 12 | patch_size: int, 13 | width: int, 14 | layers: int, 15 | heads: int, 16 | output_dim: int, 17 | in_channels=3, 18 | stride=None 19 | ): 20 | super().__init__() 21 | self.input_resolution = input_resolution 22 | self.output_dim = output_dim 23 | stride = stride or patch_size 24 | if isinstance(stride, int): 25 | stride = [stride] * 2 26 | if isinstance(patch_size, int): 27 | patch_size = [patch_size] * 2 28 | stride = list(stride) 29 | patch_size = list(patch_size) 30 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=stride, bias=False) 31 | 32 | scale = width ** -0.5 33 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 34 | if isinstance(input_resolution, int): 35 | positions = (input_resolution // patch_size[0]) ** 2 + 1 36 | else: 37 | row_stride, col_stride = stride[:2] 38 | nrow = (input_resolution[0] - patch_size[0]) // row_stride + 1 39 | ncol = (input_resolution[1] - patch_size[1]) // col_stride + 1 40 | positions = nrow * ncol + 1 41 | self.position_resolution = (nrow, ncol) 42 | self.positional_embedding = nn.Parameter(scale * torch.randn(positions, width)) 43 | self.ln_pre = LayerNorm(width) 44 | 45 | self.transformer = Transformer(width, layers, heads) 46 | 47 | self.ln_post = LayerNorm(width) 48 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 49 | 50 | @property 51 | def dtype(self): 52 | return self.conv1.weight.dtype 53 | 54 | def forward(self, x: torch.Tensor, require_feature: bool=False): 55 | x = x.type(self.dtype) 56 | if x.shape[1] != self.conv1.weight.shape[1]: # interpolate weight 57 | conv1_weight = self.conv1.weight.mean(1, keepdim=True) 58 | x = F.conv2d( 59 | x, conv1_weight, bias=self.conv1.bias, stride=self.conv1.stride 60 | ) 61 | else: 62 | x = self.conv1(x) # shape = [*, width, grid, grid] 63 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 64 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 65 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 66 | x = x + self.positional_embedding.to(x.dtype) 67 | x = self.ln_pre(x) 68 | 69 | x = x.permute(1, 0, 2) # NLD -> LND 70 | x = self.transformer(x) 71 | x = x.permute(1, 0, 2) # LND -> NLD 72 | 73 | x = x_feature = self.ln_post(x) 74 | 75 | if self.proj is not None: 76 | x = x[:, 0, :] @ self.proj 77 | 78 | if require_feature: 79 | return x, x_feature[:, 1:] 80 | 81 | return x 82 | 83 | -------------------------------------------------------------------------------- /cvap/module/txt.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | from fvcore.common.registry import Registry 4 | 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from clip import LayerNorm 12 | from . import Transformer 13 | 14 | class TextualTransformer(nn.Module): 15 | def __init__( 16 | self, 17 | width: int, 18 | layers: int, 19 | heads: int, 20 | ctx_len: int, 21 | vocab_size: int, 22 | output_dim: int, 23 | require_inter_attn: bool = False, 24 | ): 25 | super().__init__() 26 | self.ctx_len = ctx_len 27 | self.transformer = Transformer( 28 | width=width, 29 | layers=layers, 30 | heads=heads, 31 | attn_mask=self.build_attention_mask(), 32 | require_inter_attn=require_inter_attn, 33 | ) 34 | 35 | self.vocab_size = vocab_size 36 | self.token_embedding = nn.Embedding(vocab_size, width) 37 | self.positional_embedding = nn.Parameter(torch.empty(ctx_len, width)) 38 | self.ln_final = LayerNorm(width) 39 | 40 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 41 | 42 | self.initialize_parameters() 43 | 44 | def initialize_parameters(self): 45 | nn.init.normal_(self.token_embedding.weight, std=0.02) 46 | nn.init.normal_(self.positional_embedding, std=0.01) 47 | 48 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 49 | attn_std = self.transformer.width ** -0.5 50 | fc_std = (2 * self.transformer.width) ** -0.5 51 | for block in self.transformer.resblocks: 52 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 53 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 54 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 55 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 56 | 57 | if self.text_projection is not None: 58 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 59 | 60 | def build_attention_mask(self): 61 | # lazily create causal attention mask, with full attention between the vision tokens 62 | # pytorch uses additive attention mask; fill with -inf 63 | mask = torch.empty(self.ctx_len, self.ctx_len) 64 | mask.fill_(float("-inf")) 65 | mask.triu_(1) # zero out the lower diagonal 66 | return mask 67 | 68 | @property 69 | def dtype(self): 70 | return self.token_embedding.weight.dtype 71 | 72 | def forward(self, text, positional_embedding=None, memory=None, require_feature=False): 73 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 74 | 75 | positional_embedding = positional_embedding or self.positional_embedding 76 | positional_embedding = positional_embedding[:x.shape[1]] 77 | x = x + positional_embedding.type(self.dtype) 78 | x = x.permute(1, 0, 2) # NLD -> LND 79 | x = self.transformer(x, memory) 80 | x = x.permute(1, 0, 2) # LND -> NLD 81 | x = x_feature = self.ln_final(x).type(self.dtype) 82 | 83 | # x.shape = [batch_size, n_ctx, transformer.width] 84 | # take features from the eot embedding (eot_token is the highest number in each sequence) 85 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 86 | 87 | if require_feature: 88 | return x, x_feature 89 | 90 | return x 91 | 92 | -------------------------------------------------------------------------------- /cvap/module/transformer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from clip import LayerNorm, QuickGELU 10 | 11 | class ResidualAttentionBlock(nn.Module): 12 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 13 | super().__init__() 14 | 15 | self.attn = nn.MultiheadAttention(d_model, n_head) 16 | self.ln_1 = LayerNorm(d_model) 17 | self.mlp = nn.Sequential(OrderedDict([ 18 | ("c_fc", nn.Linear(d_model, d_model * 4)), 19 | ("gelu", QuickGELU()), 20 | ("c_proj", nn.Linear(d_model * 4, d_model)) 21 | ])) 22 | self.ln_2 = LayerNorm(d_model) 23 | self.attn_mask = attn_mask 24 | 25 | def attention(self, x: torch.Tensor): 26 | if self.attn_mask is not None: 27 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) 28 | attn_mask = self.attn_mask[:x.shape[0], :x.shape[0]] 29 | else: 30 | attn_mask = None 31 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 32 | 33 | def forward(self, x: torch.Tensor): 34 | x = x + self.attention(self.ln_1(x)) 35 | x = x + self.mlp(self.ln_2(x)) 36 | return x 37 | 38 | class GeneralResidualAttentionBlock(nn.Module): 39 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, require_inter_attn = False): 40 | super().__init__() 41 | 42 | self.attn = nn.MultiheadAttention(d_model, n_head) 43 | self.ln_1 = LayerNorm(d_model) 44 | self.mlp = nn.Sequential(OrderedDict([ 45 | ("c_fc", nn.Linear(d_model, d_model * 4)), 46 | ("gelu", QuickGELU()), 47 | ("c_proj", nn.Linear(d_model * 4, d_model)) 48 | ])) 49 | self.ln_2 = LayerNorm(d_model) 50 | self.attn_mask = attn_mask 51 | 52 | self.require_inter_attn = require_inter_attn 53 | if self.require_inter_attn: 54 | self.attn_inter_ln = LayerNorm(d_model) 55 | self.attn_inter = nn.MultiheadAttention(d_model, n_head) 56 | 57 | def attention(self, x: torch.Tensor): 58 | if self.attn_mask is not None: 59 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) 60 | attn_mask = self.attn_mask[:x.shape[0], :x.shape[0]] 61 | else: 62 | attn_mask = None 63 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 64 | 65 | def forward(self, x): 66 | if isinstance(x, tuple): 67 | x, memory = x 68 | else: 69 | memory = None 70 | x = x + self.attention(self.ln_1(x)) 71 | if self.require_inter_attn: 72 | x = self.attn_inter_ln(x) 73 | x = x + self.attn_inter(x, memory, memory, need_weights=False)[0] 74 | x = x + self.mlp(self.ln_2(x)) 75 | return x, memory 76 | 77 | class Transformer(nn.Module): 78 | def __init__( 79 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, require_inter_attn: bool = False 80 | ): 81 | super().__init__() 82 | self.width = width 83 | self.layers = layers 84 | self.resblocks = nn.Sequential(*[ 85 | #ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers) 86 | GeneralResidualAttentionBlock(width, heads, attn_mask, require_inter_attn) for _ in range(layers) 87 | ]) 88 | 89 | def forward(self, x: torch.Tensor, memory: torch.Tensor=None): 90 | #return self.resblocks(x) 91 | return self.resblocks((x, memory))[0] 92 | -------------------------------------------------------------------------------- /cvap/model/clvp.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import torch 4 | from torch import nn 5 | 6 | import torch.distributed as dist 7 | from torch.nn.parallel import data_parallel 8 | from collections import defaultdict, OrderedDict 9 | 10 | from ..module import ( 11 | build_image_head, build_audio_head, build_text_head, build_loss_head 12 | ) 13 | from . import ( 14 | load_checkpoint, load_clip, load_meme 15 | ) 16 | 17 | from clip import load 18 | 19 | class CLVP(nn.Module): 20 | def __init__(self, cfg, echo): 21 | super(CLVP, self).__init__() 22 | self.cfg = cfg 23 | self.echo = echo 24 | 25 | def forward(self, images, text, *args, **kwargs): 26 | if kwargs.get("retrieval", False): # if it is a retrieval task 27 | return self.forward_retrieval(images, text, *args, **kwargs) 28 | else: 29 | raise ValueError("Only support retrieval.") 30 | 31 | def forward_retrieval(self, images, text, *args, **kwargs): 32 | device_ids = kwargs.get("device_ids", [0]) 33 | # how to asynchronize the two `data_parallel` 34 | kwargs = {"normalized": self.loss_head.normalized, "names": kwargs.get("names", None)} 35 | image_features = audio_features = text_features = None 36 | dummy_image = list(images.shape[1:]) == [1, 1, 1] 37 | if images is not None and self.image_head is not None and not dummy_image: 38 | image_features = data_parallel( 39 | self.image_head, images, device_ids=device_ids, module_kwargs=kwargs 40 | ) 41 | elif images is not None: # pre-computed unnormalized features 42 | if self.loss_head.normalized and not dummy_image: 43 | images = images / images.norm(dim=-1, keepdim=True) 44 | image_features = images # dummy images will be ignored 45 | text_features = data_parallel( 46 | self.text_head, text, device_ids=device_ids, module_kwargs=kwargs 47 | ) 48 | loss = self.loss_head(image_features, text_features, **kwargs) 49 | return loss 50 | 51 | def collect_audio_state_dict(self): 52 | return (dict(),) * 2 53 | 54 | def collect_state_dict(self): 55 | return (dict(),) * 3 56 | 57 | def report(self, gold_file=None): 58 | if not dist.is_initialized() or dist.get_rank() == 0: 59 | return self.loss_head.report(gold_file=gold_file) 60 | else: 61 | return "" 62 | 63 | def build(self, **kwargs): 64 | tunable_params = dict() 65 | if self.cfg.eval: 66 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 67 | from_scratch, image_head_sd, text_head_sd, _ = load_clip(None, self.cfg, self.echo) 68 | 69 | # image_head's parameters as the reference 70 | self.image_head = build_image_head(self.cfg.model.image) 71 | if not from_scratch and not self.cfg.model.image.from_scratch: 72 | n_o, o_n = self.image_head.copy_state_dict(image_head_sd) 73 | msg = f" except {n_o}" if len(n_o) > 0 else "" 74 | self.echo(f"Initialize image encoder from `image_head`{msg}.") 75 | if self.cfg.running.frame_emb is not None or not self.cfg.running.imagine: 76 | self.image_head = None 77 | self.echo("Destory image encoder.") 78 | 79 | self.text_head = build_text_head(self.cfg.model.text) # 80 | n_o, o_n = self.text_head.copy_state_dict(text_head_sd) 81 | msg = f" except {n_o}" if len(n_o) > 0 else "" 82 | self.echo(f"Initialize text encoder from `text_head`{msg}.") 83 | 84 | self.loss_head = build_loss_head(self.cfg.model.loss, **kwargs) 85 | if loss_head_sd is not None: 86 | self.loss_head.copy_state_dict(loss_head_sd) # 87 | 88 | self.cuda(self.cfg.rank) 89 | else: 90 | raise ValueError("Not implemented yet.") 91 | -------------------------------------------------------------------------------- /cvap/module/deit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | import timm 7 | from timm.models.layers import to_2tuple, trunc_normal_ 8 | from timm.models.vision_transformer import VisionTransformer 9 | 10 | class PatchEmbed(nn.Module): 11 | """ 2D Image to Patch Embedding 12 | """ 13 | def __init__( 14 | self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, stride=None 15 | ): 16 | super().__init__() 17 | if isinstance(patch_size, dict): # hack 18 | patch_size, stride = patch_size["patch_size"], patch_size["stride"] 19 | img_size = list(to_2tuple(img_size)) 20 | patch_size = list(to_2tuple(patch_size)) 21 | self.img_size = img_size 22 | self.patch_size = patch_size 23 | 24 | stride = stride or patch_size 25 | if isinstance(stride, int): 26 | stride = [stride] * 2 27 | stride = list(stride) 28 | 29 | row_stride, col_stride = stride[:2] 30 | nrow = (img_size[0] - patch_size[0]) // row_stride + 1 31 | ncol = (img_size[1] - patch_size[1]) // col_stride + 1 32 | 33 | self.grid_size = (nrow, ncol) 34 | self.num_patches = self.grid_size[0] * self.grid_size[1] 35 | self.flatten = flatten 36 | 37 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) 38 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 39 | 40 | def forward(self, x): 41 | B, C, H, W = x.shape 42 | assert H == self.img_size[0] and W == self.img_size[1], \ 43 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 44 | if x.shape[1] != self.proj.weight.shape[1]: # interpolate weight 45 | conv1_weight = self.proj.weight.mean(1, keepdim=True) 46 | x = F.conv2d( 47 | x, conv1_weight, bias=self.proj.bias, stride=self.proj.stride 48 | ) 49 | else: 50 | x = self.proj(x) 51 | if self.flatten: 52 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 53 | x = self.norm(x) 54 | return x 55 | 56 | class DistilledVisionTransformer(VisionTransformer): 57 | def __init__(self, *args, output_dim=None, **kwargs): 58 | super().__init__(*args, **kwargs) 59 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 60 | num_patches = self.patch_embed.num_patches 61 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 62 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 63 | 64 | trunc_normal_(self.dist_token, std=.02) 65 | trunc_normal_(self.pos_embed, std=.02) 66 | self.head_dist.apply(self._init_weights) 67 | 68 | scale = self.embed_dim ** -0.5 69 | self.proj = nn.Parameter(scale * torch.randn(self.embed_dim, output_dim)) if output_dim is not None else None 70 | 71 | def forward_features(self, x): 72 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 73 | # with slight modifications to add the dist_token 74 | B = x.shape[0] 75 | x = self.patch_embed(x) 76 | 77 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 78 | dist_token = self.dist_token.expand(B, -1, -1) 79 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 80 | 81 | x = x + self.pos_embed 82 | x = self.pos_drop(x) 83 | 84 | for blk in self.blocks: 85 | x = blk(x) 86 | 87 | x = self.norm(x) 88 | # non-linear operator because of Tanh activation function, it might be desired because we want to use 89 | # classification head as the head for contrastive learning 90 | # still, we only want a simple projection layer 91 | if self.proj is not None: 92 | x = x[:, :2] @ self.proj 93 | else: 94 | x = self.pre_logits(x) 95 | return x[:, 0], x[:, 1] 96 | 97 | def forward(self, x): 98 | x, x_dist = self.forward_features(x) 99 | x = self.head(x) 100 | x_dist = self.head_dist(x_dist) 101 | if self.training: 102 | return x, x_dist 103 | else: 104 | # during inference, return the average of both classifier predictions 105 | return (x + x_dist) / 2 106 | 107 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /cvap/module/encoder/text_head.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | from fvcore.common.registry import Registry 4 | 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from clip import LayerNorm 12 | from .. import TextualTransformer 13 | 14 | TEXT_HEADS_REGISTRY = Registry("TEXT_HEADS") 15 | TEXT_HEADS_REGISTRY.__doc__ = """ 16 | Registry for text encoders. 17 | """ 18 | 19 | def build_text_head(cfg, **kwargs): 20 | return TEXT_HEADS_REGISTRY.get(cfg.name)(cfg, **kwargs) 21 | 22 | @TEXT_HEADS_REGISTRY.register() 23 | class TextHead(nn.Module): 24 | def __init__(self, cfg, **kwargs): 25 | super().__init__() 26 | self.encoder = TextualTransformer( 27 | width=cfg.width, 28 | layers=cfg.layers, 29 | heads=cfg.heads, 30 | ctx_len=cfg.ctx_len, 31 | vocab_size=cfg.vocab_size, 32 | output_dim=cfg.embed_dim, 33 | ) 34 | 35 | def copy_state_dict(self, state_dict): 36 | self.encoder.load_state_dict(state_dict) 37 | return {}, {} 38 | 39 | def forward(self, text, *args, **kwargs): 40 | positional_embedding = kwargs.get("positional_embedding", None) 41 | z = self.encoder(text, positional_embedding=positional_embedding) 42 | if kwargs.get("normalized", False): 43 | z = z / z.norm(dim=-1, keepdim=True) 44 | #print(f"{threading.current_thread().ident} image --{kwargs.get('normalized', False)}") 45 | return z 46 | 47 | @TEXT_HEADS_REGISTRY.register() 48 | class SeqGenerationHead(nn.Module): 49 | def __init__(self, cfg, **kwargs): 50 | super().__init__() 51 | self.encoder = TextualTransformer( 52 | width=cfg.width, 53 | layers=cfg.layers, 54 | heads=cfg.heads, 55 | ctx_len=cfg.ctx_len, 56 | vocab_size=cfg.vocab_size, 57 | output_dim=cfg.embed_dim, 58 | require_inter_attn=True, 59 | ) 60 | width = cfg.width 61 | scale = width ** -0.5 62 | self.mem_ln = LayerNorm(width) 63 | self.to_txt = nn.Parameter(scale * torch.randn(cfg.mem_width, cfg.width)) 64 | self.predictor = nn.Linear(width, self.encoder.vocab_size, bias=cfg.bias) 65 | self.max_len_dec = cfg.max_len_dec 66 | 67 | def copy_state_dict(self, state_dict): 68 | excluded = [] 69 | new_dict = self.encoder.state_dict() 70 | old_dict = {k: v for k, v in state_dict.items() if k not in excluded} 71 | new_keys = set(new_dict.keys()) 72 | old_keys = set(old_dict.keys()) 73 | new_dict.update(old_dict) 74 | self.encoder.load_state_dict(new_dict) 75 | n_o = new_keys - old_keys 76 | o_n = old_keys - new_keys 77 | #print(f"{n_o}\n{o_n}") 78 | return n_o, o_n 79 | 80 | def infer(self, x, positional_embedding, memory): 81 | beg_len = 0 82 | max_len = self.max_len_dec - beg_len 83 | logits = list() 84 | indice = torch.arange(0, x.shape[0], 5, device=x.device) 85 | x = x[indice] 86 | if beg_len > 0: # gold prefix and fake logits 87 | all_ctx = x[:, :beg_len + 1] 88 | logit = torch.zeros(( 89 | all_ctx.size(0), beg_len, self.encoder.vocab_size 90 | ), device=x.device) 91 | logit = logit.scatter(2, all_ctx[:, 1:].unsqueeze(-1), 10) 92 | logits.append(logit) 93 | else: # the start symbol 94 | all_ctx = x[:, :1] 95 | 96 | for istep in range(beg_len, max_len): 97 | _, features = self.encoder( 98 | all_ctx, positional_embedding=positional_embedding, memory=memory, require_feature=True 99 | ) 100 | logit = self.predictor(features[:, -1:]) 101 | logits.append(logit) 102 | 103 | new_ctx = logit.argmax(dim=-1) 104 | all_ctx = torch.cat((all_ctx, new_ctx), 1) 105 | 106 | logits = torch.cat(logits, dim=1) 107 | return x, logits, all_ctx 108 | 109 | def forward(self, text, audio, time_first, *args, **kwargs): 110 | # layer-normed audio: (N, nrow, ncol, D) 111 | audio = audio @ self.to_txt # project to the textual space 112 | audio = audio.mean(2) if time_first else audio.mean(1) 113 | audio = self.mem_ln(audio).permute(1, 0, 2) # NLD -> LND 114 | # text conditional on audio 115 | positional_embedding = kwargs.get("positional_embedding", None) 116 | 117 | if not self.training: 118 | return self.infer(text, positional_embedding, audio) 119 | 120 | z, features = self.encoder( 121 | text, positional_embedding=positional_embedding, memory=audio, require_feature=True 122 | ) 123 | logits = self.predictor(features) # compute cross-entropy loss 124 | logits = logits[:, :-1] 125 | 126 | if kwargs.get("normalized", False): 127 | z = z / z.norm(dim=-1, keepdim=True) 128 | #print(f"{threading.current_thread().ident} image --{kwargs.get('normalized', False)}") 129 | return z, logits, None 130 | -------------------------------------------------------------------------------- /cvap/module/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from clip import Bottleneck 6 | 7 | class AttentionPool2d(nn.Module): 8 | def __init__(self, positions: int, embed_dim: int, num_heads: int, output_dim: int = None): 9 | super().__init__() 10 | self.positional_embedding = nn.Parameter(torch.randn(positions, embed_dim) / embed_dim ** 0.5) 11 | self.k_proj = nn.Linear(embed_dim, embed_dim) 12 | self.q_proj = nn.Linear(embed_dim, embed_dim) 13 | self.v_proj = nn.Linear(embed_dim, embed_dim) 14 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 15 | self.num_heads = num_heads 16 | 17 | def forward(self, x): 18 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 19 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 20 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 21 | x, _ = F.multi_head_attention_forward( 22 | query=x, key=x, value=x, 23 | embed_dim_to_check=x.shape[-1], 24 | num_heads=self.num_heads, 25 | q_proj_weight=self.q_proj.weight, 26 | k_proj_weight=self.k_proj.weight, 27 | v_proj_weight=self.v_proj.weight, 28 | in_proj_weight=None, 29 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 30 | bias_k=None, 31 | bias_v=None, 32 | add_zero_attn=False, 33 | dropout_p=0, 34 | out_proj_weight=self.c_proj.weight, 35 | out_proj_bias=self.c_proj.bias, 36 | use_separate_proj_weight=True, 37 | training=self.training, 38 | need_weights=False 39 | ) 40 | 41 | return x[0] 42 | 43 | class ModifiedResNet(nn.Module): 44 | """ 45 | A ResNet class that is similar to torchvision's but contains the following changes: 46 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 47 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 48 | - The final pooling layer is a QKV attention instead of an average pool 49 | """ 50 | 51 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64, in_channels=3): 52 | super().__init__() 53 | self.output_dim = output_dim 54 | self.input_resolution = input_resolution 55 | 56 | # the 3-layer stem 57 | self.conv1 = nn.Conv2d(in_channels, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(width // 2) 59 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(width // 2) 61 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(width) 63 | self.avgpool = nn.AvgPool2d(2) 64 | self.relu = nn.ReLU(inplace=True) 65 | 66 | # residual layers 67 | self._inplanes = width # this is a *mutable* variable used during construction 68 | self.layer1 = self._make_layer(width, layers[0]) 69 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 70 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 71 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 72 | 73 | if isinstance(input_resolution, int): 74 | positions = (input_resolution // 32) ** 2 + 1 75 | else: 76 | nrow = (input_resolution[0] - 0) // 32 77 | ncol = (input_resolution[1] - 0) // 32 78 | positions = nrow * ncol + 1 79 | self.position_resolution = (nrow, ncol) 80 | 81 | embed_dim = width * 32 # the ResNet feature dimension 82 | self.attnpool = AttentionPool2d(positions, embed_dim, heads, output_dim) 83 | self.initialize_parameters() 84 | 85 | def _make_layer(self, planes, blocks, stride=1): 86 | layers = [Bottleneck(self._inplanes, planes, stride)] 87 | 88 | self._inplanes = planes * Bottleneck.expansion 89 | for _ in range(1, blocks): 90 | layers.append(Bottleneck(self._inplanes, planes)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def initialize_parameters(self): 95 | std = self.attnpool.c_proj.in_features ** -0.5 96 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 97 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 98 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 99 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 100 | 101 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 102 | for name, param in resnet_block.named_parameters(): 103 | if name.endswith("bn3.weight"): 104 | nn.init.zeros_(param) 105 | 106 | @property 107 | def dtype(self): 108 | return self.conv1.weight.dtype 109 | 110 | def forward(self, x): 111 | def stem(x): 112 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 113 | x = self.relu(bn(conv(x))) 114 | x = self.avgpool(x) 115 | return x 116 | 117 | x = x.type(self.dtype) 118 | x = stem(x) 119 | x = self.layer1(x) 120 | x = self.layer2(x) 121 | x = self.layer3(x) 122 | x = self.layer4(x) 123 | x = self.attnpool(x) 124 | 125 | return x 126 | 127 | -------------------------------------------------------------------------------- /cvap/model/cvap.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import torch 4 | from torch import nn 5 | 6 | import torch.distributed as dist 7 | from torch.nn.parallel import data_parallel 8 | 9 | from ..module import ( 10 | build_image_head, build_audio_head, build_text_head, build_loss_head 11 | ) 12 | from . import ( 13 | load_checkpoint, load_clip, load_meme 14 | ) 15 | 16 | from clip import load 17 | 18 | class CVAP(nn.Module): 19 | def __init__(self, cfg, echo): 20 | super(CVAP, self).__init__() 21 | self.cfg = cfg 22 | self.echo = echo 23 | 24 | def forward(self, images, audios, *args, **kwargs): 25 | device_ids = kwargs.get("device_ids", [0]) 26 | # how to asynchronize the two `data_parallel` 27 | kwargs = {"normalized": self.loss_head.normalized, "names": kwargs.get("names", None)} 28 | if self.image_head is not None: 29 | image_features = data_parallel( 30 | self.image_head, images, device_ids=device_ids, module_kwargs=kwargs 31 | ) 32 | else: # pre-computed unnormalized features 33 | if self.loss_head.normalized: 34 | images = images / images.norm(dim=-1, keepdim=True) 35 | image_features = images 36 | audio_features = data_parallel( 37 | self.audio_head, audios, device_ids=device_ids, module_kwargs=kwargs 38 | ) 39 | loss = self.loss_head(image_features, audio_features, **kwargs) 40 | return loss 41 | 42 | def collect_audio_state_dict(self): 43 | return ( 44 | self.audio_head.state_dict(), 45 | self.loss_head.state_dict(), 46 | ) 47 | 48 | def collect_state_dict(self): 49 | return ( 50 | self.image_head.state_dict(), 51 | self.audio_head.state_dict(), 52 | self.loss_head.state_dict(), 53 | ) 54 | 55 | def report(self, gold_file=None): 56 | if not dist.is_initialized() or dist.get_rank() == 0: 57 | return self.loss_head.report(gold_file=gold_file) 58 | else: 59 | return "" 60 | 61 | def build(self): 62 | tunable_params = dict() 63 | if self.cfg.eval: 64 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 65 | from_scratch, image_head_sd, _, _ = load_clip(None, self.cfg, self.echo) 66 | 67 | self.image_head = build_image_head(self.cfg.model.image) 68 | self.image_head.copy_state_dict(image_head_sd) 69 | 70 | self.audio_head = build_audio_head(local_cfg.model.audio) 71 | self.audio_head.load_state_dict(audio_head_sd) 72 | 73 | self.loss_head = build_loss_head(local_cfg.model.loss) 74 | self.loss_head.load_state_dict(loss_head_sd) 75 | self.cuda(self.cfg.rank) 76 | else: 77 | # try pre-trained model! 78 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 79 | # try clip! TODO do we always have to load CLIP? 80 | from_scratch, image_head_sd, _, model = load_clip(local_cfg, self.cfg, self.echo) 81 | # try meme! 82 | with_meme, meme_image_head_sd = load_meme(self.cfg, self.echo) 83 | 84 | self.image_head = build_image_head(self.cfg.model.image) 85 | if not from_scratch and not self.cfg.model.image.from_scratch: 86 | self.image_head.copy_state_dict(image_head_sd) 87 | self.echo("Initialize image encoder from `image_head`.") 88 | if self.cfg.running.frame_emb is not None: 89 | self.image_head = None 90 | self.echo("Destory image encoder.") 91 | 92 | self.audio_head = build_audio_head(self.cfg.model.audio) 93 | if not self.cfg.model.audio.from_scratch: 94 | if local_cfg is not None: 95 | # TODO better to use `from_pretrained()` 96 | self.audio_head.load_state_dict(audio_head_sd) 97 | self.echo("Initialize audio encoder from `audio_head`.") 98 | elif not from_scratch: 99 | if with_meme: # higher priority 100 | msg = " `meme_image_head`" 101 | n_o, o_n = self.audio_head.copy_state_dict(meme_image_head_sd) 102 | else: 103 | msg = " `image_head`" 104 | n_o, o_n = self.audio_head.copy_state_dict(image_head_sd) 105 | msg += f" except {n_o}" if len(n_o) > 0 else "" 106 | self.echo(f"Initialize audio encoder from{msg}.") 107 | else: 108 | self.echo("Have to learn from scratch.") 109 | 110 | self.loss_head = build_loss_head(self.cfg.model.loss) 111 | if not from_scratch and not self.cfg.model.audio.from_scratch: 112 | extra_sd = {"logit_scale": model.logit_scale} 113 | self.loss_head.copy_state_dict(extra_sd) 114 | 115 | tunable_params = { 116 | f"audio_head.{k}": v for k, v in self.audio_head.named_parameters() 117 | } 118 | tunable_params.update({ 119 | f"loss_head.{k}": v for k, v in self.loss_head.named_parameters() 120 | }) 121 | if not self.cfg.model.image.freeze and self.image_head is not None: 122 | tunable_params.update({ 123 | f"image_head.{k}": v for k, v in self.image_head.named_parameters() 124 | }) 125 | elif self.image_head is not None: 126 | self.echo("Freeze image encoder.") 127 | self.cuda(self.cfg.rank) 128 | return tunable_params 129 | 130 | -------------------------------------------------------------------------------- /cvap/model/esc50_clf.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import warnings 4 | from typing import Union, List 5 | from collections import defaultdict, OrderedDict 6 | 7 | import copy 8 | import time 9 | import torch 10 | import numpy as np 11 | from torch import nn 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | import torch.distributed as dist 16 | from torch.nn.parallel import data_parallel 17 | from torch.nn.parallel import DistributedDataParallel 18 | 19 | from clip import load 20 | 21 | from ..module import ( 22 | build_image_head, build_audio_head, build_text_head, build_loss_head 23 | ) 24 | from . import ( 25 | load_checkpoint, load_clip, load_meme 26 | ) 27 | 28 | 29 | class ESClassifier(nn.Module): 30 | def __init__(self, cfg, echo): 31 | super(ESClassifier, self).__init__() 32 | self.cfg = cfg 33 | self.echo = echo 34 | 35 | def forward(self, audios, labels, *args, **kwargs): 36 | device_ids = kwargs.get("device_ids", [0]) 37 | # how to asynchronize the two `data_parallel` 38 | kwargs = {"normalized": self.loss_head.normalized, "names": kwargs.get("names", None)} 39 | audio_features = data_parallel( 40 | self.audio_head, audios, device_ids=device_ids, module_kwargs=kwargs 41 | ) 42 | loss = self.loss_head(audio_features, labels, **kwargs) 43 | return loss 44 | 45 | def encode_text(self, text, *args, device_ids=[0], **kwargs): 46 | text_features = data_parallel( 47 | self.text_head, text, device_ids=device_ids, module_kwargs=kwargs 48 | ) 49 | return text_features 50 | 51 | def collect_audio_state_dict(self): 52 | return ( 53 | self.audio_head.state_dict(), 54 | self.loss_head.state_dict(), 55 | ) 56 | 57 | def report(self, gold_file=None, **kwargs): 58 | if not dist.is_initialized() or dist.get_rank() == 0: 59 | return self.loss_head.report(gold_file=gold_file, **kwargs) 60 | else: 61 | return "" 62 | 63 | def build(self, **kwargs): 64 | tunable_params = dict() 65 | if self.cfg.eval: 66 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 67 | from_scratch, image_head_sd, text_head_sd, _ = load_clip(None, self.cfg, self.echo) 68 | 69 | self.audio_head = build_audio_head(self.cfg.model.audio) 70 | if audio_head_sd is not None: 71 | n_o, o_n = self.audio_head.from_pretrained(audio_head_sd, local_cfg) 72 | msg = f" except {n_o}" if len(n_o) > 0 else "" 73 | self.echo(f"Initialize audio encoder from `audio_head`{msg}.") 74 | else: 75 | self.audio_head.copy_state_dict(image_head_sd) 76 | self.echo("Initialize audio encoder from `image_head`.") 77 | 78 | self.text_head = build_text_head(self.cfg.model.text) # 79 | n_o, o_n = self.text_head.copy_state_dict(text_head_sd) 80 | msg = f" except {n_o}" if len(n_o) > 0 else "" 81 | self.echo(f"Initialize text encoder from `text_head`{msg}.") 82 | 83 | self.loss_head = build_loss_head(self.cfg.model.loss, **kwargs) 84 | if loss_head_sd is not None: 85 | self.loss_head.copy_state_dict(loss_head_sd) # 86 | 87 | self.cuda(self.cfg.rank) 88 | else: 89 | # try pre-trained model! 90 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 91 | # try clip! TODO do we always have to load CLIP? 92 | from_scratch, image_head_sd, _, model = load_clip(None, self.cfg, self.echo) 93 | # try meme! 94 | with_meme, meme_image_head_sd = load_meme(self.cfg, self.echo) 95 | 96 | self.audio_head = build_audio_head(self.cfg.model.audio) 97 | if not self.cfg.model.audio.from_scratch: 98 | if local_cfg is not None: 99 | # TODO better to use `from_pretrained()` 100 | self.audio_head.from_pretrained(audio_head_sd, local_cfg) 101 | self.echo("Initialize audio encoder from `audio_head`.") 102 | elif not from_scratch: 103 | if with_meme: # higher priority 104 | msg = " `meme_image_head`" 105 | n_o, o_n = self.audio_head.copy_state_dict(meme_image_head_sd) 106 | else: 107 | msg = " `image_head`" 108 | n_o, o_n = self.audio_head.copy_state_dict(image_head_sd) 109 | msg += f" except {n_o}" if len(n_o) > 0 else "" 110 | self.echo(f"Initialize audio encoder from{msg}.") 111 | else: 112 | self.echo("Have to learn from scratch.") 113 | 114 | self.loss_head = build_loss_head(self.cfg.model.loss, **kwargs) 115 | tunable_params = { 116 | f"loss_head.{k}": v for k, v in self.loss_head.named_parameters() 117 | } 118 | if not self.cfg.model.audio.freeze: 119 | excl_modules = set(self.cfg.running.excl_modules.amodules) 120 | pattern = "|".join([f"^{m}\." for m in excl_modules]) 121 | tunable_params.update({ 122 | f"audio_head.{k}": v for k, v in self.audio_head.named_parameters() 123 | if pattern == "" or not re.match(pattern, k)}) # filter out excluded parameters 124 | self.echo(f"Tune audio encoder (excl. {excl_modules}).") 125 | else: 126 | self.echo("Freeze audio encoder.") 127 | self.cuda(self.cfg.rank) 128 | return tunable_params 129 | 130 | -------------------------------------------------------------------------------- /cvap/data/audioset_hub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import torch 5 | import itertools 6 | import torchaudio 7 | import numpy as np 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from itertools import cycle, islice, chain 11 | from einops import rearrange, repeat 12 | from collections import defaultdict 13 | 14 | import multiprocessing as mp 15 | import torch.utils.data as data 16 | import torch.nn.functional as F 17 | 18 | from .audio import ( 19 | make_transform, _extract_kaldi_spectrogram 20 | ) 21 | from .audiocaps import AudioCapDatasetSrc 22 | from .audioset_cls import AudiosetNpz, AudiosetSrc 23 | from .audioset_clf import AudiosetDatasetNpz, ImageAudioCollator 24 | from clip import tokenize 25 | 26 | ### 27 | # this file ocntains the very first implementations of AudioSet data loader. 28 | # we now have clf-focused loader in audioset_clf.py and audioset_ast.py 29 | # and contrastive-focused loader in image_audio.py and this file. 30 | ### 31 | 32 | def build_filter_set(data_root, filter_set): 33 | try: # filters can be None 34 | name, topk = filter_set.split(",") 35 | filter_file = f"{data_root}/{name}" 36 | if filter_file[-3:] == "csv": 37 | samples = set() 38 | with open(filter_file, "r") as fr: 39 | for line in fr: 40 | line = line.strip() 41 | samples.add(line) 42 | elif filter_file[-1] == "k": 43 | samples_per_label = json.load(open(filter_file, "r")) 44 | samples = set() 45 | for k, v in samples_per_label.items(): 46 | samples.update(v) 47 | else: 48 | topk = int(topk) 49 | samples = set() 50 | with open(filter_file, "r") as fr: 51 | for line in fr: 52 | line = json.loads(line) 53 | k, v = list(line.items())[0] 54 | new_samples = set([name for name, _ in v[:topk]] + [k]) 55 | samples.update(new_samples) 56 | except Exception as e: 57 | samples = None 58 | return samples 59 | 60 | def collect_ytid(csv_root, csv_list): 61 | ids = defaultdict(list) 62 | nrow = 0 63 | for fname in csv_list: 64 | ifile = f"{csv_root}/{fname}.csv" 65 | with open(ifile, "r") as fr: 66 | for _ in range(3): 67 | next(fr) 68 | for irow, row in enumerate(fr): 69 | row = row.split(", ") 70 | ids[row[0].strip()].append( 71 | (row[1].strip(), row[2].strip(), row[3].strip('" \n').split(",")) 72 | ) 73 | nrow += 1 74 | return list(ids.keys()), ids 75 | 76 | def build_audioset_label_map(data_root, label_map="ontology,eval_segments", prompt=""): 77 | file_list = label_map.split(",") 78 | ontology, label_files = file_list[0], file_list[1:] 79 | label_path = f"{data_root}/{ontology}.json" 80 | label_real = f"{data_root}/{label_files[0]}.csv" 81 | assert os.path.isfile(label_path) and os.path.isfile(label_real), ( 82 | "please specify a valid `ontology` file (ontology) and `eval` file (eval_segments)." 83 | ) 84 | category_list = list() 85 | ontology = json.load(open(label_path, "r")) 86 | prompt = "" if prompt.strip() == "" else prompt.strip() + " " 87 | for item in ontology: 88 | category = item["id"] 89 | category_list.append( 90 | (category, prompt + item["name"].lower()) 91 | ) 92 | text_list = [item[1] for item in category_list] 93 | label_int = tokenize(text_list, as_list=True) 94 | category_list = [item + (label_int[i],) for i, item in enumerate(category_list)] 95 | #label_map = {category_list[i][0]: (i, category_list[i][1], label_int[i]) for i in range(len(category_list))} 96 | 97 | _, ytid_dict = collect_ytid(data_root, label_files) 98 | 99 | label_set = set(itertools.chain.from_iterable( 100 | v[0][2] for _, v in ytid_dict.items() 101 | )) 102 | category_list = [item for item in category_list if item[0] in label_set] 103 | label_map = {category_list[i][0]: (i,) + category_list[i][1:] for i in range(len(category_list))} 104 | #print(text_list, len(label_set)) 105 | #print(label_map, len(label_map)) 106 | return label_map 107 | 108 | def build_audioset_dataloader(cfg, data_name, label_map, shuffle=True, train=True, external_text=None, filters=None): 109 | ddp_mode = torch.distributed.is_initialized() 110 | rcfg = cfg.running 111 | if data_name.startswith("src"): 112 | if not rcfg.force_npz: 113 | dataset = AudiosetSrc(rcfg, data_name, train, label_map, False, external_text=external_text, filter_set=filters) 114 | else: 115 | dataset = AudiosetNpz(rcfg, data_name, train, label_map, False, external_text=external_text) 116 | elif data_name.startswith("npz"): 117 | dataset = AudiosetDatasetNpz(rcfg, data_name, train, label_map, False) 118 | elif data_name.startswith("audiocaps"): # audio captioning 119 | dataset = AudioCapDatasetSrc(rcfg, data_name, train, label_map) 120 | else: 121 | raise ValueError(f"unrecognized data file `{data_name}`.") 122 | if ddp_mode: 123 | assert cfg.optimizer.batch_size % cfg.num_gpus == 0 124 | sampler = torch.utils.data.distributed.DistributedSampler( 125 | dataset, shuffle=shuffle 126 | ) 127 | per_device_batch_size = cfg.optimizer.batch_size // cfg.num_gpus 128 | else: 129 | sampler = ( 130 | torch.utils.data.RandomSampler(dataset) if shuffle else 131 | torch.utils.data.SequentialSampler(dataset) 132 | ) 133 | per_device_batch_size = cfg.optimizer.batch_size 134 | dataloader = torch.utils.data.DataLoader( 135 | dataset, 136 | batch_size=per_device_batch_size, 137 | collate_fn=ImageAudioCollator(), 138 | num_workers=(0 if ddp_mode else cfg.num_proc), 139 | pin_memory=True, 140 | sampler=sampler, 141 | drop_last=(True if ddp_mode else False), 142 | ) 143 | return sampler, dataloader 144 | -------------------------------------------------------------------------------- /cvap/data/image_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | import json 5 | import torch 6 | import warnings 7 | import itertools 8 | import torchaudio 9 | import numpy as np 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | from PIL import Image as PILImage 13 | from itertools import cycle, islice, chain 14 | from einops import rearrange, repeat 15 | 16 | import multiprocessing as mp 17 | import torch.utils.data as data 18 | import torch.nn.functional as F 19 | 20 | from .audio import ( 21 | make_transform, _extract_kaldi_spectrogram 22 | ) 23 | from .image import make_clip_image_transform as make_image_transform 24 | from .audio_text import ( 25 | build_dataloader, build_audiocaps_data_list, AudioTextDatasetSrc 26 | ) 27 | from clip import tokenize 28 | 29 | class ImageTextDatasetSrc(AudioTextDatasetSrc): 30 | """ `__getitem__' loads raw file from disk. 31 | """ 32 | def __init__(self, cfg, data_list, train): 33 | super().__init__(cfg, data_list, train) 34 | self.frame_key = cfg.frame_key 35 | self.transform_image = make_image_transform(cfg.resolution) 36 | 37 | def _image2embed(self, fname): 38 | try: 39 | image = np.load(fname)["v"] 40 | except Exception as e: 41 | image = np.random.rand(self.cfg.embed_dim).astype("float32") 42 | warnings.warn(f"use random image instead because `{e}` {fname}.") 43 | return image 44 | 45 | def _image2numpy(self, fname): 46 | if fname is not None: 47 | try: 48 | if fname.endswith(".npz"): 49 | images = np.load(fname) 50 | images = [images[key] for key in images.files if len(images[key]) != 0] 51 | idx = np.random.choice(len(images), 1)[0] if self.train else int(np.ceil(len(images) / 2)) - 1 52 | image = images[idx] 53 | else: 54 | image = PILImage.open(fname) 55 | image = self.transform_image(image).cpu().numpy() 56 | except Exception as e: 57 | h = w = self.cfg.resolution 58 | image = PILImage.fromarray( 59 | (np.random.rand(h, w, 3) * 256).astype(np.uint8) 60 | ) 61 | warnings.warn(f"use random image instead because `{e}` {fname}.") 62 | image = self.transform_image(image).cpu().numpy() 63 | else: 64 | image = np.array([[[1]]]) 65 | return image 66 | 67 | def __getitem__(self, index): 68 | akey = self.aclip_key 69 | fkey = self.frame_key 70 | name = self.dataset[index]["id"] 71 | sub_dir = self.dataset[index]["dir"] 72 | label_str = self.dataset[index]["label_str"] 73 | label_int = self.dataset[index]["label_int_bpe"] 74 | aclip = self.dataset[index][akey][0] 75 | frame = images = self.dataset[index][fkey] 76 | 77 | sub_dir = "" if len(sub_dir) == 0 else f"{sub_dir}/" 78 | aclip = aclip if aclip == name else f"{akey}/{name}.{aclip}" 79 | aclip_file = f"{self.cfg.data_root}/{sub_dir}{aclip}" 80 | 81 | # image 82 | frame_emb_file = None 83 | if isinstance(frame, str): 84 | frame_file = f"{self.cfg.data_root}/{sub_dir}{fkey}/{name}.{frame}" 85 | else: 86 | idx = np.random.choice(len(images), 1)[0] if self.train else int(np.ceil(len(images) / 2)) - 1 87 | frame_file = f"{self.cfg.data_root}/{sub_dir}{fkey}/{name}.{images[idx]}" 88 | if self.cfg.frame_emb is not None: 89 | frame_emb_file = f"{self.cfg.data_root}/{self.cfg.frame_emb}/{name}.{images[idx].rsplit('.', 1)[0]}.npz" 90 | # higher priority for pre-computed frame embeddings 91 | image = self._image2embed(frame_emb_file) if frame_emb_file is not None else self._image2numpy(frame_file) 92 | 93 | # audio 94 | audio = self._audio2numpy_cst(aclip_file) 95 | 96 | if not self.cfg.audio.eval_norms and len(self.audio_norms) == 2: 97 | mean, std = self.audio_norms 98 | audio = (audio - mean) / std 99 | 100 | #if self.train and self.transform_fbank is not None: 101 | if not self.cfg.audio.eval_norms and self.train and self.transform_fbank is not None: 102 | audio = self.transform_fbank(audio) 103 | 104 | if self.train: 105 | idx = np.random.choice(len(label_int), 1)[0] 106 | text = label_int[idx] 107 | else: 108 | text = label_int 109 | 110 | audio = audio[None] 111 | image = image[None] 112 | item = {"image": image, "audio": audio, "text": text, "name": name} 113 | return item 114 | 115 | def __len__(self): 116 | return self.length 117 | 118 | class ImageTextCollator: 119 | def __init__(self, device=torch.device("cpu")): 120 | # RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned 121 | # when pin_memory is true, the collator has to return CPU tensors 122 | self.device = device 123 | 124 | def __call__(self, records): 125 | union = { 126 | k: [record.get(k) for record in records] for k in set().union(*records) 127 | } 128 | name = union["name"] 129 | text_list = union["text"] 130 | if isinstance(text_list[0][0], int): # train 131 | pass 132 | """ https://stackoverflow.com/a/43149308 133 | lengths = [len(x) for x in text_list] 134 | max_len = max(lengths) 135 | text = np.zeros((len(text_list), max_len), int) 136 | mask = np.arange(max_len) < np.array(lengths)[:, None] 137 | text[mask] = np.concatenate(text_list) 138 | """ 139 | elif isinstance(text_list[0][0], list): # test 140 | text_list = list(itertools.chain.from_iterable(text_list)) 141 | #name = list(itertools.chain.from_iterable(name)) 142 | else: 143 | raise ValueError(f"unrecognized `{type(text_list[0][0])}`") 144 | # https://stackoverflow.com/a/38619333 145 | text = np.array(list(itertools.zip_longest(*text_list, fillvalue=0))).T 146 | return ( 147 | np.concatenate(union["image"], axis=0), 148 | text, 149 | name, 150 | ) 151 | 152 | def build_dataloader_audiocaps(cfg, data_name, shuffle=True, train=True): 153 | name_list = data_name.split(",") 154 | dataset = list() 155 | for name in name_list: 156 | subset = build_audiocaps_data_list(cfg.running, name) 157 | dataset.extend(subset) 158 | return build_dataloader(cfg, dataset, ImageTextDatasetSrc, shuffle=shuffle, train=train, collator_cls=ImageTextCollator) 159 | 160 | def build_image_text_dataloader(cfg, data_name, *args, shuffle=True, train=True, **kwargs): 161 | if data_name.startswith("audiocaps"): # can only do w/ AudioCaps 162 | return build_dataloader_audiocaps( 163 | cfg, data_name, shuffle=shuffle, train=train 164 | ) 165 | else: 166 | raise ValueError(f"unrecognized dataset `{data_name}`.") 167 | 168 | -------------------------------------------------------------------------------- /cvap/model/clap.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import torch 4 | from torch import nn 5 | 6 | import torch.distributed as dist 7 | from torch.nn.parallel import data_parallel 8 | from collections import defaultdict, OrderedDict 9 | 10 | from ..module import ( 11 | build_image_head, build_audio_head, build_text_head, build_loss_head 12 | ) 13 | from . import ( 14 | load_checkpoint, load_clip, load_meme 15 | ) 16 | 17 | from clip import load 18 | 19 | class CLAP(nn.Module): 20 | def __init__(self, cfg, echo): 21 | super(CLAP, self).__init__() 22 | self.cfg = cfg 23 | self.echo = echo 24 | 25 | def forward(self, audios, text, *args, **kwargs): 26 | if kwargs.get("retrieval", False): # if it is a retrieval task 27 | return self.forward_retrieval(audios, text, *args, **kwargs) 28 | device_ids = kwargs.get("device_ids", [0]) 29 | # how to asynchronize the two `data_parallel` 30 | kwargs = {"normalized": False, "require_feature": True, "names": kwargs.get("names", None)} 31 | _, audio_features = data_parallel( 32 | self.audio_head, audios, device_ids=device_ids, module_kwargs=kwargs 33 | ) 34 | time_first = True #self.audio_head.time_first 35 | text_input = (text, audio_features, time_first) 36 | _, logits, predictions = data_parallel( 37 | self.text_head, text_input, device_ids=device_ids, module_kwargs=kwargs 38 | ) 39 | loss = self.loss_head(logits, text[:, 1:], predictions, **kwargs) 40 | return loss 41 | 42 | def forward_retrieval(self, audios, text, *args, **kwargs): 43 | device_ids = kwargs.get("device_ids", [0]) 44 | # how to asynchronize the two `data_parallel` 45 | kwargs = {"normalized": self.loss_head.normalized, "names": kwargs.get("names", None)} 46 | audio_features = data_parallel( 47 | self.audio_head, audios, device_ids=device_ids, module_kwargs=kwargs 48 | ) 49 | text_features = data_parallel( 50 | self.text_head, text, device_ids=device_ids, module_kwargs=kwargs 51 | ) 52 | loss = self.loss_head(audio_features, text_features, **kwargs) 53 | return loss 54 | 55 | def encode_text(self, text, *args, device_ids=[0], **kwargs): 56 | text_features = data_parallel( 57 | self.text_head, text, device_ids=device_ids, module_kwargs=kwargs 58 | ) 59 | return text_features 60 | 61 | def collect_audio_state_dict(self): 62 | return ( 63 | self.audio_head.state_dict(), 64 | self.loss_head.state_dict(), 65 | ) 66 | 67 | def collect_state_dict(self): 68 | return ( 69 | self.audio_head.state_dict(), 70 | self.text_head.state_dict(), 71 | self.loss_head.state_dict(), 72 | ) 73 | 74 | def report(self, gold_file=None): 75 | if not dist.is_initialized() or dist.get_rank() == 0: 76 | return self.loss_head.report(gold_file=gold_file) 77 | else: 78 | return "" 79 | 80 | def build(self, **kwargs): 81 | tunable_params = dict() 82 | if self.cfg.eval: 83 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 84 | from_scratch, image_head_sd, text_head_sd, _ = load_clip(None, self.cfg, self.echo) 85 | 86 | self.audio_head = build_audio_head(self.cfg.model.audio) 87 | if audio_head_sd is not None: 88 | n_o, o_n = self.audio_head.from_pretrained(audio_head_sd, local_cfg) 89 | msg = f" except {n_o}" if len(n_o) > 0 else "" 90 | self.echo(f"Initialize audio encoder from `audio_head`{msg}.") 91 | else: 92 | self.audio_head.copy_state_dict(image_head_sd) 93 | self.echo("Initialize audio encoder from `image_head`.") 94 | 95 | self.text_head = build_text_head(self.cfg.model.text) # 96 | n_o, o_n = self.text_head.copy_state_dict(text_head_sd) 97 | msg = f" except {n_o}" if len(n_o) > 0 else "" 98 | self.echo(f"Initialize text encoder from `text_head`{msg}.") 99 | 100 | self.loss_head = build_loss_head(self.cfg.model.loss, **kwargs) 101 | if loss_head_sd is not None: 102 | self.loss_head.copy_state_dict(loss_head_sd) # 103 | 104 | self.cuda(self.cfg.rank) 105 | else: 106 | # try pre-trained model! 107 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 108 | # try clip! TODO do we always have to load CLIP? 109 | from_scratch, image_head_sd, text_head_sd, model = load_clip(local_cfg, self.cfg, self.echo) 110 | # try meme! 111 | with_meme, meme_image_head_sd = load_meme(self.cfg, self.echo) 112 | 113 | #cfg = local_cfg if local_cfg is not None else self.cfg 114 | self.audio_head = build_audio_head(self.cfg.model.audio, **kwargs) 115 | if not self.cfg.model.audio.from_scratch: 116 | if local_cfg is not None: 117 | self.audio_head.from_pretrained(audio_head_sd, local_cfg) 118 | self.echo("Initialize audio encoder from `audio_head`.") 119 | elif not from_scratch: 120 | if with_meme: # higher priority 121 | msg = " `meme_image_head`" 122 | n_o, o_n = self.audio_head.copy_state_dict(meme_image_head_sd) 123 | else: 124 | msg = " `image_head`" 125 | n_o, o_n = self.audio_head.copy_state_dict(image_head_sd) 126 | msg += f" except {n_o}" if len(n_o) > 0 else "" 127 | self.echo(f"Initialize audio encoder from{msg}.") 128 | else: 129 | self.echo("Have to learn from scratch.") 130 | 131 | self.text_head = build_text_head(self.cfg.model.text) 132 | if not from_scratch: 133 | self.text_head.copy_state_dict(text_head_sd) 134 | self.echo("Initialize text encoder from `text_head`.") 135 | 136 | self.loss_head = build_loss_head(self.cfg.model.loss) 137 | if not from_scratch and not self.cfg.model.audio.from_scratch: 138 | extra_sd = {"logit_scale": model.logit_scale} 139 | self.loss_head.copy_state_dict(extra_sd) 140 | 141 | tunable_params.update({ 142 | f"loss_head.{k}": v for k, v in self.loss_head.named_parameters() 143 | }) 144 | if not self.cfg.model.audio.freeze: 145 | tunable_params.update({ 146 | f"audio_head.{k}": v for k, v in self.audio_head.named_parameters() 147 | }) 148 | elif self.audio_head is not None: 149 | self.echo("Freeze audio encoder.") 150 | if not self.cfg.model.text.freeze: 151 | tunable_params.update({ 152 | f"text_head.{k}": v for k, v in self.text_head.named_parameters() 153 | }) 154 | elif self.text_head is not None: 155 | self.echo("Freeze text encoder.") 156 | self.cuda(self.cfg.rank) 157 | return tunable_params 158 | 159 | -------------------------------------------------------------------------------- /cvap/model/siamese_va.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import warnings 4 | from typing import Union, List 5 | from collections import defaultdict, OrderedDict 6 | 7 | import copy 8 | import time 9 | import torch 10 | import numpy as np 11 | from torch import nn 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | import torch.distributed as dist 16 | from torch.nn.parallel import data_parallel 17 | from torch.nn.parallel import DistributedDataParallel 18 | 19 | from clip import load 20 | 21 | from .cvalp import CVALP 22 | from ..module import ( 23 | build_image_head, build_audio_head, build_text_head, build_loss_head 24 | ) 25 | from . import ( 26 | load_checkpoint, load_clip, load_meme 27 | ) 28 | 29 | class CVASP(CVALP): 30 | def __init__(self, cfg, echo): 31 | super(CVASP, self).__init__(cfg, echo) 32 | 33 | def forward( 34 | self, images, images_v1, audios_v1, 35 | text_v1=None, images_v2=None, audios_v2=None, *args, **kwargs 36 | ): 37 | device_ids = kwargs.get("device_ids", [0]) 38 | # how to asynchronize the two `data_parallel` 39 | kwargs = {"normalized": self.loss_head.normalized, "names": kwargs.get("names", None)} 40 | image_features = image_features_v1 = image_features_v2 = audio_features_v1 = text_features = None 41 | if images is not None: # pre-computed unnormalized features 42 | dummy_image = list(images.shape[1:]) == [1, 1, 1] 43 | if self.loss_head.normalized and not dummy_image: 44 | images = images / images.norm(dim=-1, keepdim=True) 45 | image_features = images # dummy images will be ignored 46 | if images_v1 is not None and self.image_head is not None: 47 | image_features_v1 = data_parallel( 48 | self.image_head, images_v1, device_ids=device_ids, module_kwargs=kwargs 49 | ) 50 | if images_v2 is not None and self.image_head is not None: 51 | image_features_v2 = data_parallel( 52 | self.image_head, images_v2, device_ids=device_ids, module_kwargs=kwargs 53 | ) 54 | if audios_v1 is not None and self.audio_head is not None: 55 | audio_features_v1 = data_parallel( 56 | self.audio_head, audios_v1, device_ids=device_ids, module_kwargs=kwargs 57 | ) 58 | if text_v1 is not None and self.text_head is not None: 59 | text_features = data_parallel( 60 | self.text_head, text_v1, device_ids=device_ids, module_kwargs=kwargs 61 | ) 62 | loss = self.loss_head( 63 | image_features, image_features_v1, audio_features_v1, 64 | images_v2=image_features_v2, audios_v2=image_features_v2, **kwargs 65 | ) 66 | return loss 67 | 68 | def _build_siamese_backbone(self, **kwargs): 69 | # try pre-trained model! 70 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 71 | from_scratch, image_head_sd, text_head_sd, _ = load_clip(None, self.cfg, self.echo) 72 | 73 | # image_head's parameters as the reference 74 | self.image_head = build_image_head(self.cfg.model.image) 75 | if not from_scratch and not self.cfg.model.image.from_scratch: 76 | n_o, o_n = self.image_head.copy_state_dict(image_head_sd) 77 | msg = f" except {n_o}" if len(n_o) > 0 else "" 78 | self.echo(f"Initialize image encoder from `image_head`{msg}.") 79 | if False and (self.cfg.running.frame_emb is not None or not self.cfg.running.imagine): 80 | self.image_head = None 81 | self.echo("Destory image encoder.") 82 | scfg = self.cfg.running.siamese 83 | 84 | # shared modules with audio_head 85 | amodules = set(scfg.amodules) 86 | kwargs = { 87 | "shared_modules": amodules, "reference": self.image_head, "keep_hp": scfg.keep_hp 88 | } 89 | self.audio_head = build_audio_head(self.cfg.model.audio, **kwargs) 90 | if not self.cfg.model.audio.from_scratch: 91 | if local_cfg is not None: 92 | n_o, o_n = self.audio_head.from_pretrained(audio_head_sd, local_cfg) 93 | msg = f" except {n_o}" if len(n_o) > 0 else "" 94 | self.echo(f"Initialize audio encoder from `audio_head`{msg}.") 95 | elif not from_scratch: 96 | n_o, o_n = self.audio_head.copy_state_dict(image_head_sd) 97 | msg = f" except {n_o}" if len(n_o) > 0 else "" 98 | self.echo(f"Initialize audio encoder from `image_head`{msg}.") 99 | else: 100 | self.echo("Have to learn from scratch.") 101 | ref_modules = self.audio_head.replace_modules(**kwargs) 102 | self.echo(f"A: audio_head.modules referring to image_head.modules: {ref_modules}.") 103 | 104 | # shared modules with text_head 105 | lmodules = set(scfg.lmodules) 106 | kwargs.update({"shared_modules": lmodules}) 107 | self.text_head = build_text_head(self.cfg.model.text, **kwargs) 108 | if not from_scratch and not self.cfg.model.text.from_scratch: 109 | if self.cfg.model.text.from_text: 110 | n_o, o_n = self.text_head.copy_state_dict(text_head_sd) 111 | msg = f" except {n_o}" if len(n_o) > 0 else "" 112 | self.echo(f"Initialize text encoder from `text_head`{msg}.") 113 | else: 114 | n_o, o_n = self.text_head.copy_state_dict(image_head_sd) 115 | msg = f" except {n_o}" if len(n_o) > 0 else "" 116 | self.echo(f"Initialize text encoder from `image_head`{msg}.") 117 | ref_modules = self.text_head.replace_modules(**kwargs) 118 | self.echo(f"T: text_head.modules referring to image_head.modules: {ref_modules}.") 119 | if len(self.text_head.state_dict()) == 0: 120 | self.text_head = None 121 | self.echo("Destory text encoder.") 122 | 123 | self.loss_head = build_loss_head(self.cfg.model.loss) 124 | 125 | tunable_params = { 126 | f"loss_head.{k}": v for k, v in self.loss_head.named_parameters() 127 | } 128 | if not self.cfg.model.image.freeze and self.image_head is not None: 129 | tunable_params.update({ 130 | f"image_head.{k}": v for k, v in self.image_head.named_parameters() 131 | }) 132 | elif self.image_head is not None: 133 | shared_modules = amodules | lmodules 134 | pattern = "|".join([f"^{m}\." for m in shared_modules]) 135 | tunable_params.update({ 136 | f"image_head.{k}": v for k, v in self.image_head.named_parameters() 137 | if pattern != "" and re.match(pattern, k)}) # shared parameters must be tunable 138 | self.echo(f"Freeze image encoder (excl. shared modules: {shared_modules}).") 139 | if not self.cfg.model.audio.freeze: 140 | pattern = "|".join([f"^{m}\." for m in amodules]) 141 | tunable_params.update({ 142 | f"audio_head.{k}": v for k, v in self.audio_head.named_parameters() 143 | if pattern == "" or not re.match(pattern, k)}) # filter out shared parameters 144 | else: 145 | self.echo("Freeze audio encoder.") 146 | if not self.cfg.model.text.freeze and self.text_head is not None: 147 | pattern = "|".join([f"^{m}\." for m in lmodules]) 148 | tunable_params.update({ 149 | f"text_head.{k}": v for k, v in self.text_head.named_parameters() 150 | if pattern == "" or not re.match(pattern, k)}) # filter out shared parameters 151 | elif self.text_head is not None: 152 | self.echo("Freeze text encoder.") 153 | return tunable_params 154 | -------------------------------------------------------------------------------- /clip/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb) 4 | 5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision. 6 | 7 | 8 | 9 | ## Approach 10 | 11 | ![CLIP](CLIP.png) 12 | 13 | 14 | 15 | ## Usage 16 | 17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick: 18 | 19 | ```bash 20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 21 | $ pip install ftfy regex tqdm 22 | $ pip install git+https://github.com/openai/CLIP.git 23 | ``` 24 | 25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU. 26 | 27 | ```python 28 | import torch 29 | import clip 30 | from PIL import Image 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | model, preprocess = clip.load("ViT-B/32", device=device) 34 | 35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) 36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 37 | 38 | with torch.no_grad(): 39 | image_features = model.encode_image(image) 40 | text_features = model.encode_text(text) 41 | 42 | logits_per_image, logits_per_text = model(image, text) 43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy() 44 | 45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] 46 | ``` 47 | 48 | 49 | ## API 50 | 51 | The CLIP module `clip` provides the following methods: 52 | 53 | #### `clip.available_models()` 54 | 55 | Returns the names of the available CLIP models. 56 | 57 | #### `clip.load(name, device=..., jit=True)` 58 | 59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. 60 | 61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. 62 | 63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` 64 | 65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model 66 | 67 | --- 68 | 69 | The model returned by `clip.load()` supports the following methods: 70 | 71 | #### `model.encode_image(image: Tensor)` 72 | 73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model. 74 | 75 | #### `model.encode_text(text: Tensor)` 76 | 77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model. 78 | 79 | #### `model(image: Tensor, text: Tensor)` 80 | 81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100. 82 | 83 | 84 | 85 | ## More Examples 86 | 87 | ### Zero-Shot Prediction 88 | 89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset. 90 | 91 | ```python 92 | import os 93 | import clip 94 | import torch 95 | from torchvision.datasets import CIFAR100 96 | 97 | # Load the model 98 | device = "cuda" if torch.cuda.is_available() else "cpu" 99 | model, preprocess = clip.load('ViT-B/32', device) 100 | 101 | # Download the dataset 102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) 103 | 104 | # Prepare the inputs 105 | image, class_id = cifar100[3637] 106 | image_input = preprocess(image).unsqueeze(0).to(device) 107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) 108 | 109 | # Calculate features 110 | with torch.no_grad(): 111 | image_features = model.encode_image(image_input) 112 | text_features = model.encode_text(text_inputs) 113 | 114 | # Pick the top 5 most similar labels for the image 115 | image_features /= image_features.norm(dim=-1, keepdim=True) 116 | text_features /= text_features.norm(dim=-1, keepdim=True) 117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 118 | values, indices = similarity[0].topk(5) 119 | 120 | # Print the result 121 | print("\nTop predictions:\n") 122 | for value, index in zip(values, indices): 123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") 124 | ``` 125 | 126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device): 127 | 128 | ``` 129 | Top predictions: 130 | 131 | snake: 65.31% 132 | turtle: 12.29% 133 | sweet_pepper: 3.83% 134 | lizard: 1.88% 135 | crocodile: 1.75% 136 | ``` 137 | 138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs. 139 | 140 | 141 | ### Linear-probe evaluation 142 | 143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features. 144 | 145 | ```python 146 | import os 147 | import clip 148 | import torch 149 | 150 | import numpy as np 151 | from sklearn.linear_model import LogisticRegression 152 | from torch.utils.data import DataLoader 153 | from torchvision.datasets import CIFAR100 154 | from tqdm import tqdm 155 | 156 | # Load the model 157 | device = "cuda" if torch.cuda.is_available() else "cpu" 158 | model, preprocess = clip.load('ViT-B/32', device) 159 | 160 | # Load the dataset 161 | root = os.path.expanduser("~/.cache") 162 | train = CIFAR100(root, download=True, train=True, transform=preprocess) 163 | test = CIFAR100(root, download=True, train=False, transform=preprocess) 164 | 165 | 166 | def get_features(dataset): 167 | all_features = [] 168 | all_labels = [] 169 | 170 | with torch.no_grad(): 171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)): 172 | features = model.encode_image(images.to(device)) 173 | 174 | all_features.append(features) 175 | all_labels.append(labels) 176 | 177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy() 178 | 179 | # Calculate the image features 180 | train_features, train_labels = get_features(train) 181 | test_features, test_labels = get_features(test) 182 | 183 | # Perform logistic regression 184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) 185 | classifier.fit(train_features, train_labels) 186 | 187 | # Evaluate using the logistic regression classifier 188 | predictions = classifier.predict(test_features) 189 | accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100. 190 | print(f"Accuracy = {accuracy:.3f}") 191 | ``` 192 | 193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split. 194 | -------------------------------------------------------------------------------- /clip/model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: CLIP 2 | 3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model. 4 | 5 | ## Model Details 6 | 7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within. 8 | 9 | ### Model Date 10 | 11 | January 2021 12 | 13 | ### Model Type 14 | 15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer. 16 | 17 | ### Model Version 18 | 19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50. 20 | 21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. 22 | 23 | Please see the paper linked below for further details about their specification. 24 | 25 | ### Documents 26 | 27 | - [Blog Post](https://openai.com/blog/clip/) 28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020) 29 | 30 | 31 | 32 | ## Model Use 33 | 34 | ### Intended Use 35 | 36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis. 37 | 38 | #### Primary intended uses 39 | 40 | The primary intended users of these models are AI researchers. 41 | 42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models. 43 | 44 | ### Out-of-Scope Use Cases 45 | 46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful. 47 | 48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use. 49 | 50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases. 51 | 52 | 53 | 54 | ## Data 55 | 56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users. 57 | 58 | ### Data Mission Statement 59 | 60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset. 61 | 62 | 63 | 64 | ## Performance and Limitations 65 | 66 | ### Performance 67 | 68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets: 69 | 70 | - Food101 71 | - CIFAR10 72 | - CIFAR100 73 | - Birdsnap 74 | - SUN397 75 | - Stanford Cars 76 | - FGVC Aircraft 77 | - VOC2007 78 | - DTD 79 | - Oxford-IIIT Pet dataset 80 | - Caltech101 81 | - Flowers102 82 | - MNIST 83 | - SVHN 84 | - IIIT5K 85 | - Hateful Memes 86 | - SST-2 87 | - UCF101 88 | - Kinetics700 89 | - Country211 90 | - CLEVR Counting 91 | - KITTI Distance 92 | - STL-10 93 | - RareAct 94 | - Flickr30 95 | - MSCOCO 96 | - ImageNet 97 | - ImageNet-A 98 | - ImageNet-R 99 | - ImageNet Sketch 100 | - ObjectNet (ImageNet Overlap) 101 | - Youtube-BB 102 | - ImageNet-Vid 103 | 104 | ## Limitations 105 | 106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance. 107 | 108 | ### Bias and Fairness 109 | 110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper). 111 | 112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks. 113 | 114 | 115 | 116 | ## Feedback 117 | 118 | ### Where to send questions or comments about the model 119 | 120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9) 121 | -------------------------------------------------------------------------------- /cvap/model/audioset_clf.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | import warnings 4 | from typing import Union, List 5 | from collections import defaultdict, OrderedDict 6 | 7 | import copy 8 | import time 9 | import torch 10 | import numpy as np 11 | from torch import nn 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | import torch.distributed as dist 16 | from torch.nn.parallel import data_parallel 17 | from torch.nn.parallel import DistributedDataParallel 18 | 19 | from clip import load 20 | 21 | from ..module import ( 22 | build_image_head, build_audio_head, build_text_head, build_loss_head 23 | ) 24 | from . import ( 25 | load_checkpoint, load_clip, load_meme 26 | ) 27 | 28 | class ASClassifier(nn.Module): 29 | def __init__(self, cfg, echo): 30 | super(ASClassifier, self).__init__() 31 | self.cfg = cfg 32 | self.echo = echo 33 | 34 | def forward(self, images, audios, labels, *args, **kwargs): 35 | device_ids = kwargs.get("device_ids", [0]) 36 | # how to asynchronize the two `data_parallel` 37 | kwargs = {"normalized": self.loss_head.normalized, "names": kwargs.get("names", None)} 38 | if self.image_head is not None and list(images.shape[1:]) != [1, 1, 1]: 39 | image_features = data_parallel( 40 | self.image_head, images, device_ids=device_ids, module_kwargs=kwargs 41 | ) 42 | else: # pre-computed unnormalized features 43 | if self.loss_head.normalized: 44 | images = images / images.norm(dim=-1, keepdim=True) 45 | image_features = images 46 | 47 | audio_features = data_parallel( 48 | self.audio_head, audios, device_ids=device_ids, module_kwargs=kwargs 49 | ) 50 | loss = self.loss_head(audio_features, labels, x3=image_features, **kwargs) 51 | return loss 52 | 53 | def encode_image(self, images, *args, device_ids=[0], **kwargs): 54 | image_features = data_parallel( 55 | self.image_head, images, device_ids=device_ids, module_kwargs=kwargs 56 | ) 57 | return image_features 58 | 59 | def encode_audio(self, audios, *args, device_ids=[0], **kwargs): 60 | audio_features = data_parallel( 61 | self.audio_head, audios, device_ids=device_ids, module_kwargs=kwargs 62 | ) 63 | return audio_features 64 | 65 | def encode_text(self, text, *args, device_ids=[0], **kwargs): 66 | text_features = data_parallel( 67 | self.text_head, text, device_ids=device_ids, module_kwargs=kwargs 68 | ) 69 | return text_features 70 | 71 | def collect_audio_state_dict(self): 72 | return ( 73 | self.audio_head.state_dict(), 74 | self.loss_head.state_dict(), 75 | ) 76 | 77 | def report(self, gold_file=None, **kwargs): 78 | if self.training: 79 | return self.loss_head.stats(**kwargs) if hasattr(self.loss_head, "stats") else "" 80 | if not dist.is_initialized() or dist.get_rank() == 0: 81 | return self.loss_head.report(gold_file=gold_file, **kwargs) 82 | else: 83 | return "" 84 | 85 | def build(self, **kwargs): 86 | tunable_params = dict() 87 | if self.cfg.eval: 88 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 89 | from_scratch, image_head_sd, text_head_sd, _ = load_clip(None, self.cfg, self.echo) 90 | 91 | self.image_head = build_image_head(self.cfg.model.image) 92 | if not from_scratch and self.cfg.running.imagine: 93 | self.image_head.copy_state_dict(image_head_sd) 94 | self.echo("Initialize image encoder from `image_head`.") 95 | else: 96 | self.image_head = None 97 | self.echo("Destory image encoder.") 98 | 99 | self.audio_head = build_audio_head(self.cfg.model.audio) 100 | if audio_head_sd is not None: 101 | n_o, o_n = self.audio_head.from_pretrained(audio_head_sd, local_cfg) 102 | msg = f" except {n_o}" if len(n_o) > 0 else "" 103 | self.echo(f"Initialize audio encoder from `audio_head`{msg}.") 104 | else: 105 | self.audio_head.copy_state_dict(image_head_sd) 106 | self.echo("Initialize audio encoder from `image_head`.") 107 | 108 | self.text_head = build_text_head(self.cfg.model.text) # 109 | n_o, o_n = self.text_head.copy_state_dict(text_head_sd) 110 | msg = f" except {n_o}" if len(n_o) > 0 else "" 111 | self.echo(f"Initialize text encoder from `text_head`{msg}.") 112 | 113 | self.loss_head = build_loss_head(self.cfg.model.loss, **kwargs) 114 | try: 115 | self.loss_head.load_state_dict(loss_head_sd) 116 | except Exception as e: 117 | self.echo(f"Failed to load `loss_head` (expected in zero-shot mode) because: {e}") 118 | 119 | self.cuda(self.cfg.rank) 120 | else: 121 | # try pre-trained model! 122 | local_cfg, _, audio_head_sd, _, loss_head_sd = load_checkpoint(self.cfg, self.echo) 123 | # try clip! TODO do we always have to load CLIP? 124 | from_scratch, image_head_sd, _, model = load_clip(None, self.cfg, self.echo) 125 | # try meme! 126 | with_meme, meme_image_head_sd = load_meme(self.cfg, self.echo) 127 | 128 | self.image_head = build_image_head(self.cfg.model.image) 129 | if not from_scratch and not self.cfg.model.image.from_scratch and image_head_sd is not None: 130 | self.image_head.copy_state_dict(image_head_sd) 131 | self.echo("Initialize image encoder from `image_head`.") 132 | if not self.cfg.running.imagine or self.cfg.running.frame_emb is not None: 133 | self.image_head = None 134 | self.echo("Destory image encoder.") 135 | 136 | self.audio_head = build_audio_head(self.cfg.model.audio) 137 | if not self.cfg.model.audio.from_scratch: 138 | if local_cfg is not None: 139 | # TODO better to use `from_pretrained()` 140 | self.audio_head.load_state_dict(audio_head_sd) 141 | self.echo("Initialize audio encoder from `audio_head`.") 142 | elif not from_scratch: 143 | if with_meme: # higher priority 144 | msg = " `meme_image_head`" 145 | n_o, o_n = self.audio_head.copy_state_dict(meme_image_head_sd) 146 | else: 147 | msg = " `image_head`" 148 | n_o, o_n = self.audio_head.copy_state_dict(image_head_sd) 149 | msg += f" except {n_o}" if len(n_o) > 0 else "" 150 | self.echo(f"Initialize audio encoder from{msg}.") 151 | else: 152 | self.echo("Have to learn from scratch.") 153 | 154 | self.loss_head = build_loss_head(self.cfg.model.loss, **kwargs) 155 | tunable_params = { 156 | f"loss_head.{k}": v for k, v in self.loss_head.named_parameters() 157 | } 158 | if not self.cfg.model.image.freeze and self.image_head is not None: 159 | tunable_params.update({ 160 | f"image_head.{k}": v for k, v in self.image_head.named_parameters() 161 | }) 162 | elif self.image_head is not None: 163 | self.echo("Freeze image encoder.") 164 | if not self.cfg.model.audio.freeze: 165 | excl_modules = set(self.cfg.running.excl_modules.amodules) 166 | pattern = "|".join([f"^{m}\." for m in excl_modules]) 167 | tunable_params.update({ 168 | f"audio_head.{k}": v for k, v in self.audio_head.named_parameters() 169 | if pattern == "" or not re.match(pattern, k)}) # filter out excluded parameters 170 | self.echo(f"Tune audio encoder (excl. {excl_modules}).") 171 | else: 172 | self.echo("Freeze audio encoder.") 173 | self.cuda(self.cfg.rank) 174 | return tunable_params 175 | -------------------------------------------------------------------------------- /cvap/data/image/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image, ImageOps, ImageFilter 7 | from torchvision.transforms import ( 8 | InterpolationMode, Compose, Resize, CenterCrop, ToTensor, Normalize 9 | ) 10 | 11 | def make_clip_image_transform(n_px): 12 | return Compose([ 13 | Resize(n_px, interpolation=InterpolationMode.BICUBIC), 14 | CenterCrop(n_px), 15 | lambda image: image.convert("RGB"), 16 | ToTensor(), 17 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 18 | ]) 19 | 20 | class GaussianBlur(object): 21 | def __init__(self, p): 22 | self.p = p 23 | 24 | def __call__(self, image): 25 | if random.random() < self.p: 26 | sigma = random.random() * 1.9 + 0.1 27 | return image.filter(ImageFilter.GaussianBlur(sigma)) 28 | else: 29 | return image 30 | 31 | class Solarization(object): 32 | def __init__(self, p): 33 | self.p = p 34 | 35 | def __call__(self, image): 36 | if random.random() < self.p: 37 | return ImageOps.solarize(image) 38 | else: 39 | return image 40 | 41 | class SharedImageTransform: 42 | def __init__(self, n_px): 43 | self.transform = transforms.Compose([ 44 | transforms.RandomResizedCrop( 45 | n_px, interpolation=InterpolationMode.BICUBIC 46 | ), 47 | transforms.RandomHorizontalFlip(p=0.5), 48 | transforms.RandomApply([ 49 | transforms.ColorJitter( 50 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 51 | ) 52 | ], p=0.8), 53 | transforms.RandomGrayscale(p=0.2), 54 | ]) 55 | 56 | def __call__(self, x): 57 | return self.transform(x) 58 | 59 | class SecretImageTransform: 60 | def __init__(self, p_g, p_s): 61 | self.transform = transforms.Compose([ 62 | GaussianBlur(p=p_g), 63 | Solarization(p=p_s), 64 | transforms.ToTensor(), 65 | transforms.Normalize( 66 | #mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 67 | mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] 68 | ) 69 | ]) 70 | def __call__(self, x): 71 | return self.transform(x) 72 | 73 | class AuthenticCLIPImageTransform: 74 | def __init__(self, n_px): 75 | self.transform = transforms.Compose([ 76 | transforms.Resize(n_px, interpolation=InterpolationMode.BICUBIC), 77 | transforms.CenterCrop(n_px), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 80 | ]) 81 | self.transform_prime = transforms.Compose([ 82 | transforms.Resize(n_px, interpolation=InterpolationMode.BICUBIC), 83 | transforms.CenterCrop(n_px), 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | self.transform_eval = self.transform_prime 88 | 89 | def __call__(self, x, both, train): 90 | x = x.convert("RGB") 91 | if not train: 92 | return self.transform_eval(x), np.array([[[1]]]) 93 | else: 94 | y1 = self.transform_prime(x) 95 | y2 = self.transform(x) if both else np.array([[[1]]]) 96 | return y1, y2 97 | 98 | class CLIPImageTransform: 99 | def __init__(self, n_px): 100 | self.transform = transforms.Compose([ 101 | transforms.Resize(n_px, interpolation=InterpolationMode.BICUBIC), 102 | transforms.CenterCrop(n_px), 103 | transforms.RandomHorizontalFlip(p=0.5), 104 | transforms.RandomApply([ 105 | transforms.ColorJitter( 106 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 107 | ) 108 | ], p=0.8), 109 | transforms.RandomGrayscale(p=0.2), 110 | GaussianBlur(p=1.0), 111 | Solarization(p=0.0), 112 | transforms.ToTensor(), 113 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 114 | ]) 115 | self.transform_prime = transforms.Compose([ 116 | transforms.Resize(n_px, interpolation=InterpolationMode.BICUBIC), 117 | transforms.CenterCrop(n_px), 118 | transforms.RandomHorizontalFlip(p=0.5), 119 | transforms.RandomApply([ 120 | transforms.ColorJitter( 121 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 122 | ) 123 | ], p=0.8), 124 | transforms.RandomGrayscale(p=0.2), 125 | GaussianBlur(p=0.1), 126 | Solarization(p=0.2), 127 | transforms.ToTensor(), 128 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 129 | ]) 130 | self.transform_eval = transforms.Compose([ 131 | transforms.Resize(n_px, interpolation=InterpolationMode.BICUBIC), 132 | transforms.CenterCrop(n_px), 133 | transforms.ToTensor(), 134 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 135 | ]) 136 | 137 | def __call__(self, x, both, train): 138 | x = x.convert("RGB") 139 | if not train: 140 | return self.transform_eval(x), np.array([[[1]]]) 141 | else: 142 | y1 = self.transform_prime(x) 143 | y2 = self.transform(x) if both else np.array([[[1]]]) 144 | return y1, y2 145 | 146 | class BarlowImageTransform: 147 | def __init__(self, n_px): 148 | self.transform = transforms.Compose([ 149 | transforms.RandomResizedCrop( 150 | n_px, interpolation=InterpolationMode.BICUBIC 151 | ), 152 | transforms.RandomHorizontalFlip(p=0.5), 153 | transforms.RandomApply([ 154 | transforms.ColorJitter( 155 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 156 | ) 157 | ], p=0.8), 158 | transforms.RandomGrayscale(p=0.2), 159 | GaussianBlur(p=1.0), 160 | Solarization(p=0.0), 161 | transforms.ToTensor(), 162 | transforms.Normalize( 163 | #mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 164 | mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] 165 | ) 166 | ]) 167 | self.transform_prime = transforms.Compose([ 168 | transforms.RandomResizedCrop( 169 | n_px, interpolation=InterpolationMode.BICUBIC 170 | ), 171 | transforms.RandomHorizontalFlip(p=0.5), 172 | transforms.RandomApply([ 173 | transforms.ColorJitter( 174 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 175 | ) 176 | ], p=0.8), 177 | transforms.RandomGrayscale(p=0.2), 178 | GaussianBlur(p=0.1), 179 | Solarization(p=0.2), 180 | transforms.ToTensor(), 181 | transforms.Normalize( 182 | #mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 183 | mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] 184 | ) 185 | ]) 186 | self.transform_eval = transforms.Compose([ 187 | transforms.Resize(n_px, interpolation=InterpolationMode.BICUBIC), 188 | transforms.CenterCrop(n_px), 189 | transforms.ToTensor(), 190 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 191 | ]) 192 | 193 | def __call__(self, x, both, train): 194 | x = x.convert("RGB") 195 | if not train: 196 | return self.transform_eval(x), np.array([[[1]]]) 197 | else: 198 | y1 = self.transform_prime(x) 199 | y2 = self.transform(x) if both else np.array([[[1]]]) 200 | return y1, y2 201 | -------------------------------------------------------------------------------- /cvap/data/audioset_clf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import json 5 | import torch 6 | import itertools 7 | import torchaudio 8 | import numpy as np 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | from itertools import cycle, islice, chain 12 | from einops import rearrange, repeat 13 | from collections import defaultdict 14 | from tabulate import tabulate 15 | from termcolor import colored 16 | 17 | import multiprocessing as mp 18 | import torch.utils.data as data 19 | import torch.nn.functional as F 20 | 21 | from .audio import ( 22 | make_transform, _extract_kaldi_spectrogram 23 | ) 24 | from .audioset_cls import print_label_dist, AudiosetNpz, AudiosetSrc 25 | from clip import tokenize 26 | 27 | class AudiosetDatasetNpz(data.Dataset): 28 | """ `__getitem__' loads .npz from disk. 29 | """ 30 | def __init__(self, cfg, data_name, train, label_map, weighted): 31 | data_path = f"{cfg.data_root}/{data_name}.csv" 32 | assert os.path.isfile(data_path), f"{data_path} is not a file." 33 | self.label_map = label_map 34 | self.num_label = len(label_map) 35 | label_counts = np.zeros(self.num_label) 36 | self.dataset = list() 37 | with open(data_path, "r") as fr: 38 | for iline, line in enumerate(fr): 39 | record = json.loads(line) 40 | self.dataset.append(record) 41 | if not train and iline + 1 == cfg.eval_samples: 42 | break 43 | if weighted: # save label distribution 44 | for category in record["labels"]: 45 | label_counts[ 46 | self.label_map[category][0] 47 | ] += 1 48 | self.length = len(self.dataset) 49 | if weighted: # compute sample weight 50 | lid2label = {v[0]: re.sub(f"^{cfg.prompt}", "", v[1]).strip() for _, v in label_map.items()} 51 | print_label_dist(cfg, print, label_counts, lid2label, ncol=18) 52 | self.sample_weights = np.zeros(self.length) 53 | label_counts = 1000.0 / (label_counts + 1.) 54 | for i, record in enumerate(self.dataset): 55 | for category in record["labels"]: 56 | self.sample_weights[i] += label_counts[ 57 | self.label_map[category][0] 58 | ] 59 | self.audio_norms = cfg.audio.norms 60 | self.train = train 61 | self.cfg = cfg 62 | 63 | self.transform_audio, self.transform_fbank = make_transform(cfg.audio) 64 | 65 | def _shuffle(self): 66 | pass 67 | 68 | def __getitem__(self, index): 69 | name = self.dataset[index]["id"] 70 | aclip = self.dataset[index]["aclip"] 71 | frame = self.dataset[index]["frame"] 72 | categories = self.dataset[index]["labels"] 73 | 74 | aclip_file = f"{self.cfg.data_root}/{aclip}" 75 | frame_file = f"{self.cfg.data_root}/{frame}" 76 | 77 | images = np.load(frame_file) 78 | images = [images[key] for key in images.files if len(images[key]) != 0] 79 | assert len(images) != 0, f"no frame exist: |images| = {len(images)}" 80 | if self.train: 81 | idx = np.random.choice(len(images), 1)[0] 82 | ict = np.random.choice(len(categories), 1)[0] 83 | else: 84 | idx = int(np.ceil(len(images) / 2)) - 1 85 | ict = 0 # 1st label 86 | image = images[idx] 87 | 88 | max_audio_len = self.cfg.max_audio_len 89 | audio = np.load(aclip_file)["flag"] # (..., time, freq): `flag' is used as the key accidentally 90 | 91 | if self.cfg.audio.normalized: # normalize along feature dim 92 | audio /= np.max(np.abs(audio), axis=1)[:, None] 93 | 94 | if not self.cfg.audio.eval_norms and len(self.audio_norms) == 2: 95 | mean, std = self.audio_norms 96 | audio = (audio - mean) / std 97 | 98 | if not self.cfg.audio.eval_norms and self.train and self.transform_fbank is not None: 99 | audio = self.transform_fbank(audio) 100 | 101 | npad = max_audio_len - audio.shape[0] 102 | if npad > 0: 103 | audio = np.pad(audio, ((0, npad), (0, 0)), "constant", constant_values=(0., 0.)) 104 | 105 | image = image[None] 106 | audio = audio[None] 107 | 108 | if not self.cfg.clf: 109 | category = categories[ict] 110 | label, _, text_int = self.label_map[category] 111 | else: # classification task 112 | label_set = set([self.label_map[category][0] for category in categories]) 113 | label = [1 if i in label_set else 0 for i in range(self.num_label)] 114 | text_int = [0] # TODO concatenate all text pieces 115 | 116 | item = {"image": image, "audio": audio, "text": text_int, "label": label, "name": name} 117 | return item 118 | 119 | def __len__(self): 120 | return self.length 121 | 122 | class ImageAudioCollator: 123 | def __init__(self, device=torch.device("cpu")): 124 | # RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned 125 | # when pin_memory is true, the collator has to return CPU tensors 126 | self.device = device 127 | 128 | def __call__(self, records): 129 | union = { 130 | k: [record.get(k) for record in records] for k in set().union(*records) 131 | } 132 | name = union["name"] 133 | label = union["label"] 134 | text_list = union["text"] 135 | if isinstance(text_list[0][0], np.ndarray): # pre-computed text embeddings 136 | text = np.concatenate(text_list, axis=0) # (1, H) -> (b, H) 137 | else: 138 | if isinstance(text_list[0][0], int): # train / test clf 139 | label = np.array(label) 140 | elif isinstance(text_list[0][0], list): # test retrieval 141 | text_list = list(itertools.chain.from_iterable(text_list)) 142 | #name = list(itertools.chain.from_iterable(name)) 143 | #label = # list of label lists 144 | else: 145 | raise ValueError(f"unrecognized `{type(text_list[0][0])}`") 146 | # https://stackoverflow.com/a/38619333 147 | text = np.array(list(itertools.zip_longest(*text_list, fillvalue=0))).T 148 | return ( 149 | np.concatenate(union["image"], axis=0), 150 | np.concatenate(union["audio"], axis=0), 151 | text, label, name, 152 | ) 153 | 154 | def build_audioset_clf_dataloader(cfg, data_name, label_map, shuffle=True, train=True, filters=None): 155 | ddp_mode = torch.distributed.is_initialized() 156 | rcfg = cfg.running 157 | weighted = train and rcfg.weighted_sampling 158 | if data_name.startswith("src"): 159 | if not rcfg.force_npz: 160 | dataset = AudiosetSrc(rcfg, data_name, train, label_map, weighted, filters) 161 | else: 162 | dataset = AudiosetNpz(rcfg, data_name, train, label_map, weighted) 163 | elif data_name.startswith("npz"): 164 | dataset = AudiosetDatasetNpz(rcfg, data_name, train, label_map, weighted) 165 | else: 166 | dataset = AudiosetSrc(rcfg, data_name, train, label_map, weighted) 167 | #raise ValueError(f"unrecognized data file `{data_name}`.") 168 | if ddp_mode: 169 | assert cfg.optimizer.batch_size % cfg.num_gpus == 0 170 | sampler = torch.utils.data.distributed.DistributedSampler( 171 | dataset, shuffle=shuffle 172 | ) 173 | per_device_batch_size = cfg.optimizer.batch_size // cfg.num_gpus 174 | else: 175 | if not weighted: 176 | sampler = ( 177 | torch.utils.data.RandomSampler(dataset) if shuffle else 178 | torch.utils.data.SequentialSampler(dataset) 179 | ) 180 | else: 181 | sampler = torch.utils.data.WeightedRandomSampler( 182 | dataset.sample_weights, len(dataset), replacement=True 183 | ) 184 | per_device_batch_size = cfg.optimizer.batch_size 185 | dataloader = torch.utils.data.DataLoader( 186 | dataset, 187 | batch_size=per_device_batch_size, 188 | collate_fn=ImageAudioCollator(), 189 | num_workers=(0 if ddp_mode else cfg.num_proc), 190 | pin_memory=True, 191 | sampler=sampler, 192 | drop_last=(True if ddp_mode else False), 193 | ) 194 | return sampler, dataloader 195 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import InterpolationMode, Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | __all__ = ["available_models", "load", "tokenize", "_tokenizer"] 16 | _tokenizer = _Tokenizer() 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 23 | "ViT-B32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 24 | "ViT-B16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 25 | } 26 | 27 | 28 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 29 | os.makedirs(root, exist_ok=True) 30 | filename = os.path.basename(url) 31 | 32 | expected_sha256 = url.split("/")[-2] 33 | download_target = os.path.join(root, filename) 34 | 35 | if os.path.exists(download_target) and not os.path.isfile(download_target): 36 | raise RuntimeError(f"{download_target} exists and is not a regular file") 37 | 38 | if os.path.isfile(download_target): 39 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 40 | return download_target 41 | else: 42 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 43 | 44 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 45 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 46 | while True: 47 | buffer = source.read(8192) 48 | if not buffer: 49 | break 50 | 51 | output.write(buffer) 52 | loop.update(len(buffer)) 53 | 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 55 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 56 | 57 | return download_target 58 | 59 | 60 | def _transform(n_px): 61 | return Compose([ 62 | Resize(n_px, interpolation=InterpolationMode.BICUBIC), 63 | CenterCrop(n_px), 64 | lambda image: image.convert("RGB"), 65 | ToTensor(), 66 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 67 | ]) 68 | 69 | 70 | def available_models() -> List[str]: 71 | """Returns the names of available CLIP models""" 72 | return list(_MODELS.keys()) 73 | 74 | 75 | def load( 76 | name: str, 77 | root: str = os.path.expanduser("~/.cache/clip"), 78 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 79 | jit=True 80 | ): 81 | """Load a CLIP model 82 | 83 | Parameters 84 | ---------- 85 | name : str 86 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 87 | 88 | device : Union[str, torch.device] 89 | The device to put the loaded model 90 | 91 | jit : bool 92 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 93 | 94 | Returns 95 | ------- 96 | model : torch.nn.Module 97 | The CLIP model 98 | 99 | preprocess : Callable[[PIL.Image], torch.Tensor] 100 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 101 | """ 102 | if name in _MODELS: 103 | model_path = _download(_MODELS[name], root=root) 104 | elif os.path.isfile(name): 105 | model_path = name 106 | else: 107 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 108 | 109 | try: 110 | # loading JIT archive 111 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 112 | state_dict = None 113 | except RuntimeError: 114 | # loading saved state dict 115 | if jit: 116 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 117 | jit = False 118 | state_dict = torch.load(model_path, map_location="cpu") 119 | 120 | if not jit: 121 | model = build_model(state_dict or model.state_dict()).to(device) 122 | if str(device) == "cpu": 123 | model.float() 124 | return model, _transform(model.visual.input_resolution) 125 | 126 | # patch the device names 127 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 128 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 129 | 130 | def patch_device(module): 131 | graphs = [module.graph] if hasattr(module, "graph") else [] 132 | if hasattr(module, "forward1"): 133 | graphs.append(module.forward1.graph) 134 | 135 | for graph in graphs: 136 | for node in graph.findAllNodes("prim::Constant"): 137 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 138 | node.copyAttributes(device_node) 139 | 140 | model.apply(patch_device) 141 | #patch_device(model.encode_image) 142 | #patch_device(model.encode_text) 143 | 144 | # patch dtype to float32 on CPU 145 | if str(device) == "cpu": 146 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 147 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 148 | float_node = float_input.node() 149 | 150 | def patch_float(module): 151 | graphs = [module.graph] if hasattr(module, "graph") else [] 152 | if hasattr(module, "forward1"): 153 | graphs.append(module.forward1.graph) 154 | 155 | for graph in graphs: 156 | for node in graph.findAllNodes("aten::to"): 157 | inputs = list(node.inputs()) 158 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 159 | if inputs[i].node()["value"] == 5: 160 | inputs[i].node().copyAttributes(float_node) 161 | 162 | model.apply(patch_float) 163 | patch_float(model.encode_image) 164 | patch_float(model.encode_text) 165 | 166 | model.float() 167 | 168 | return model, _transform(model.input_resolution.item()) 169 | 170 | 171 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, as_list=False, truncate: bool = False) -> torch.LongTensor: 172 | """ 173 | Returns the tokenized representation of given input string(s) 174 | 175 | Parameters 176 | ---------- 177 | texts : Union[str, List[str]] 178 | An input string or a list of input strings to tokenize 179 | 180 | context_length : int 181 | The context length to use; all CLIP models use 77 as the context length 182 | 183 | Returns 184 | ------- 185 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 186 | """ 187 | if isinstance(texts, str): 188 | texts = [texts] 189 | 190 | sot_token = _tokenizer.encoder["<|startoftext|>"] 191 | eot_token = _tokenizer.encoder["<|endoftext|>"] 192 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 193 | if as_list: 194 | return all_tokens 195 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 196 | 197 | for i, tokens in enumerate(all_tokens): 198 | if len(tokens) > context_length: 199 | if truncate: 200 | tokens = tokens[:context_length] 201 | tokens[-1] = eot_token 202 | else: 203 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 204 | result[i, :len(tokens)] = torch.tensor(tokens) 205 | 206 | return result 207 | -------------------------------------------------------------------------------- /cvap/monitor/siamese_va.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import os, re 3 | from collections import defaultdict 4 | 5 | import time 6 | import torch 7 | import numpy as np 8 | from torch import nn 9 | 10 | import torch.distributed as dist 11 | import torch.nn.functional as F 12 | from torch.nn.parallel import data_parallel 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 15 | 16 | from .cvap import Monitor 17 | from ..module import adjust_learning_rate 18 | 19 | class Monitor(Monitor): 20 | def __init__(self, cfg, echo, device): 21 | super(Monitor, self).__init__(cfg, echo, device) 22 | 23 | def make_batch(self, batch): 24 | def scale_images(images): 25 | if images.dim() != 2 and images.shape[-1] != self.cfg.running.resolution and \ 26 | list(images.shape[1:]) != [1, 1, 1]: 27 | images = F.interpolate( 28 | images, 29 | self.cfg.running.resolution, 30 | mode="bilinear", 31 | align_corners=False, 32 | ) 33 | return images 34 | #print(batch[0].shape, batch[1].shape, batch[2].shape, batch[3].shape, batch[4].shape) 35 | images = (torch.tensor(batch[0], device=self.device) 36 | if self.cfg.model.loss.vp or self.cfg.model.loss.ap else None 37 | ) 38 | images_v1 = scale_images( 39 | torch.tensor(batch[1], device=self.device) # (c, h, w) 40 | ) 41 | images_v2 = scale_images( 42 | torch.tensor(batch[2], device=self.device) # (c, h, w) 43 | ) if self.cfg.model.loss.vv else None 44 | audios_v1 = torch.tensor(batch[3], device=self.device) 45 | audios_v2 = (torch.tensor(batch[4], device=self.device) 46 | if self.cfg.model.loss.aa else None 47 | ) 48 | """ 49 | print( 50 | images.shape if images is not None else None, 51 | self.cfg.running.resolution, 52 | images_v1.shape, 53 | images_v2.shape if images_v2 is not None else None, 54 | audios_v1.shape, 55 | audios_v2.shape if audios_v2 is not None else None, 56 | ) 57 | """ 58 | #import sys; sys.exit(0) 59 | batch = ( 60 | images, images_v1, images_v2, audios_v1, audios_v2, batch[-1], # sample id or name 61 | ) 62 | return batch # bare tensors 63 | 64 | def epoch(self, iepoch): 65 | all_time = defaultdict(list) 66 | self.timeit(all_time) 67 | device_ids = [i for i in range(self.cfg.num_gpus)] 68 | nchunk = dist.get_world_size() if torch.distributed.is_initialized() else 1 69 | warmup_step_rate = max(self.cfg.optimizer.warmup_steps // 20, 1) 70 | for step, batch in enumerate(self.dataloader, start=iepoch * len(self.dataloader)): 71 | images, images_v1, images_v2, audios_v1, audios_v2, _ = self.make_batch(batch) 72 | self.timeit(all_time, key="data") 73 | 74 | if self.cfg.optimizer.use_lars: 75 | adjust_learning_rate(self.cfg.optimizer, self.optimizer, self.dataloader, step) 76 | 77 | inc = 0 78 | force_eval = False # recommended by SGDR 79 | warmup = not self.cfg.optimizer.use_lars and self.cfg.optimizer.warmup and \ 80 | (self.total_step + inc) <= self.cfg.optimizer.warmup_steps 81 | # it is important to always warm up lr at the first step otherwise 82 | # the optimizer will use the default / initial lr 83 | if warmup and ((self.total_step + inc) % warmup_step_rate == 0 or self.total_step == 0): 84 | ratio = ((self.total_step + inc) / self.cfg.optimizer.warmup_steps) # * self.cfg.optimizer.lr 85 | for param_group in self.optimizer.param_groups: 86 | param_group['lr'] = ratio * param_group["initial_lr"] 87 | lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] 88 | force_eval = lrs == self.scheduler.base_lrs 89 | lrs = [f"{lr:.2e}" for lr in lrs] 90 | self.echo(f"warmup lr: {' '.join(lrs)} @ {self.total_step}") 91 | 92 | self.optimizer.zero_grad(set_to_none=True) 93 | with torch.cuda.amp.autocast(): 94 | loss = self.model( 95 | images, images_v1, audios_v1, device_ids=device_ids, 96 | text_v1=None, images_v2=images_v2, audios_v2=audios_v2 97 | ) 98 | self.scaler.scale(loss).backward() 99 | self.scaler.step(self.optimizer) 100 | self.scaler.update() 101 | 102 | if not self.cfg.optimizer.use_lars and self.cfg.optimizer.batch_sch and not warmup: 103 | old_lrs = " ".join([f"{x:.2e}" for x in self.scheduler.get_last_lr()]) 104 | self.scheduler.step() # after all warmup is completed 105 | if isinstance(self.scheduler, (CosineAnnealingWarmRestarts,)): 106 | force_eval = self.scheduler.get_last_lr() == self.scheduler.base_lrs 107 | #self.echo(f"do step lr {old_lrs}") 108 | 109 | self.timeit(all_time, key="model") 110 | 111 | if False and self.cfg.rank == 0: 112 | print(f"doing some check on unused params... {dist.get_world_size()}") 113 | for k, v in self.model.named_parameters(): 114 | if v.requires_grad and v.grad is None: 115 | print(f"--> {k}") 116 | 117 | self.total_step += 1 118 | self.total_loss += loss.detach() 119 | self.total_inst += images.shape[0] * nchunk if images is not None else audios_v1.shape[0] * nchunk 120 | if force_eval or (self.cfg.rank == 0 and self.total_step % self.cfg.running.peep_rate == 0): 121 | def grad_norm(): 122 | return sum( 123 | [p.grad.norm(p=2) ** 2 for p in self.params if p.grad is not None] 124 | ).item() ** 0.5 125 | lr_w = self.optimizer.param_groups[0]['lr'] 126 | lr_b = self.optimizer.param_groups[1]['lr'] 127 | msg = self.model.report(**{"nstep": self.total_step}) 128 | self.echo( 129 | f"epoch {iepoch:>4} step {self.total_step}\t" + #gnorm {grad_norm():.2f} " + 130 | f"lr_w {lr_w:.2e} lr_b {lr_b:.2e} loss {self.total_loss / self.total_step:.3f} " + 131 | f"{msg} {self.total_inst / (time.time() - self.start_time):.2f} samples/s" 132 | ) 133 | if force_eval or self.total_step % self.cfg.running.save_rate == 0 or ( 134 | self.cfg.running.save_epoch and self.total_step % len(self.dataloader) == 0 135 | ): # distributed eval 136 | report = "" 137 | if self.evalloader is not None: 138 | self.model.train(False) 139 | with torch.no_grad(): 140 | report = self.infer( 141 | self.evalloader, samples=self.cfg.running.eval_samples, iepoch=iepoch 142 | ) 143 | self.model.train(True) 144 | if report != "": 145 | self.echo(f"{report}") 146 | if self.cfg.rank == 0: 147 | self.save() 148 | self.timeit(all_time, key="report") 149 | 150 | if not self.cfg.optimizer.use_lars and not self.cfg.optimizer.batch_sch: 151 | self.scheduler.step() 152 | self.timeit(all_time, show=True) 153 | 154 | def infer(self, dataloader, samples=float("inf"), iepoch=0): 155 | losses, nsample, nchunk, nbatch = 0, 0, 1, len(dataloader) 156 | device_ids = [i for i in range(self.cfg.num_gpus)] 157 | if isinstance(self.model, DistributedDataParallel): 158 | dataloader.sampler.set_epoch(iepoch) 159 | nchunk = self.cfg.num_gpus 160 | peep_rate = max(10, (len(dataloader) // 10)) 161 | start_time = time.time() 162 | for ibatch, batch in enumerate(dataloader): 163 | if nsample >= samples: 164 | #print(f"{nsample}\t{ibatch}/{nbatch} continue") 165 | break #continue # iterate through every batch 166 | images, images_v1, _, audios_v1, _, names = self.make_batch(batch) 167 | #msg = f"{images[0, 0, 50, 50:55]} {audios[0, 0, 50, 50:55]}" # if ibatch == 0 else "" 168 | #print(f"{nsample}\t{ibatch}/{nbatch} done {msg}") 169 | loss = self.model(images, images_v1, audios_v1, device_ids=device_ids, names=names) 170 | nsample += images.shape[0] * nchunk if images is not None else audios_v1.shape[0] * nchunk 171 | losses += loss or 0. 172 | if self.cfg.rank == 0 and (ibatch + 1) % peep_rate == 0: 173 | self.echo( 174 | f"step {ibatch}\t" + #gnorm {grad_norm():.2f} " + 175 | f"loss {losses / (ibatch + 1):.8f} " + 176 | f"{nsample / (time.time() - start_time):.2f} samples/s" 177 | ) 178 | model = self.model.module if isinstance(self.model, DistributedDataParallel) else self.model 179 | self.echo(f"# sample {nsample}; {nsample / (time.time() - start_time):.2f} samples/s") 180 | return model.report(gold_file=self.gold_file) 181 | -------------------------------------------------------------------------------- /cvap/data/audio_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | import json 5 | import torch 6 | import itertools 7 | import torchaudio 8 | import numpy as np 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | from itertools import cycle, islice, chain 12 | from einops import rearrange, repeat 13 | 14 | import multiprocessing as mp 15 | import torch.utils.data as data 16 | import torch.nn.functional as F 17 | 18 | from .audio import ( 19 | make_transform, _extract_kaldi_spectrogram 20 | ) 21 | from clip import tokenize 22 | 23 | class AudioTextDatasetSrc(data.Dataset): 24 | """ `__getitem__' loads raw file from disk. 25 | """ 26 | def __init__(self, cfg, data_list, train): 27 | self.dataset = list() 28 | for iline, record in enumerate(data_list): 29 | self.dataset.append(record) 30 | if not train and iline + 1 == cfg.eval_samples: 31 | break 32 | self.audio_norms = cfg.audio.norms 33 | self.length = len(self.dataset) 34 | self.train = train 35 | self.cfg = cfg 36 | 37 | self.aclip_key = "clip" if "clip" in self.dataset[0] else "aclip" 38 | acfg = cfg.audio 39 | self.transform_audio, self.transform_fbank = make_transform(acfg) 40 | self.kaldi_params = { 41 | "htk_compat": True, 42 | "use_energy": False, 43 | "window_type": 'hanning', 44 | "num_mel_bins": acfg.num_mel_bins, 45 | "dither": 0.0, 46 | "frame_shift": 10 47 | } 48 | 49 | def _shuffle(self): 50 | pass 51 | 52 | def _audio2numpy_cst(self, aclip_file): 53 | max_audio_len = self.cfg.max_audio_len 54 | audio = _extract_kaldi_spectrogram( 55 | aclip_file, 56 | self.kaldi_params, 57 | train=self.train, 58 | max_audio_len=max_audio_len, 59 | zero_mean_wf=self.cfg.audio.zero_mean_wf, 60 | transform_audio=( 61 | self.transform_audio if self.train and not self.cfg.audio.eval_norms else None 62 | ) 63 | ) # (..., time, freq) 64 | 65 | npad = max_audio_len - audio.shape[0] 66 | if npad > 0: 67 | audio = np.pad(audio, ((0, npad), (0, 0)), "constant", constant_values=(0., 0.)) 68 | return audio 69 | 70 | def __getitem__(self, index): 71 | akey = self.aclip_key 72 | name = self.dataset[index]["id"] 73 | sub_dir = self.dataset[index]["dir"] 74 | label_str = self.dataset[index]["label_str"] 75 | label_int = self.dataset[index]["label_int_bpe"] 76 | aclip = self.dataset[index][akey][0] 77 | 78 | sub_dir = "" if len(sub_dir) == 0 else f"{sub_dir}/" 79 | aclip = aclip if aclip == name else f"{akey}/{name}.{aclip}" 80 | aclip_file = f"{self.cfg.data_root}/{sub_dir}{aclip}" 81 | 82 | audio = self._audio2numpy_cst(aclip_file) 83 | 84 | if not self.cfg.audio.eval_norms and len(self.audio_norms) == 2: 85 | mean, std = self.audio_norms 86 | audio = (audio - mean) / std 87 | 88 | #if self.train and self.transform_fbank is not None: 89 | if not self.cfg.audio.eval_norms and self.train and self.transform_fbank is not None: 90 | audio = self.transform_fbank(audio) 91 | 92 | if self.train: 93 | idx = np.random.choice(len(label_int), 1)[0] 94 | text = label_int[idx] 95 | else: 96 | text = label_int 97 | 98 | audio = audio[None] 99 | item = {"audio": audio, "text": text, "name": name} 100 | return item 101 | 102 | def __len__(self): 103 | return self.length 104 | 105 | class AudioTextCollator: 106 | def __init__(self, device=torch.device("cpu")): 107 | # RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned 108 | # when pin_memory is true, the collator has to return CPU tensors 109 | self.device = device 110 | 111 | def __call__(self, records): 112 | union = { 113 | k: [record.get(k) for record in records] for k in set().union(*records) 114 | } 115 | name = union["name"] 116 | text_list = union["text"] 117 | if isinstance(text_list[0][0], int): # train 118 | pass 119 | """ https://stackoverflow.com/a/43149308 120 | lengths = [len(x) for x in text_list] 121 | max_len = max(lengths) 122 | text = np.zeros((len(text_list), max_len), int) 123 | mask = np.arange(max_len) < np.array(lengths)[:, None] 124 | text[mask] = np.concatenate(text_list) 125 | """ 126 | elif isinstance(text_list[0][0], list): # test 127 | text_list = list(itertools.chain.from_iterable(text_list)) 128 | #name = list(itertools.chain.from_iterable(name)) 129 | else: 130 | raise ValueError(f"unrecognized `{type(text_list[0][0])}`") 131 | # https://stackoverflow.com/a/38619333 132 | text = np.array(list(itertools.zip_longest(*text_list, fillvalue=0))).T 133 | return ( 134 | np.concatenate(union["audio"], axis=0), 135 | text, 136 | name, 137 | ) 138 | 139 | def build_dataloader(cfg, data_list, dataset_cls, shuffle=True, train=True, collator_cls=AudioTextCollator): 140 | ddp_mode = torch.distributed.is_initialized() 141 | rcfg = cfg.running 142 | if isinstance(dataset_cls, str): 143 | dataset = eval(dataset_cls)(rcfg, data_list, train) 144 | else: 145 | dataset = dataset_cls(rcfg, data_list, train) 146 | if ddp_mode: 147 | assert self.cfg.optimizer.batch_size % self.cfg.num_gpus == 0 148 | sampler = torch.utils.data.distributed.DistributedSampler( 149 | dataset, shuffle=shuffle 150 | ) 151 | per_device_batch_size = cfg.optimizer.batch_size // cfg.num_gpus 152 | else: 153 | sampler = ( 154 | torch.utils.data.RandomSampler(dataset) if shuffle else 155 | torch.utils.data.SequentialSampler(dataset) 156 | ) 157 | per_device_batch_size = cfg.optimizer.batch_size 158 | dataloader = torch.utils.data.DataLoader( 159 | dataset, 160 | batch_size=per_device_batch_size, 161 | collate_fn=collator_cls(), 162 | num_workers=(0 if ddp_mode else cfg.num_proc), 163 | pin_memory=True, 164 | sampler=sampler, 165 | drop_last=(True if ddp_mode else False), 166 | ) 167 | return sampler, dataloader 168 | 169 | def build_clotho_data_list(cfg, data_name): 170 | fold = data_name.rsplit("_", 1)[-1] # {development, validation, evaluation} 171 | data_path = f"{cfg.data_root}/{data_name}.csv" 172 | assert os.path.isfile(data_path), f"{data_path} is not a file." 173 | prompt = cfg.prompt.strip() 174 | prompt = "" if len(prompt) == 0 else f"{prompt} " 175 | dataset = list() 176 | with open(data_path, "r") as fr: 177 | meta = csv.DictReader(fr) 178 | for i, row in enumerate(meta): 179 | filename = row["file_name"] 180 | captions = [prompt + row[f"caption_{icap}"] for icap in range(1, 6)] 181 | label_int_bpe = tokenize(captions, as_list=True) 182 | item = { 183 | "id": filename, 184 | "dir": fold, 185 | "aclip": [filename], 186 | "label_int_bpe": label_int_bpe, 187 | "label_int_w2v": [], 188 | "label_str": captions 189 | } 190 | dataset.append(item) 191 | if i > 10: 192 | pass #break 193 | #print(dataset) 194 | return dataset 195 | 196 | def build_audiocaps_data_list(cfg, data_name): 197 | data_path = f"{cfg.data_root}/{data_name}.csv" 198 | assert os.path.isfile(data_path), f"{data_path} is not a file." 199 | prompt = cfg.prompt.strip() 200 | prompt = "" if len(prompt) == 0 else f"{prompt} " 201 | dataset = list() 202 | with open(data_path, "r") as fr: 203 | for iline, line in enumerate(fr): 204 | record = json.loads(line) 205 | captions = [prompt + caption for caption in record["captions"]] 206 | record["label_int_w2v"] = [] 207 | record["label_int_bpe"] = tokenize( 208 | captions, as_list=True 209 | ) # add bpe captions 210 | record["label_str"] = captions 211 | dataset.append(record) 212 | if iline > 10: 213 | pass #break 214 | print(dataset[:2]) 215 | return dataset 216 | 217 | def build_dataloader_clotho(cfg, data_name, shuffle=True, train=True): 218 | name_list = data_name.split(",") 219 | dataset = list() 220 | for name in name_list: 221 | subset = build_clotho_data_list(cfg.running, name) 222 | dataset.extend(subset) 223 | return build_dataloader(cfg, dataset, AudioTextDatasetSrc, shuffle=shuffle, train=train) 224 | 225 | def build_dataloader_audiocaps(cfg, data_name, shuffle=True, train=True): 226 | name_list = data_name.split(",") 227 | dataset = list() 228 | for name in name_list: 229 | subset = build_audiocaps_data_list(cfg.running, name) 230 | dataset.extend(subset) 231 | return build_dataloader(cfg, dataset, AudioTextDatasetSrc, shuffle=shuffle, train=train) 232 | 233 | def build_audio_text_dataloader(cfg, data_name, *args, shuffle=True, train=True, **kwargs): 234 | if data_name.startswith("clotho"): 235 | return build_dataloader_clotho( 236 | cfg, data_name, shuffle=shuffle, train=train 237 | ) 238 | elif data_name.startswith("audiocaps"): 239 | #from .audioset import build_audioset_dataloader 240 | #return build_audioset_dataloader(cfg, data_name, dict(), shuffle=shuffle, train=train) 241 | return build_dataloader_audiocaps( 242 | cfg, data_name, shuffle=shuffle, train=train 243 | ) 244 | else: 245 | raise ValueError(f"unrecognized dataset `{data_name}`.") 246 | 247 | -------------------------------------------------------------------------------- /cvap/data/audiocaps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import json 5 | import torch 6 | import random 7 | import warnings 8 | import itertools 9 | import torchaudio 10 | import numpy as np 11 | from pathlib import Path 12 | from tqdm import tqdm 13 | from PIL import Image as PILImage 14 | from itertools import cycle, islice, chain 15 | from einops import rearrange, repeat 16 | from collections import defaultdict 17 | from tabulate import tabulate 18 | from termcolor import colored 19 | 20 | import multiprocessing as mp 21 | import torch.utils.data as data 22 | import torch.nn.functional as F 23 | from torchvision.transforms import ( 24 | InterpolationMode, Compose, Resize, CenterCrop, ToTensor, Normalize 25 | ) 26 | 27 | from .audio import ( 28 | make_transform, _extract_kaldi_spectrogram 29 | ) 30 | from .image import make_clip_image_transform as make_image_transform 31 | from clip import tokenize 32 | 33 | class AudioCapDatasetSrc(data.Dataset): 34 | """ `__getitem__' loads raw file from disk. 35 | """ 36 | def __init__(self, cfg, data_name, train, label_map): 37 | data_path = f"{cfg.data_root}/{data_name}.csv" 38 | assert os.path.isfile(data_path), f"{data_path} is not a file." 39 | self.label_map = label_map 40 | self.num_label = len(label_map) 41 | label_counts = np.zeros(self.num_label) 42 | self.dataset = list() 43 | with open(data_path, "r") as fr: 44 | for iline, line in enumerate(fr): 45 | record = json.loads(line) 46 | record["captions_bpe"] = tokenize( 47 | record["captions"], as_list=True 48 | ) # add bpe captions 49 | self.dataset.append(record) 50 | if not train and iline + 1 == cfg.eval_samples: 51 | break 52 | if train and cfg.train_samples > 0. and cfg.train_samples < 1.: 53 | k = int(len(self.dataset) * cfg.train_samples) 54 | #self.dataset = np.random.choice(self.dataset, k, replace=False) 55 | shuffled_indice = np.random.permutation(np.random.permutation(len(self.dataset))) 56 | self.dataset = [self.dataset[i] for i in shuffled_indice[:k]] 57 | self.length = len(self.dataset) 58 | self.audio_norms = cfg.audio.norms 59 | self.aclip_key = "clip" if "clip" in self.dataset[0] else "aclip" 60 | self.frame_key = cfg.frame_key 61 | self.train = train 62 | self.cfg = cfg 63 | 64 | self.rnd_cap = getattr(cfg, "rnd_cap", False) # random AL fine-tuning baseline 65 | 66 | acfg = cfg.audio 67 | self.transform_image = make_image_transform(cfg.resolution) 68 | self.transform_audio, self.transform_fbank = make_transform(acfg) 69 | self.kaldi_params = { 70 | "htk_compat": True, 71 | "use_energy": False, 72 | "window_type": 'hanning', 73 | "num_mel_bins": acfg.num_mel_bins, 74 | "dither": 0.0, 75 | "frame_shift": 10 76 | } 77 | 78 | def _shuffle(self): 79 | pass 80 | 81 | def _process_item(self, index): 82 | akey = self.aclip_key 83 | fkey = self.frame_key 84 | sub_dir = self.dataset[index]["dir"] 85 | name = self.dataset[index]["id"] 86 | aclip = self.dataset[index][akey][0] 87 | frame = images = self.dataset[index][fkey] 88 | categories = self.dataset[index]["labels"] 89 | 90 | sub_dir = "" if len(sub_dir) == 0 else f"{sub_dir}/" 91 | aclip_file = f"{self.cfg.data_root}/{sub_dir}{akey}/{name}.{aclip}" 92 | 93 | frame_file = frame_emb_file = None 94 | if self.cfg.imagine: 95 | if isinstance(frame, str): 96 | frame_file = f"{self.cfg.data_root}/{sub_dir}{fkey}/{name}.{frame}" 97 | else: 98 | idx = np.random.choice(len(images), 1)[0] if self.train else int(np.ceil(len(images) / 2)) - 1 99 | frame_file = f"{self.cfg.data_root}/{sub_dir}{fkey}/{name}.{images[idx]}" 100 | if self.cfg.frame_emb is not None: 101 | frame_emb_file = f"{self.cfg.data_root}/{self.cfg.frame_emb}/{name}.{images[idx].rsplit('.', 1)[0]}.npz" 102 | 103 | label = [0] 104 | captions_bpe = self.dataset[index]["captions_bpe"] 105 | if self.train: 106 | if self.rnd_cap: # random baseline 107 | rnd_idx = np.random.randint(0, self.length) 108 | captions_bpe = self.dataset[rnd_idx]["captions_bpe"] 109 | icp = np.random.choice(len(captions_bpe), 1)[0] 110 | text_bpe = captions_bpe[icp] 111 | else: 112 | text_bpe = captions_bpe 113 | 114 | item = {"text": text_bpe, "name": name} 115 | return item, label, aclip_file, frame_file, frame_emb_file 116 | 117 | def _image2embed(self, fname): 118 | try: 119 | image = np.load(fname)["v"] 120 | except Exception as e: 121 | image = np.random.rand(self.cfg.embed_dim).astype("float32") 122 | warnings.warn(f"use random image instead because `{e}` {fname}.") 123 | return image 124 | 125 | def _image2numpy(self, fname): 126 | if fname is not None: 127 | try: 128 | if fname.endswith(".npz"): 129 | images = np.load(fname) 130 | images = [images[key] for key in images.files if len(images[key]) != 0] 131 | idx = np.random.choice(len(images), 1)[0] if self.train else int(np.ceil(len(images) / 2)) - 1 132 | image = images[idx] 133 | else: 134 | image = PILImage.open(fname) 135 | image = self.transform_image(image).cpu().numpy() 136 | except Exception as e: 137 | h = w = self.cfg.resolution 138 | image = PILImage.fromarray( 139 | (np.random.rand(h, w, 3) * 256).astype(np.uint8) 140 | ) 141 | warnings.warn(f"use random image instead because `{e}` {fname}.") 142 | image = self.transform_image(image).cpu().numpy() 143 | else: 144 | image = np.array([[[1]]]) 145 | return image 146 | 147 | def _audio2numpy_clf(self, aclip_file, label): 148 | wf, sr = torchaudio.load(aclip_file) 149 | wf = wf[:1] #wf.mean(0, keepdim=True) 150 | wf = wf - wf.mean() 151 | 152 | sampler = np.random if self.cfg.np_rnd else random 153 | 154 | #if self.train and sampler.random() < self.cfg.mixup_rate: 155 | if not self.cfg.audio.eval_norms and self.train and sampler.random() < self.cfg.mixup_rate: 156 | idx_mix = sampler.randint(0, self.length if self.cfg.np_rnd else self.length - 1) 157 | _, label_mix, aclip_file, _, _ = self._process_item(idx_mix) 158 | wf_mix, _ = torchaudio.load(aclip_file) 159 | wf_mix = wf_mix[:1] #wf_mix.mean(0, keepdim=True) 160 | wf_mix = wf_mix - wf_mix.mean() 161 | 162 | wf_len = wf.shape[1] 163 | wf_mix = wf_mix[:, :wf_len] 164 | npad = wf_len - wf_mix.shape[1] 165 | if npad > 0: 166 | wf_mix = F.pad(wf_mix, (0, npad), mode='constant', value=0.) 167 | 168 | lambd = np.random.beta(10, 10) # sample lambda from beta distribtion 169 | wf_mixed = lambd * wf + (1 - lambd) * wf_mix 170 | wf_mixed = wf_mixed - wf_mixed.mean() 171 | wf = wf_mixed 172 | 173 | label = lambd * np.array(label) + (1 - lambd) * np.array(label_mix) 174 | label = label.tolist() 175 | 176 | audio = torchaudio.compliance.kaldi.fbank( 177 | wf, 178 | sample_frequency=sr, 179 | **self.kaldi_params 180 | ) 181 | 182 | max_audio_len = self.cfg.max_audio_len 183 | audio = audio[:max_audio_len] 184 | npad = max_audio_len - audio.shape[0] 185 | if npad > 0: 186 | audio = F.pad(audio, (0, 0, 0, npad), mode='constant', value=0.) 187 | return audio 188 | 189 | def _audio2numpy_cst(self, aclip_file): 190 | max_audio_len = self.cfg.max_audio_len 191 | audio = _extract_kaldi_spectrogram( 192 | aclip_file, 193 | self.kaldi_params, 194 | train=self.train, 195 | max_audio_len=max_audio_len, 196 | zero_mean_wf=self.cfg.audio.zero_mean_wf, 197 | transform_audio=( 198 | self.transform_audio if self.train and not self.cfg.audio.eval_norms else None 199 | ) 200 | ) # (..., time, freq) 201 | 202 | npad = max_audio_len - audio.shape[0] 203 | if npad > 0: 204 | audio = np.pad(audio, ((0, npad), (0, 0)), "constant", constant_values=(0., 0.)) 205 | return audio 206 | 207 | def __getitem__(self, index): 208 | item, label, aclip_file, frame_file, frame_emb_file = self._process_item(index) 209 | 210 | # higher priority for pre-computed frame embeddings 211 | image = (self._image2embed(frame_emb_file) 212 | if frame_emb_file is not None and self.cfg.imagine else self._image2numpy(frame_file) 213 | ) 214 | audio = (self._audio2numpy_clf(aclip_file, label) 215 | if self.cfg.clf else self._audio2numpy_cst(aclip_file) 216 | ) 217 | 218 | if not self.cfg.audio.eval_norms and len(self.audio_norms) == 2: 219 | mean, std = self.audio_norms 220 | audio = (audio - mean) / std 221 | 222 | #if self.train and self.transform_fbank is not None: 223 | if not self.cfg.audio.eval_norms and self.train and self.transform_fbank is not None: 224 | audio = self.transform_fbank(audio) 225 | 226 | image = image[None] 227 | audio = audio[None] 228 | item.update({"image": image, "audio": audio, "label": label}) 229 | return item 230 | 231 | def __len__(self): 232 | return self.length 233 | -------------------------------------------------------------------------------- /cvap/data/audio/transform.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import torchaudio 4 | import numpy as np 5 | from omegaconf.listconfig import ListConfig 6 | from omegaconf.dictconfig import DictConfig 7 | 8 | import torchvision.transforms as transforms 9 | from torchvision.transforms import Compose, ToTensor, Normalize 10 | from torchaudio.transforms import FrequencyMasking, TimeMasking 11 | 12 | def _extract_kaldi_spectrogram( 13 | filename, params, train=True, mean_channel=False, zero_mean_wf=False, max_audio_len=1000, transform_audio=None, tile_audio=False, 14 | ): 15 | waveform, sample_rate = torchaudio.load(filename) 16 | if mean_channel: # mean along channel # TODO else branch should take a specific channel 17 | waveform = waveform.mean(0, keepdim=True) 18 | desired_len = int((max_audio_len / 100) * sample_rate) 19 | if tile_audio and desired_len > waveform.shape[-1]: 20 | ntile = int(np.ceil(desired_len / waveform.shape[-1])) 21 | waveform = torch.tile(waveform, (1, ntile))[:desired_len] 22 | if transform_audio is not None: 23 | waveform = transform_audio(waveform) 24 | waveform = RandomCrop.random_crop( 25 | waveform, int((max_audio_len / 100 + 0.05) * sample_rate), train=train 26 | ) # divided by 100 because kaldi has a frame shift of 10, additional 0.05s 27 | if zero_mean_wf: # TODO should extract the 1st channel before the mean 28 | waveform = waveform - waveform.mean() 29 | fbank_feat = torchaudio.compliance.kaldi.fbank( 30 | waveform, 31 | sample_frequency=sample_rate, 32 | **params, 33 | ) 34 | fbank_feat = fbank_feat[:max_audio_len] 35 | return fbank_feat.numpy() 36 | 37 | def make_transform(cfg): 38 | transform_fbank = transform_audio = None 39 | if cfg.transform_audio: 40 | tfm_list = list() 41 | for name, params in cfg.audio_transforms: 42 | if isinstance(params, DictConfig): 43 | tfm_list.append(eval(name)(**params)) 44 | else: 45 | tfm_list.append(eval(name)(*params)) 46 | if len(tfm_list) > 0: 47 | transform_audio = Compose(tfm_list) 48 | if cfg.transform_fbank: 49 | tfm_list = list() 50 | for name, params in cfg.fbank_transforms: 51 | if isinstance(params, DictConfig): 52 | tfm_list.append(eval(name)(**params)) 53 | else: 54 | tfm_list.append(eval(name)(*params)) 55 | if len(tfm_list) > 0: 56 | tfm_list = [lambda x: x.T, ToTensorKeepdim()] + tfm_list + [lambda x: x.T] 57 | transform_fbank = Compose(tfm_list) 58 | #print(transform_audio, transform_fbank) 59 | return transform_audio, transform_fbank 60 | 61 | class ToTensorKeepdim(ToTensor): 62 | def __call__(self, x): 63 | if isinstance(x, torch.Tensor): 64 | return x 65 | x = super(ToTensorKeepdim, self).__call__(x[..., None]) 66 | return x.squeeze_(0) 67 | 68 | class AbstractTransform(abc.ABC): 69 | @abc.abstractmethod 70 | def __call__(self, x): 71 | pass 72 | def __repr__(self): 73 | return self.__class__.__name__ + '()' 74 | 75 | class RandomFlip(AbstractTransform): 76 | def __init__(self, p=0.5): 77 | super(RandomFlip, self).__init__() 78 | self.p = p 79 | 80 | @staticmethod 81 | def random_flip(x, p): 82 | if x.dim() > 2: 83 | flip_mask = torch.rand(x.shape[0], device=x.device) <= p 84 | x[flip_mask] = x[flip_mask].flip(-1) 85 | else: 86 | if torch.rand(1) <= p: 87 | x = x.flip(-1) 88 | return x 89 | 90 | def __call__(self, x): 91 | return self.random_flip(x, self.p) 92 | 93 | class RandomScale(AbstractTransform): 94 | def __init__(self, scale=1.5, keep_len=False): 95 | super(RandomScale, self).__init__() 96 | self.scale = scale 97 | self.keep_len = keep_len 98 | 99 | @staticmethod 100 | def random_scale(x, scale, keep_len): 101 | scaling = np.power(scale, np.random.uniform(-1, 1)) 102 | output_len = int(x.shape[-1] * scaling) 103 | base = torch.arange(output_len, device=x.device, dtype=x.dtype).div_(scaling) 104 | 105 | ref1 = base.clone().type(torch.int64) 106 | ref2 = torch.min(ref1 + 1, torch.full_like(ref1, x.shape[-1] - 1, dtype=torch.int64)) 107 | r = base - ref1.type(base.type()) 108 | scaled_x = (1 - r) * x[..., ref1] + r * x[..., ref2] 109 | if keep_len: 110 | scaled_x = RandomCrop.random_crop(scaled_x, x.shape[-1], True) # keep the same length 111 | return scaled_x 112 | 113 | def __call__(self, x): 114 | return self.random_scale(x, self.scale, self.keep_len) 115 | 116 | class RandomCrop(AbstractTransform): 117 | def __init__(self, output_len=44100, train=True): 118 | super(RandomCrop, self).__init__() 119 | self.output_len = output_len 120 | self.train = train 121 | 122 | @staticmethod 123 | def random_crop(x, output_len, train): 124 | if x.shape[-1] <= output_len: 125 | return x 126 | if train: 127 | left = np.random.randint(0, x.shape[-1] - output_len) 128 | else: # center 129 | left = int(round(0.5 * (x.shape[-1] - output_len))) 130 | 131 | old_std = x.float().std() * 0.5 132 | cropped_x = x[..., left : left + output_len] 133 | 134 | new_std = cropped_x.float().std() 135 | if new_std < old_std: 136 | cropped_x = x[..., : output_len] 137 | 138 | out_std = cropped_x.float().std() 139 | if old_std > new_std > out_std: 140 | cropped_x = x[..., -output_len:] 141 | return cropped_x 142 | 143 | def __call__(self, x): 144 | return self.random_crop(x, self.output_len, self.train) 145 | 146 | class RandomPad(AbstractTransform): 147 | def __init__(self, output_len=88200, train=True, padding_value=None): 148 | super(RandomPad, self).__init__() 149 | self.output_len = output_len 150 | self.train = train 151 | self.padding_value = padding_value 152 | 153 | @staticmethod 154 | def random_pad(x, output_len, train, padding_value=None): 155 | if x.shape[-1] >= output_len: 156 | return x 157 | if train: 158 | left = np.random.randint(0, output_len - x.shape[-1]) 159 | else: # center 160 | left = int(round(0.5 * (output_len - x.shape[-1]))) 161 | 162 | right = output_len - (left + x.shape[-1]) 163 | if padding_value is not None: 164 | pad_value_left = pad_value_right = padding_value 165 | else: # mean over channel? 166 | pad_value_left = x[..., 0].float().mean().to(x.dtype) 167 | pad_value_right = x[..., -1].float().mean().to(x.dtype) 168 | padded_x = torch.cat(( 169 | torch.zeros(x.shape[:-1] + (left,), dtype=x.dtype, device=x.device).fill_(pad_value_left), 170 | x, 171 | torch.zeros(x.shape[:-1] + (right,), dtype=x.dtype, device=x.device).fill_(pad_value_right) 172 | ), dim=-1) 173 | return padded_x 174 | 175 | def __call__(self, x): 176 | return self.random_pad(x, self.output_len, self.train, self.padding_value) 177 | 178 | class RandomNoise(AbstractTransform): 179 | def __init__(self, snr_min_db=10.0, snr_max_db=120.0, p=0.25): 180 | super(RandomNoise, self).__init__() 181 | self.snr_min_db = snr_min_db 182 | self.snr_max_db = snr_max_db 183 | self.p = p 184 | 185 | @staticmethod 186 | def random_noise(x, snr_min_db, snr_max_db, p): 187 | if np.random.rand() > p: 188 | return x 189 | target_snr = np.random.rand() * (snr_max_db - snr_min_db + 1.0) + snr_min_db 190 | 191 | x_watts = torch.mean(x ** 2, dim=(-1, -2)) 192 | x_db = 10 * torch.log10(x_watts) 193 | 194 | noise_db = x_db - target_snr 195 | noise_watts = 10 ** (noise_db / 10) + 1e-7 196 | noise = torch.normal(0.0, noise_watts.item() ** 0.5, x.shape) 197 | 198 | noise_x = x + noise 199 | return noise_x 200 | 201 | def __call__(self, x): 202 | return self.random_noise(x, self.snr_min_db, self.snr_max_db, self.p) 203 | 204 | class SimpleRandomNoise(AbstractTransform): 205 | def __init__(self, scale=10.0, shift=10, p=0.25): 206 | super(SimpleRandomNoise, self).__init__() 207 | self.scale = scale 208 | self.shift = shift 209 | self.p = p 210 | 211 | @staticmethod 212 | def random_noise(x, scale, shift, p): 213 | # expect a 2d tensor 214 | if np.random.rand() > p: 215 | return x 216 | noise_x = x + torch.rand(x.shape) * np.random.rand() / scale 217 | noise_x = torch.roll(noise_x, np.random.randint(-shift, shift), -1) 218 | return noise_x 219 | 220 | def __call__(self, x): 221 | return self.random_noise(x, self.scale, self.shift, self.p) 222 | 223 | class FbankTransform: 224 | def __init__(self): 225 | self.transform = transforms.Compose([ 226 | lambda x: x.T, 227 | transforms.ToTensor(), 228 | transforms.Normalize( 229 | mean=[-4.93839311], std=[5.75751113] 230 | ), 231 | FrequencyMasking(48), 232 | TimeMasking(300), 233 | lambda x: x.transpose(-1, -2) 234 | ]) 235 | self.transform_prime = transforms.Compose([ 236 | lambda x: x.T, 237 | transforms.ToTensor(), 238 | transforms.Normalize( 239 | mean=[-4.93839311], std=[5.75751113] 240 | ), 241 | FrequencyMasking(32), 242 | TimeMasking(200), 243 | lambda x: x.transpose(-1, -2) 244 | ]) 245 | self.transform_eval = transforms.Compose([ 246 | transforms.ToTensor(), 247 | transforms.Normalize( 248 | mean=[-4.93839311], std=[5.75751113] 249 | ), 250 | ]) 251 | 252 | def __call__(self, x, both, train): 253 | if not train: 254 | return self.transform_eval(x), np.array([[[1]]]) 255 | else: 256 | y1 = self.transform_prime(x) 257 | y2 = self.transform(x) if both else np.array([[[1]]]) 258 | return y1, y2 259 | -------------------------------------------------------------------------------- /cvap/module/encoder/audio_head.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | from fvcore.common.registry import Registry 4 | from omegaconf.listconfig import ListConfig 5 | from functools import partial 6 | 7 | import re 8 | import math 9 | import copy 10 | import warnings 11 | import threading 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | 17 | from timm.models.layers import to_2tuple 18 | from .. import ModifiedResNet, VisualTransformer, DistilledVisionTransformer, PatchEmbed 19 | 20 | AUDIO_HEADS_REGISTRY = Registry("AUDIO_HEADS") 21 | AUDIO_HEADS_REGISTRY.__doc__ = """ 22 | Registry for audio encoders. 23 | """ 24 | 25 | def build_audio_head(cfg, **kwargs): 26 | return AUDIO_HEADS_REGISTRY.get(cfg.name)(cfg, **kwargs) 27 | 28 | def position_resolution(input_resolution, patch_size, stride): 29 | input_resolution = list(to_2tuple(input_resolution)) 30 | patch_size = list(to_2tuple(patch_size)) 31 | 32 | stride = stride or patch_size 33 | if isinstance(stride, int): 34 | stride = [stride] * 2 35 | stride = list(stride) 36 | 37 | row_stride, col_stride = stride[:2] 38 | nrow = (input_resolution[0] - patch_size[0]) // row_stride + 1 39 | ncol = (input_resolution[1] - patch_size[1]) // col_stride + 1 40 | return nrow, ncol 41 | 42 | def interp_conv_weight(old_dict, new_dict, key): 43 | old_conv_weight = old_dict[key] 44 | new_conv_weight = new_dict[key] 45 | if new_conv_weight.shape[2:] != old_conv_weight.shape[2:]: 46 | old_conv_weight = F.interpolate( 47 | old_conv_weight, 48 | new_conv_weight.shape[2:], 49 | mode="bilinear", 50 | align_corners=False, 51 | ) 52 | return old_conv_weight 53 | 54 | def interp_pos_embedding(state_dict, old_dict, new_dict, key, bop, pos_resolution): 55 | """bop: start position of the postional embeddings""" 56 | add_leading_dim = False 57 | old_pos_emb = state_dict[key] 58 | if old_pos_emb.dim() == 3: # ensure of rank-2 tensor 59 | assert old_pos_emb.shape[0] == 1 60 | old_pos_emb = old_pos_emb.squeeze(0) 61 | add_leading_dim = True 62 | num_pos, pos_dim = old_pos_emb.shape[-2:] 63 | 64 | num_pos = int(np.sqrt(num_pos - bop)) 65 | ptensor = old_pos_emb[bop:].reshape( 66 | -1, num_pos, num_pos, pos_dim 67 | ).permute(0, 3, 1, 2) 68 | 69 | new_pos_emb = F.interpolate( 70 | ptensor, 71 | pos_resolution, 72 | mode="bilinear", 73 | align_corners=False, 74 | ).permute(0, 2, 3, 1).flatten(1, 2) 75 | new_pos_emb = torch.cat(( 76 | old_pos_emb[:bop], new_pos_emb.view(-1, pos_dim) 77 | ), dim=0) 78 | new_pos_emb = new_pos_emb.unsqueeze(0) if add_leading_dim else new_pos_emb 79 | old_dict[key] = new_pos_emb 80 | 81 | new_keys = set(new_dict.keys()) 82 | old_keys = set(old_dict.keys()) 83 | new_dict.update(old_dict) 84 | n_o = new_keys - old_keys 85 | o_n = old_keys - new_keys 86 | #print(f"{n_o}\n{o_n}") 87 | return n_o, o_n 88 | 89 | def load_pos_embedding( 90 | state_dict, old_dict, new_dict, key, bop, old_pos_shape, new_pos_shape, use_slice=True 91 | ): 92 | add_leading_dim = False 93 | old_pos_emb = state_dict[key] 94 | if old_pos_emb.dim() == 3: # ensure of rank-2 tensor 95 | assert old_pos_emb.shape[0] == 1 96 | old_pos_emb = old_pos_emb.squeeze(0) 97 | add_leading_dim = True 98 | num_pos, pos_dim = old_pos_emb.shape[-2:] 99 | num_pos_required = np.prod(new_pos_shape) 100 | 101 | if new_pos_shape == old_pos_shape: 102 | new_pos_emb = old_pos_emb # do nothing 103 | elif use_slice and new_pos_shape[-1] == old_pos_shape[-1] and num_pos_required + bop <= num_pos: 104 | extra = old_pos_shape[-2] - new_pos_shape[-2] 105 | if extra == 0: 106 | new_pos_emb = old_pos_emb[:num_pos_required + bop] # first k time steps 107 | else: 108 | start = 6 # [0, extra] 109 | start = start * old_pos_shape[-1] + bop 110 | new_pos_emb = torch.cat(( 111 | old_pos_emb[:bop], old_pos_emb[start : start + num_pos_required] 112 | ), 0) 113 | else: # interpolate 114 | shape = (-1,) + old_pos_shape + (pos_dim,) 115 | ptensor = old_pos_emb[bop:].reshape(shape).permute(0, 3, 1, 2) 116 | new_pos_emb = F.interpolate( 117 | ptensor, 118 | new_pos_shape, 119 | mode="bilinear", 120 | align_corners=False, 121 | ).permute(0, 2, 3, 1).flatten(1, 2) 122 | new_pos_emb = torch.cat(( 123 | old_pos_emb[:bop], new_pos_emb.view(-1, pos_dim) 124 | ), dim=0) 125 | new_pos_emb = new_pos_emb.unsqueeze(0) if add_leading_dim else new_pos_emb 126 | old_dict[key] = new_pos_emb 127 | 128 | new_keys = set(new_dict.keys()) 129 | old_keys = set(old_dict.keys()) 130 | new_dict.update(old_dict) 131 | n_o = new_keys - old_keys 132 | o_n = old_keys - new_keys 133 | #print(f"{n_o}\n{o_n}") 134 | return n_o, o_n 135 | 136 | @AUDIO_HEADS_REGISTRY.register() 137 | class NaiveCLIPAudioHead(nn.Module): 138 | def __init__(self, cfg, **kwargs): 139 | super().__init__() 140 | if isinstance(cfg.layers, (tuple, list, ListConfig)): 141 | heads = cfg.width * 32 // 64 142 | self.encoder = ModifiedResNet( 143 | in_channels=getattr(cfg, "in_channel", 1), 144 | input_resolution=cfg.resolution, 145 | output_dim=cfg.embed_dim, 146 | layers=cfg.layers, 147 | width=cfg.width, 148 | heads=heads, 149 | ) 150 | else: 151 | heads = cfg.width // 64 152 | self.encoder = VisualTransformer( 153 | in_channels=getattr(cfg, "in_channel", 1), 154 | stride=cfg.stride, 155 | input_resolution=cfg.resolution, 156 | output_dim=cfg.embed_dim, 157 | patch_size=cfg.patch_size, 158 | layers=cfg.layers, 159 | width=cfg.width, 160 | heads=heads, 161 | ) 162 | 163 | def from_pretrained(self, state_dict, cfg, *args, **kwargs): 164 | if (list(state_dict.keys())[0]).startswith("encoder."): 165 | audio_head_sd_new = OrderedDict() 166 | for k, v in state_dict.items(): 167 | k = re.sub("^encoder\.", "", k) 168 | audio_head_sd_new[k] = v 169 | state_dict = audio_head_sd_new 170 | excluded = ["positional_embedding", "attnpool.positional_embedding"] 171 | new_dict = self.encoder.state_dict() 172 | old_dict = {k: v for k, v in state_dict.items() if k not in excluded} 173 | # interpolate positional embedding 174 | key = ("attnpool.positional_embedding" 175 | if isinstance(self.encoder, ModifiedResNet) else "positional_embedding" 176 | ) 177 | new_pos_shape = self.encoder.position_resolution 178 | old_pos_shape = position_resolution( 179 | cfg.model.audio.resolution, cfg.model.audio.patch_size, cfg.model.audio.stride 180 | ) # nrow always indicates the time dimenstion 181 | n_o, o_n = load_pos_embedding( 182 | state_dict, old_dict, new_dict, key, 1, old_pos_shape, new_pos_shape 183 | ) 184 | self.encoder.load_state_dict(new_dict) 185 | return n_o, o_n 186 | 187 | def copy_state_dict(self, state_dict): 188 | excluded = ["conv1.weight", "positional_embedding", "attnpool.positional_embedding"] 189 | new_dict = self.encoder.state_dict() 190 | old_dict = {k: v for k, v in state_dict.items() if k not in excluded} 191 | # conv1: 3 channels -> 1 channel 192 | conv_key = "conv1.weight" 193 | old_conv_weight = interp_conv_weight(state_dict, new_dict, conv_key) 194 | old_dict[conv_key] = (old_conv_weight.mean(1, keepdim=True) 195 | if new_dict[conv_key].shape[1] != old_conv_weight.shape[1] else old_conv_weight 196 | ) 197 | # interpolate positional embedding 198 | key = ("attnpool.positional_embedding" 199 | if isinstance(self.encoder, ModifiedResNet) else "positional_embedding" 200 | ) 201 | n_o, o_n = interp_pos_embedding( 202 | state_dict, old_dict, new_dict, key, 1, self.encoder.position_resolution 203 | ) 204 | self.encoder.load_state_dict(new_dict) 205 | return n_o, o_n 206 | 207 | def forward(self, audios, *args, **kwargs): 208 | z = self.encoder(audios) 209 | if kwargs.get("normalized", False): 210 | z = z / z.norm(dim=-1, keepdim=True) 211 | #print(f"{threading.current_thread().ident} audio --{kwargs.get('normalized', False)}") 212 | return z 213 | 214 | @AUDIO_HEADS_REGISTRY.register() 215 | class NaiveDeiTAudioHead(nn.Module): 216 | def __init__(self, cfg, **kwargs): 217 | super().__init__() 218 | heads = cfg.width // 64 219 | self.encoder = DistilledVisionTransformer( 220 | img_size=cfg.resolution, 221 | # hack and has to be used with the customized `PatchEmbed` 222 | patch_size={"patch_size": cfg.patch_size, "stride": cfg.stride}, 223 | representation_size=False, 224 | output_dim=cfg.embed_dim, 225 | embed_dim=cfg.width, 226 | depth=cfg.layers, 227 | num_heads=heads, 228 | mlp_ratio=4, 229 | qkv_bias=True, 230 | in_chans=getattr(cfg, "in_channel", 1), 231 | num_classes=-1, 232 | embed_layer=PatchEmbed, 233 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 234 | **kwargs 235 | ) 236 | 237 | def from_pretrained(self, state_dict, cfg, *args, **kwargs): 238 | if (list(state_dict.keys())[0]).startswith("encoder."): 239 | audio_head_sd_new = OrderedDict() 240 | for k, v in state_dict.items(): 241 | k = re.sub("^encoder\.", "", k) 242 | audio_head_sd_new[k] = v 243 | state_dict = audio_head_sd_new 244 | excluded = ["pos_embed"] 245 | new_dict = self.encoder.state_dict() 246 | old_dict = {k: v for k, v in state_dict.items() if k not in excluded} 247 | # interpolate positional embedding 248 | key = "pos_embed" 249 | new_pos_shape = self.encoder.patch_embed.grid_size 250 | old_pos_shape = position_resolution( 251 | cfg.model.audio.resolution, cfg.model.audio.patch_size, cfg.model.audio.stride 252 | ) # nrow always indicates the time dimenstion 253 | n_o, o_n = load_pos_embedding( 254 | state_dict, old_dict, new_dict, key, 2, old_pos_shape, new_pos_shape 255 | ) 256 | self.encoder.load_state_dict(new_dict) 257 | return n_o, o_n 258 | 259 | def copy_state_dict(self, state_dict): 260 | excluded = ["patch_embed.proj.weight", "pos_embed"] 261 | new_dict = self.encoder.state_dict() 262 | old_dict = {k: v for k, v in state_dict.items() if k not in excluded and k in new_dict} 263 | # conv1: 3 channels -> 1 channel 264 | conv_key = "patch_embed.proj.weight" 265 | old_conv_weight = interp_conv_weight(state_dict, new_dict, conv_key) 266 | old_dict[conv_key] = (old_conv_weight.mean(1, keepdim=True) 267 | if new_dict[conv_key].shape[1] != old_conv_weight.shape[1] else old_conv_weight 268 | ) 269 | # interpolate positional embedding 270 | key = "pos_embed" 271 | n_o, o_n = interp_pos_embedding( 272 | state_dict, old_dict, new_dict, key, 2, self.encoder.patch_embed.grid_size 273 | ) 274 | self.encoder.load_state_dict(new_dict) 275 | return n_o, o_n 276 | 277 | def forward(self, audios, *args, **kwargs): 278 | cls_z, distilled_z = self.encoder.forward_features(audios) 279 | z = (cls_z + distilled_z) / 2 280 | if kwargs.get("normalized", False): 281 | z = z / z.norm(dim=-1, keepdim=True) 282 | #print(f"{threading.current_thread().ident} image --{kwargs.get('normalized', False)}") 283 | return z 284 | --------------------------------------------------------------------------------