├── MedleyMDPrompts ├── LICENSE ├── captions_sources.csv └── captions_targets.csv ├── README.md ├── code ├── LICENSE ├── LICENSE_StableAudioOpen ├── LICENSE_audioldm ├── __init__.py ├── audioldm │ ├── __init__.py │ ├── __main__.py │ ├── audio │ │ ├── __init__.py │ │ ├── audio_processing.py │ │ ├── stft.py │ │ └── tools.py │ ├── clap │ │ ├── __init__.py │ │ ├── encoders.py │ │ ├── open_clip │ │ │ ├── __init__.py │ │ │ ├── bert.py │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ ├── factory.py │ │ │ ├── feature_fusion.py │ │ │ ├── htsat.py │ │ │ ├── linear_probe.py │ │ │ ├── loss.py │ │ │ ├── model.py │ │ │ ├── model_configs │ │ │ │ ├── HTSAT-base.json │ │ │ │ ├── HTSAT-large.json │ │ │ │ ├── HTSAT-tiny-win-1536.json │ │ │ │ ├── HTSAT-tiny.json │ │ │ │ ├── PANN-10.json │ │ │ │ ├── PANN-14-fmax-18k.json │ │ │ │ ├── PANN-14-fmax-8k-20s.json │ │ │ │ ├── PANN-14-tiny-transformer.json │ │ │ │ ├── PANN-14-win-1536.json │ │ │ │ ├── PANN-14.json │ │ │ │ ├── PANN-6.json │ │ │ │ ├── RN101-quickgelu.json │ │ │ │ ├── RN101.json │ │ │ │ ├── RN50-quickgelu.json │ │ │ │ ├── RN50.json │ │ │ │ ├── RN50x16.json │ │ │ │ ├── RN50x4.json │ │ │ │ ├── ViT-B-16.json │ │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ │ ├── ViT-B-32.json │ │ │ │ └── ViT-L-14.json │ │ │ ├── openai.py │ │ │ ├── pann_model.py │ │ │ ├── pretrained.py │ │ │ ├── timm_model.py │ │ │ ├── tokenizer.py │ │ │ ├── transform.py │ │ │ ├── utils.py │ │ │ └── version.py │ │ └── training │ │ │ ├── __init__.py │ │ │ ├── audioset_textmap.npy │ │ │ ├── data.py │ │ │ ├── distributed.py │ │ │ ├── imagenet_zeroshot_data.py │ │ │ ├── infer_demo.py │ │ │ ├── logger.py │ │ │ ├── lp_main.py │ │ │ ├── lp_train.py │ │ │ ├── main.py │ │ │ ├── params.py │ │ │ ├── scheduler.py │ │ │ ├── train.py │ │ │ └── zero_shot.py │ ├── hifigan │ │ ├── __init__.py │ │ ├── models.py │ │ └── utilities.py │ ├── latent_diffusion │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── ema.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── ldm.py │ ├── pipeline.py │ ├── utils.py │ └── variational_autoencoder │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ ├── distributions.py │ │ └── modules.py ├── ddm_inversion │ ├── ddim_inversion.py │ └── inversion_utils.py ├── images_pc_apply_drift.py ├── images_pc_extract_inv.py ├── images_run_sdedit.py ├── main_pc_apply_drift.py ├── main_pc_extract_inv.py ├── main_run.py ├── main_run_sdedit.py ├── models.py ├── pc_drift.py └── utils.py ├── docs ├── favicon.ico ├── functionality.js ├── index.html ├── resources │ ├── audio │ │ ├── ddim │ │ │ ├── CuteCat_10sec_ddim.mp3 │ │ │ ├── CuteCat_10sec_ddim_10.mp3 │ │ │ ├── CuteCat_10sec_ddim_25.mp3 │ │ │ ├── CuteCat_10sec_ddim_50.mp3 │ │ │ ├── DogBarking_10sec_ddim.mp3 │ │ │ ├── DogBarking_10sec_ddim_10.mp3 │ │ │ ├── DogBarking_10sec_ddim_25.mp3 │ │ │ ├── DogBarking_10sec_ddim_50.mp3 │ │ │ ├── MusicDelta_Beatles_MIX_ddim_10.mp3 │ │ │ ├── MusicDelta_Beatles_MIX_ddim_25.mp3 │ │ │ ├── MusicDelta_Beatles_MIX_ddim_50.mp3 │ │ │ ├── MusicDelta_Beethoven_MIX_ddim_10.mp3 │ │ │ ├── MusicDelta_Beethoven_MIX_ddim_25.mp3 │ │ │ ├── MusicDelta_Beethoven_MIX_ddim_50.mp3 │ │ │ ├── MusicDelta_ModalJazz_MIX_ddim_10.mp3 │ │ │ ├── MusicDelta_ModalJazz_MIX_ddim_25.mp3 │ │ │ ├── MusicDelta_ModalJazz_MIX_ddim_50.mp3 │ │ │ ├── Shouting_ddim.mp3 │ │ │ ├── Shouting_ddim_10.mp3 │ │ │ ├── Shouting_ddim_25.mp3 │ │ │ ├── Shouting_ddim_50.mp3 │ │ │ ├── arabic_MDDBBritpop_ddim.mp3 │ │ │ ├── arcade_MDDBBeatles_ddim.mp3 │ │ │ ├── arcade_MDDBBebopJazz_ddim.mp3 │ │ │ ├── arcade_MDDBChineseYaoZu_ddim.mp3 │ │ │ ├── arcade_MDDBLatinJazz_ddim.mp3 │ │ │ ├── country_MDDBModalJazz_ddim.mp3 │ │ │ ├── hiphop_MDDBFunkJazz_ddim.mp3 │ │ │ ├── instrument_MDDBBeethoven_ddim.mp3 │ │ │ ├── jazz_MDDBPunk_ddim.mp3 │ │ │ ├── jazz_MDDBZeppelin_ddim.mp3 │ │ │ ├── metal_MDDBChineseDrama_ddim.mp3 │ │ │ ├── orchestra_MDDBFreeJazz_ddim.mp3 │ │ │ ├── techno_MDDBBeethoven_ddim.mp3 │ │ │ ├── techno_MDDBGospel_ddim.mp3 │ │ │ ├── techno_MDDBRockabilly_ddim.mp3 │ │ │ └── tone_MDDBChineseChaoZhou_ddim.mp3 │ │ ├── musicgen │ │ │ ├── arabic_MDDBBritpop_musicgen.mp3 │ │ │ ├── arcade_MDDBBeatles_musicgen.mp3 │ │ │ ├── arcade_MDDBBebopJazz_musicgen.mp3 │ │ │ ├── arcade_MDDBChineseYaoZu_musicgen.mp3 │ │ │ ├── arcade_MDDBLatinJazz_musicgen.mp3 │ │ │ ├── country_MDDBModalJazz_musicgen.mp3 │ │ │ ├── hiphop_MDDBFunkJazz_musicgen.mp3 │ │ │ ├── instrument_MDDBBeethoven_musicgen.mp3 │ │ │ ├── jazz_MDDBPunk_musicgen.mp3 │ │ │ ├── jazz_MDDBZeppelin_musicgen.mp3 │ │ │ ├── metal_MDDBChineseDrama_musicgen.mp3 │ │ │ ├── orchestra_MDDBFreeJazz_musicgen.mp3 │ │ │ ├── techno_MDDBBeethoven_musicgen.mp3 │ │ │ ├── techno_MDDBGospel_musicgen.mp3 │ │ │ ├── techno_MDDBRockabilly_musicgen.mp3 │ │ │ └── tone_MDDBChineseChaoZhou_musicgen.mp3 │ │ ├── orig │ │ │ ├── CuteCat_10sec.mp3 │ │ │ ├── DogBarking_10sec.mp3 │ │ │ ├── MDDBBeatles.mp3 │ │ │ ├── MDDBBeatles_8secs.mp3 │ │ │ ├── MDDBBebopJazz.mp3 │ │ │ ├── MDDBBeethoven.mp3 │ │ │ ├── MDDBBritpop.mp3 │ │ │ ├── MDDBBritpop_8secs.mp3 │ │ │ ├── MDDBChineseChaoZhou.mp3 │ │ │ ├── MDDBChineseDrama.mp3 │ │ │ ├── MDDBChineseXinJing.mp3 │ │ │ ├── MDDBChineseYaoZu.mp3 │ │ │ ├── MDDBCoolJazz.mp3 │ │ │ ├── MDDBCountry1.mp3 │ │ │ ├── MDDBCountry1_8secs.mp3 │ │ │ ├── MDDBFreeJazz.mp3 │ │ │ ├── MDDBFunkJazz.mp3 │ │ │ ├── MDDBGospel.mp3 │ │ │ ├── MDDBGospel_8secs.mp3 │ │ │ ├── MDDBHendrix.mp3 │ │ │ ├── MDDBHendrix_8secs.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing.mp3 │ │ │ ├── MDDBLatinJazz.mp3 │ │ │ ├── MDDBLatinJazz_8secs.mp3 │ │ │ ├── MDDBModalJazz.mp3 │ │ │ ├── MDDBPunk.mp3 │ │ │ ├── MDDBReggae_8secs.mp3 │ │ │ ├── MDDBRock_8secs.mp3 │ │ │ ├── MDDBRockabilly.mp3 │ │ │ ├── MDDBRockabilly_8secs.mp3 │ │ │ ├── MDDBSpeedMetal.mp3 │ │ │ ├── MDDBVivaldi.mp3 │ │ │ ├── MDDBZeppelin.mp3 │ │ │ ├── ManSpeakingBGNoise.mp3 │ │ │ └── Shouting.mp3 │ │ ├── sdedit │ │ │ ├── CuteCat_10sec_sdedit_100.mp3 │ │ │ ├── CuteCat_10sec_sdedit_130.mp3 │ │ │ ├── CuteCat_10sec_sdedit_50.mp3 │ │ │ ├── CuteCat_10sec_sdedit_80.mp3 │ │ │ ├── DogBarking_10sec_sdedit_100.mp3 │ │ │ ├── DogBarking_10sec_sdedit_130.mp3 │ │ │ ├── DogBarking_10sec_sdedit_50.mp3 │ │ │ ├── DogBarking_10sec_sdedit_80.mp3 │ │ │ ├── Shouting_sdedit_100.mp3 │ │ │ ├── Shouting_sdedit_130.mp3 │ │ │ ├── Shouting_sdedit_50.mp3 │ │ │ ├── Shouting_sdedit_80.mp3 │ │ │ ├── arabic_MDDBBritpop_sdedit_100.mp3 │ │ │ ├── arabic_MDDBBritpop_sdedit_130.mp3 │ │ │ ├── arabic_MDDBBritpop_sdedit_160.mp3 │ │ │ ├── arcade_MDDBBeatles_sdedit_100.mp3 │ │ │ ├── arcade_MDDBBeatles_sdedit_130.mp3 │ │ │ ├── arcade_MDDBBeatles_sdedit_160.mp3 │ │ │ ├── arcade_MDDBBebopJazz_sdedit_100.mp3 │ │ │ ├── arcade_MDDBBebopJazz_sdedit_130.mp3 │ │ │ ├── arcade_MDDBBebopJazz_sdedit_160.mp3 │ │ │ ├── arcade_MDDBChineseYaoZu_sdedit_100.mp3 │ │ │ ├── arcade_MDDBChineseYaoZu_sdedit_130.mp3 │ │ │ ├── arcade_MDDBChineseYaoZu_sdedit_160.mp3 │ │ │ ├── arcade_MDDBLatinJazz_sdedit_100.mp3 │ │ │ ├── arcade_MDDBLatinJazz_sdedit_130.mp3 │ │ │ ├── arcade_MDDBLatinJazz_sdedit_160.mp3 │ │ │ ├── country_MDDBModalJazz_sdedit_100.mp3 │ │ │ ├── country_MDDBModalJazz_sdedit_130.mp3 │ │ │ ├── country_MDDBModalJazz_sdedit_160.mp3 │ │ │ ├── hiphop_MDDBFunkJazz_sdedit_100.mp3 │ │ │ ├── hiphop_MDDBFunkJazz_sdedit_130.mp3 │ │ │ ├── hiphop_MDDBFunkJazz_sdedit_160.mp3 │ │ │ ├── instrument_MDDBBeethoven_sdedit_100.mp3 │ │ │ ├── instrument_MDDBBeethoven_sdedit_130.mp3 │ │ │ ├── instrument_MDDBBeethoven_sdedit_160.mp3 │ │ │ ├── jazz_MDDBPunk_sdedit_100.mp3 │ │ │ ├── jazz_MDDBPunk_sdedit_130.mp3 │ │ │ ├── jazz_MDDBPunk_sdedit_160.mp3 │ │ │ ├── jazz_MDDBZeppelin_sdedit_100.mp3 │ │ │ ├── jazz_MDDBZeppelin_sdedit_130.mp3 │ │ │ ├── jazz_MDDBZeppelin_sdedit_160.mp3 │ │ │ ├── metal_MDDBChineseDrama_sdedit_100.mp3 │ │ │ ├── metal_MDDBChineseDrama_sdedit_130.mp3 │ │ │ ├── metal_MDDBChineseDrama_sdedit_160.mp3 │ │ │ ├── orchestra_MDDBFreeJazz_sdedit_100.mp3 │ │ │ ├── orchestra_MDDBFreeJazz_sdedit_130.mp3 │ │ │ ├── orchestra_MDDBFreeJazz_sdedit_160.mp3 │ │ │ ├── techno_MDDBBeethoven_sdedit_100.mp3 │ │ │ ├── techno_MDDBBeethoven_sdedit_130.mp3 │ │ │ ├── techno_MDDBBeethoven_sdedit_160.mp3 │ │ │ ├── techno_MDDBGospel_sdedit_100.mp3 │ │ │ ├── techno_MDDBGospel_sdedit_130.mp3 │ │ │ ├── techno_MDDBGospel_sdedit_160.mp3 │ │ │ ├── techno_MDDBRockabilly_sdedit_100.mp3 │ │ │ ├── techno_MDDBRockabilly_sdedit_130.mp3 │ │ │ ├── techno_MDDBRockabilly_sdedit_160.mp3 │ │ │ ├── tone_MDDBChineseChaoZhou_sdedit_100.mp3 │ │ │ ├── tone_MDDBChineseChaoZhou_sdedit_130.mp3 │ │ │ └── tone_MDDBChineseChaoZhou_sdedit_160.mp3 │ │ ├── sup_samples │ │ │ ├── CuteCat_10sec_ours_50.mp3 │ │ │ ├── DogBarking_10sec_ours_100.mp3 │ │ │ ├── MDDBBeethoven_skipex2_ours_100.mp3 │ │ │ ├── MDDBBeethoven_skipex2_ours_110.mp3 │ │ │ ├── MDDBBeethoven_skipex2_ours_120.mp3 │ │ │ ├── MDDBBeethoven_skipex2_ours_130.mp3 │ │ │ ├── MDDBBeethoven_skipex2_ours_90.mp3 │ │ │ ├── MDDBLatinJazz_skipex1_ours_100.mp3 │ │ │ ├── MDDBLatinJazz_skipex1_ours_110.mp3 │ │ │ ├── MDDBLatinJazz_skipex1_ours_120.mp3 │ │ │ ├── MDDBLatinJazz_skipex1_ours_130.mp3 │ │ │ ├── MDDBLatinJazz_skipex1_ours_90.mp3 │ │ │ ├── MDDBModalJazz_skipex3_ours_100.mp3 │ │ │ ├── MDDBModalJazz_skipex3_ours_110.mp3 │ │ │ ├── MDDBModalJazz_skipex3_ours_120.mp3 │ │ │ ├── MDDBModalJazz_skipex3_ours_130.mp3 │ │ │ ├── MDDBModalJazz_skipex3_ours_90.mp3 │ │ │ ├── Shouting_ours_90.mp3 │ │ │ ├── arabic_MDDBBritpop_ours_90.mp3 │ │ │ ├── arcade_MDDBBeatles_ours_100.mp3 │ │ │ ├── arcade_MDDBBebopJazz_ours_120.mp3 │ │ │ ├── arcade_MDDBChineseYaoZu_ours_90.mp3 │ │ │ ├── arcade_MDDBLatinJazz_ours_110.mp3 │ │ │ ├── arcade_MDDBVivaldi_ours_100.mp3 │ │ │ ├── country_MDDBModalJazz_ours_110.mp3 │ │ │ ├── hiphop_MDDBChineseXinJing_ours_100.mp3 │ │ │ ├── hiphop_MDDBFunkJazz_ours_90.mp3 │ │ │ ├── instrument_MDDBBeethoven_ours_130.mp3 │ │ │ ├── jazz_MDDBPunk_ours_110.mp3 │ │ │ ├── jazz_MDDBSpeedMetal_ours_100.mp3 │ │ │ ├── jazz_MDDBZeppelin_ours_100.mp3 │ │ │ ├── metal_MDDBChineseDrama_ours_160.mp3 │ │ │ ├── orchestra_MDDBFreeJazz_ours_90.mp3 │ │ │ ├── techno_MDDBBeethoven_ours_110.mp3 │ │ │ ├── techno_MDDBGospel_ours_100.mp3 │ │ │ ├── techno_MDDBRockabilly_ours_110.mp3 │ │ │ └── tone_MDDBChineseChaoZhou_ours_120.mp3 │ │ ├── unsup_samples │ │ │ ├── MDDBBeatles_ours_115-80_pc1_a-40.mp3 │ │ │ ├── MDDBBeatles_ours_115-80_pc1_a40.mp3 │ │ │ ├── MDDBBeatles_ours_115-80_pc2_a-40.mp3 │ │ │ ├── MDDBBeatles_ours_115-80_pc2_a40.mp3 │ │ │ ├── MDDBBeethoven_ours_135-95_pcs_a-20.mp3 │ │ │ ├── MDDBBeethoven_ours_135-95_pcs_a-40.mp3 │ │ │ ├── MDDBBeethoven_ours_200--1_spts135_pc3_a1.mp3 │ │ │ ├── MDDBBeethoven_ours_200--1_spts135_pc3_a2.mp3 │ │ │ ├── MDDBBritpop_ours_200--1_spts95_pcs_a1.mp3 │ │ │ ├── MDDBBritpop_ours_200--1_spts95_pcs_a2.mp3 │ │ │ ├── MDDBCoolJazz_ours_135-95_pcs_a20.mp3 │ │ │ ├── MDDBCoolJazz_ours_135-95_pcs_a40.mp3 │ │ │ ├── MDDBCountry1_8sec_ours_150--1_spts115_pcs_a2.mp3 │ │ │ ├── MDDBCountry1_ours_150--1_spts115_pcs_a1.mp3 │ │ │ ├── MDDBCountry1_ours_150--1_spts115_pcs_a2.mp3 │ │ │ ├── MDDBGospel_ours_150--1_spts120_pc3_a-2.mp3 │ │ │ ├── MDDBGospel_ours_150--1_spts120_pc3_a2.mp3 │ │ │ ├── MDDBHendrix_8sec_ours_200--1_spts80_pc1_a2.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_RAND_a-12.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_RAND_a-2.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_RAND_a-8.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_RAND_a12.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_RAND_a2.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_RAND_a8.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_a-1.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_a-2.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_a-3.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_a1.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_a2.mp3 │ │ │ ├── MDDBHendrix_ours_200--1_spts80_pc1_a3.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing_ours_150-50_spts95_pcs_a1.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing_ours_150-50_spts95_pcs_a2.mp3 │ │ │ ├── MDDBLatinJazz_8sec_ours_200--1_spts80_pcs_a-2.mp3 │ │ │ ├── MDDBLatinJazz_ours_200--1_spts80_pcs_a-1.mp3 │ │ │ ├── MDDBLatinJazz_ours_200--1_spts80_pcs_a-2.mp3 │ │ │ ├── MDDBReggae_8sec_ours_150--1_spts80_pcs_a-2.mp3 │ │ │ ├── MDDBReggae_ours_150--1_spts80_pcs_a-1.mp3 │ │ │ ├── MDDBReggae_ours_150--1_spts80_pcs_a-2.mp3 │ │ │ ├── MDDBRock_8sec_ours_200--1_spts65_pcs_a-2.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_RAND_a-12.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_RAND_a-2.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_RAND_a-8.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_RAND_a12.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_RAND_a2.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_RAND_a8.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_a-0.5.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_a-1.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_a-2.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_a0.5.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_a1.mp3 │ │ │ ├── MDDBRock_ours_200--1_spts65_pcs_a2.mp3 │ │ │ ├── MDDBRockabilly_8sec_ours_200--1_spts65_pcs_a-2.mp3 │ │ │ ├── MDDBRockabilly_ours_200--1_spts65_pcs_a-1.mp3 │ │ │ ├── MDDBRockabilly_ours_200--1_spts65_pcs_a-2.mp3 │ │ │ ├── MDDBVivaldi_ours_95-80_pcs_a-20.mp3 │ │ │ ├── MDDBVivaldi_ours_95-80_pcs_a-40.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_RAND_a-120.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_RAND_a-240.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_RAND_a-40.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_RAND_a120.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_RAND_a240.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_RAND_a40.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_a-20.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_a-40.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_a-60.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_a20.mp3 │ │ │ ├── ManSpeakingBGNoise_ours_115-95_pc1_a40.mp3 │ │ │ └── ManSpeakingBGNoise_ours_115-95_pc1_a60.mp3 │ │ └── unsup_sdedit │ │ │ ├── MDDBBeatles_skip100.mp3 │ │ │ ├── MDDBBeatles_skip115.mp3 │ │ │ ├── MDDBBeatles_skip130.mp3 │ │ │ ├── MDDBBeatles_skip85.mp3 │ │ │ ├── MDDBBeethoven_skip100.mp3 │ │ │ ├── MDDBBeethoven_skip115.mp3 │ │ │ ├── MDDBBeethoven_skip130.mp3 │ │ │ ├── MDDBBeethoven_skip85.mp3 │ │ │ ├── MDDBCoolJazz_skip100.mp3 │ │ │ ├── MDDBCoolJazz_skip115.mp3 │ │ │ ├── MDDBCoolJazz_skip130.mp3 │ │ │ ├── MDDBCoolJazz_skip85.mp3 │ │ │ ├── MDDBCountry1_8sec_skip100.mp3 │ │ │ ├── MDDBCountry1_8sec_skip115.mp3 │ │ │ ├── MDDBCountry1_8sec_skip130.mp3 │ │ │ ├── MDDBCountry1_8sec_skip85.mp3 │ │ │ ├── MDDBHendrix_8sec_skip100.mp3 │ │ │ ├── MDDBHendrix_8sec_skip115.mp3 │ │ │ ├── MDDBHendrix_8sec_skip130.mp3 │ │ │ ├── MDDBHendrix_8sec_skip85.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing_skip100.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing_skip115.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing_skip130.mp3 │ │ │ ├── MDDBInTheHalloftheMountainKing_skip85.mp3 │ │ │ ├── MDDBLatinJazz_8sec_skip100.mp3 │ │ │ ├── MDDBLatinJazz_8sec_skip115.mp3 │ │ │ ├── MDDBLatinJazz_8sec_skip130.mp3 │ │ │ ├── MDDBLatinJazz_8sec_skip85.mp3 │ │ │ ├── MDDBReggae_8sec_skip100.mp3 │ │ │ ├── MDDBReggae_8sec_skip115.mp3 │ │ │ ├── MDDBReggae_8sec_skip130.mp3 │ │ │ ├── MDDBReggae_8sec_skip85.mp3 │ │ │ ├── MDDBRock_8sec_skip100.mp3 │ │ │ ├── MDDBRock_8sec_skip115.mp3 │ │ │ ├── MDDBRock_8sec_skip130.mp3 │ │ │ ├── MDDBRock_8sec_skip85.mp3 │ │ │ ├── MDDBRockabilly_8sec_skip100.mp3 │ │ │ ├── MDDBRockabilly_8sec_skip115.mp3 │ │ │ ├── MDDBRockabilly_8sec_skip130.mp3 │ │ │ ├── MDDBRockabilly_8sec_skip85.mp3 │ │ │ ├── MDDBVivaldi_skip100.mp3 │ │ │ ├── MDDBVivaldi_skip115.mp3 │ │ │ ├── MDDBVivaldi_skip130.mp3 │ │ │ └── MDDBVivaldi_skip85.mp3 │ ├── github_logo.png │ ├── overview.mp4 │ ├── paper.pdf │ ├── paper.png │ └── teaser.png ├── stylefile.css └── supp.html ├── evals ├── LICENSE ├── LICENSE_CLAP-WEIGHTS ├── LICENSE_LPAPS ├── SupEval.ipynb ├── UnsupEval.ipynb ├── fadtk_utils.py ├── lpaps.py ├── meta_clap_consistency.py ├── pretrained_networks.py └── utils.py ├── requirements.txt └── requirements_paper.txt /code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Hila Manor 4 | Copyright (c) 2023 inbarhub 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/__init__.py -------------------------------------------------------------------------------- /code/audioldm/__init__.py: -------------------------------------------------------------------------------- 1 | from .ldm import LatentDiffusion 2 | from .utils import seed_everything, save_wave, get_time, get_duration 3 | from .pipeline import * 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /code/audioldm/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration 4 | import argparse 5 | 6 | CACHE_DIR = os.getenv( 7 | "AUDIOLDM_CACHE_DIR", 8 | os.path.join(os.path.expanduser("~"), ".cache/audioldm")) 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument( 13 | "--mode", 14 | type=str, 15 | required=False, 16 | default="generation", 17 | help="generation: text-to-audio generation; transfer: style transfer", 18 | choices=["generation", "transfer"] 19 | ) 20 | 21 | parser.add_argument( 22 | "-t", 23 | "--text", 24 | type=str, 25 | required=False, 26 | default="", 27 | help="Text prompt to the model for audio generation", 28 | ) 29 | 30 | parser.add_argument( 31 | "-f", 32 | "--file_path", 33 | type=str, 34 | required=False, 35 | default=None, 36 | help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", 37 | ) 38 | 39 | parser.add_argument( 40 | "--transfer_strength", 41 | type=float, 42 | required=False, 43 | default=0.5, 44 | help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", 45 | ) 46 | 47 | parser.add_argument( 48 | "-s", 49 | "--save_path", 50 | type=str, 51 | required=False, 52 | help="The path to save model output", 53 | default="./output", 54 | ) 55 | 56 | parser.add_argument( 57 | "--model_name", 58 | type=str, 59 | required=False, 60 | help="The checkpoint you gonna use", 61 | default="audioldm-m-full", 62 | choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"] 63 | ) 64 | 65 | parser.add_argument( 66 | "-ckpt", 67 | "--ckpt_path", 68 | type=str, 69 | required=False, 70 | help="The path to the pretrained .ckpt model", 71 | default=None, 72 | ) 73 | 74 | parser.add_argument( 75 | "-b", 76 | "--batchsize", 77 | type=int, 78 | required=False, 79 | default=1, 80 | help="Generate how many samples at the same time", 81 | ) 82 | 83 | parser.add_argument( 84 | "--ddim_steps", 85 | type=int, 86 | required=False, 87 | default=200, 88 | help="The sampling step for DDIM", 89 | ) 90 | 91 | parser.add_argument( 92 | "-gs", 93 | "--guidance_scale", 94 | type=float, 95 | required=False, 96 | default=2.5, 97 | help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", 98 | ) 99 | 100 | parser.add_argument( 101 | "-dur", 102 | "--duration", 103 | type=float, 104 | required=False, 105 | default=10.0, 106 | help="The duration of the samples", 107 | ) 108 | 109 | parser.add_argument( 110 | "-n", 111 | "--n_candidate_gen_per_text", 112 | type=int, 113 | required=False, 114 | default=3, 115 | help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", 116 | ) 117 | 118 | parser.add_argument( 119 | "--seed", 120 | type=int, 121 | required=False, 122 | default=42, 123 | help="Change this value (any integer number) will lead to a different generation result.", 124 | ) 125 | 126 | args = parser.parse_args() 127 | 128 | if(args.ckpt_path is not None): 129 | print("Warning: ckpt_path has no effect after version 0.0.20.") 130 | 131 | assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" 132 | 133 | mode = args.mode 134 | if(mode == "generation" and args.file_path is not None): 135 | mode = "generation_audio_to_audio" 136 | if(len(args.text) > 0): 137 | print("Warning: You have specified the --file_path. --text will be ignored") 138 | args.text = "" 139 | 140 | save_path = os.path.join(args.save_path, mode) 141 | 142 | if(args.file_path is not None): 143 | save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) 144 | 145 | text = args.text 146 | random_seed = args.seed 147 | duration = args.duration 148 | guidance_scale = args.guidance_scale 149 | n_candidate_gen_per_text = args.n_candidate_gen_per_text 150 | 151 | os.makedirs(save_path, exist_ok=True) 152 | audioldm = build_model(model_name=args.model_name) 153 | 154 | if(args.mode == "generation"): 155 | waveform = text_to_audio( 156 | audioldm, 157 | text, 158 | args.file_path, 159 | random_seed, 160 | duration=duration, 161 | guidance_scale=guidance_scale, 162 | ddim_steps=args.ddim_steps, 163 | n_candidate_gen_per_text=n_candidate_gen_per_text, 164 | batchsize=args.batchsize, 165 | ) 166 | 167 | elif(args.mode == "transfer"): 168 | assert args.file_path is not None 169 | assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path 170 | waveform = style_transfer( 171 | audioldm, 172 | text, 173 | args.file_path, 174 | args.transfer_strength, 175 | random_seed, 176 | duration=duration, 177 | guidance_scale=guidance_scale, 178 | ddim_steps=args.ddim_steps, 179 | batchsize=args.batchsize, 180 | ) 181 | waveform = waveform[:,None,:] 182 | 183 | save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) 184 | -------------------------------------------------------------------------------- /code/audioldm/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import wav_to_fbank, read_wav_file 2 | from .stft import TacotronSTFT 3 | -------------------------------------------------------------------------------- /code/audioldm/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /code/audioldm/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchaudio 4 | 5 | 6 | def get_mel_from_wav(audio, _stft): 7 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 8 | audio = torch.autograd.Variable(audio, requires_grad=False) 9 | melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) 10 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 11 | log_magnitudes_stft = ( 12 | torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) 13 | ) 14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 15 | return melspec, log_magnitudes_stft, energy 16 | 17 | 18 | def _pad_spec(fbank, target_length=1024): 19 | n_frames = fbank.shape[0] 20 | p = target_length - n_frames 21 | # cut and pad 22 | if p > 0: 23 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 24 | fbank = m(fbank) 25 | elif p < 0: 26 | fbank = fbank[0:target_length, :] 27 | 28 | if fbank.size(-1) % 2 != 0: 29 | fbank = fbank[..., :-1] 30 | 31 | return fbank 32 | 33 | 34 | def pad_wav(waveform, segment_length): 35 | waveform_length = waveform.shape[-1] 36 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length 37 | if segment_length is None or waveform_length == segment_length: 38 | return waveform 39 | elif waveform_length > segment_length: 40 | return waveform[:segment_length] 41 | elif waveform_length < segment_length: 42 | temp_wav = np.zeros((1, segment_length)) 43 | temp_wav[:, :waveform_length] = waveform 44 | return temp_wav 45 | 46 | def normalize_wav(waveform): 47 | waveform = waveform - np.mean(waveform) 48 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) 49 | return waveform * 0.5 50 | 51 | 52 | def read_wav_file(filename, segment_length): 53 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower 54 | waveform, sr = torchaudio.load(filename) # Faster!!! 55 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 56 | waveform = waveform.numpy()[0, ...] 57 | waveform = normalize_wav(waveform) 58 | waveform = waveform[None, ...] 59 | waveform = pad_wav(waveform, segment_length) 60 | 61 | waveform = waveform / np.max(np.abs(waveform)) 62 | waveform = 0.5 * waveform 63 | 64 | return waveform 65 | 66 | 67 | def wav_to_fbank(filename, target_length=1024, fn_STFT=None): 68 | assert fn_STFT is not None 69 | 70 | # mixup 71 | waveform = read_wav_file(filename, target_length * 160) # hop size is 160 72 | 73 | waveform = waveform[0, ...] 74 | waveform = torch.FloatTensor(waveform) 75 | 76 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) 77 | 78 | fbank = torch.FloatTensor(fbank.T) 79 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) 80 | 81 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( 82 | log_magnitudes_stft, target_length 83 | ) 84 | 85 | return fbank, log_magnitudes_stft, waveform 86 | -------------------------------------------------------------------------------- /code/audioldm/clap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/audioldm/clap/__init__.py -------------------------------------------------------------------------------- /code/audioldm/clap/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from audioldm.clap.open_clip import create_model 4 | from audioldm.clap.training.data import get_audio_features 5 | import torchaudio 6 | from transformers import RobertaTokenizer 7 | import torch.nn.functional as F 8 | 9 | 10 | class CLAPAudioEmbeddingClassifierFreev2(nn.Module): 11 | def __init__( 12 | self, 13 | pretrained_path="", 14 | key="class", 15 | sampling_rate=16000, 16 | embed_mode="audio", 17 | amodel = "HTSAT-tiny", 18 | unconditional_prob=0.1, 19 | random_mute=False, 20 | max_random_mute_portion=0.5, 21 | training_mode=True, 22 | ): 23 | super().__init__() 24 | 25 | self.key = key 26 | self.device = "cpu" 27 | self.precision = "fp32" 28 | self.amodel = amodel # or 'PANN-14' 29 | self.tmodel = "roberta" # the best text encoder in our training 30 | self.enable_fusion = False # False if you do not want to use the fusion model 31 | self.fusion_type = "aff_2d" 32 | self.pretrained = pretrained_path 33 | self.embed_mode = embed_mode 34 | self.embed_mode_orig = embed_mode 35 | self.sampling_rate = sampling_rate 36 | self.unconditional_prob = unconditional_prob 37 | self.random_mute = random_mute 38 | self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") 39 | self.max_random_mute_portion = max_random_mute_portion 40 | self.training_mode = training_mode 41 | self.model, self.model_cfg = create_model( 42 | self.amodel, 43 | self.tmodel, 44 | self.pretrained, 45 | precision=self.precision, 46 | device=self.device, 47 | enable_fusion=self.enable_fusion, 48 | fusion_type=self.fusion_type, 49 | ) 50 | for p in self.model.parameters(): 51 | p.requires_grad = False 52 | 53 | self.model.eval() 54 | 55 | def get_unconditional_condition(self, batchsize): 56 | self.unconditional_token = self.model.get_text_embedding( 57 | self.tokenizer(["", ""]) 58 | )[0:1] 59 | return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) 60 | 61 | def batch_to_list(self, batch): 62 | ret = [] 63 | for i in range(batch.size(0)): 64 | ret.append(batch[i]) 65 | return ret 66 | 67 | def make_decision(self, probability): 68 | if float(torch.rand(1)) < probability: 69 | return True 70 | else: 71 | return False 72 | 73 | def random_uniform(self, start, end): 74 | val = torch.rand(1).item() 75 | return start + (end - start) * val 76 | 77 | def _random_mute(self, waveform): 78 | # waveform: [bs, t-steps] 79 | t_steps = waveform.size(-1) 80 | for i in range(waveform.size(0)): 81 | mute_size = int( 82 | self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) 83 | ) 84 | mute_start = int(self.random_uniform(0, t_steps - mute_size)) 85 | waveform[i, mute_start : mute_start + mute_size] = 0 86 | return waveform 87 | 88 | def cos_similarity(self, waveform, text): 89 | # waveform: [bs, t_steps] 90 | with torch.no_grad(): 91 | self.embed_mode = "audio" 92 | print(text) 93 | audio_emb = self(waveform.cuda()) 94 | self.embed_mode = "text" 95 | text_emb = self(text) 96 | similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) 97 | return similarity.squeeze() 98 | 99 | def forward(self, batch, key=None): 100 | # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 101 | # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 102 | if self.model.training == True and not self.training_mode: 103 | print( 104 | "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." 105 | ) 106 | self.model, self.model_cfg = create_model( 107 | self.amodel, 108 | self.tmodel, 109 | self.pretrained, 110 | precision=self.precision, 111 | device="cuda", 112 | enable_fusion=self.enable_fusion, 113 | fusion_type=self.fusion_type, 114 | ) 115 | for p in self.model.parameters(): 116 | p.requires_grad = False 117 | self.model.eval() 118 | 119 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 120 | if self.embed_mode == "audio": 121 | with torch.no_grad(): 122 | audio_dict_list = [] 123 | assert ( 124 | self.sampling_rate == 16000 125 | ), "We only support 16000 sampling rate" 126 | if self.random_mute: 127 | batch = self._random_mute(batch) 128 | # batch: [bs, 1, t-samples] 129 | batch = torchaudio.functional.resample( 130 | batch, orig_freq=self.sampling_rate, new_freq=48000 131 | ) 132 | for waveform in self.batch_to_list(batch): 133 | audio_dict = {} 134 | audio_dict = get_audio_features( 135 | audio_dict, 136 | waveform, 137 | 480000, 138 | data_truncating="fusion", 139 | data_filling="repeatpad", 140 | audio_cfg=self.model_cfg["audio_cfg"], 141 | ) 142 | audio_dict_list.append(audio_dict) 143 | # [bs, 512] 144 | embed = self.model.get_audio_embedding(audio_dict_list) 145 | elif self.embed_mode == "text": 146 | with torch.no_grad(): 147 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 148 | text_data = self.tokenizer(batch) 149 | embed = self.model.get_text_embedding(text_data) 150 | 151 | embed = embed.unsqueeze(1) 152 | self.unconditional_token = self.model.get_text_embedding( 153 | self.tokenizer(["", ""]) 154 | )[0:1] 155 | 156 | for i in range(embed.size(0)): 157 | if self.make_decision(self.unconditional_prob): 158 | embed[i] = self.unconditional_token 159 | 160 | # [bs, 1, 512] 161 | return embed.detach() 162 | 163 | def tokenizer(self, text): 164 | result = self.tokenize( 165 | text, 166 | padding="max_length", 167 | truncation=True, 168 | max_length=512, 169 | return_tensors="pt", 170 | ) 171 | return {k: v.squeeze(0) for k, v in result.items()} 172 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import ( 2 | list_models, 3 | create_model, 4 | create_model_and_transforms, 5 | add_model_config, 6 | ) 7 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 8 | from .model import ( 9 | CLAP, 10 | CLAPTextCfg, 11 | CLAPVisionCfg, 12 | CLAPAudioCfp, 13 | convert_weights_to_fp16, 14 | trace_model, 15 | ) 16 | from .openai import load_openai_model, list_openai_models 17 | from .pretrained import ( 18 | list_pretrained, 19 | list_pretrained_tag_models, 20 | list_pretrained_model_tags, 21 | get_pretrained_url, 22 | download_pretrained, 23 | ) 24 | from .tokenizer import SimpleTokenizer, tokenize 25 | from .transform import image_transform 26 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | 3 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 4 | model = BertModel.from_pretrained("bert-base-uncased") 5 | text = "Replace me by any text you'd like." 6 | 7 | 8 | def bert_embeddings(text): 9 | # text = "Replace me by any text you'd like." 10 | encoded_input = tokenizer(text, return_tensors="pt") 11 | output = model(**encoded_input) 12 | return output 13 | 14 | 15 | from transformers import RobertaTokenizer, RobertaModel 16 | 17 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 18 | model = RobertaModel.from_pretrained("roberta-base") 19 | text = "Replace me by any text you'd like." 20 | 21 | 22 | def Roberta_embeddings(text): 23 | # text = "Replace me by any text you'd like." 24 | encoded_input = tokenizer(text, return_tensors="pt") 25 | output = model(**encoded_input) 26 | return output 27 | 28 | 29 | from transformers import BartTokenizer, BartModel 30 | 31 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 32 | model = BartModel.from_pretrained("facebook/bart-base") 33 | text = "Replace me by any text you'd like." 34 | 35 | 36 | def bart_embeddings(text): 37 | # text = "Replace me by any text you'd like." 38 | encoded_input = tokenizer(text, return_tensors="pt") 39 | output = model(**encoded_input) 40 | return output 41 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .model import MLPLayers 5 | 6 | 7 | class LinearProbe(nn.Module): 8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): 9 | """ 10 | Args: 11 | model: nn.Module 12 | mlp: bool, if True, then use the MLP layer as the linear probe module 13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe 14 | in_ch: int, the output channel from CLAP model 15 | out_ch: int, the output channel from linear probe (class_num) 16 | act: torch.nn.functional, the activation function before the loss function 17 | """ 18 | super().__init__() 19 | in_ch = 512 20 | self.clap_model = model 21 | self.clap_model.text_branch = None # to save memory 22 | self.freeze = freeze 23 | if mlp: 24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) 25 | else: 26 | self.lp_layer = nn.Linear(in_ch, out_ch) 27 | 28 | if self.freeze: 29 | for param in self.clap_model.parameters(): 30 | param.requires_grad = False 31 | 32 | if act == "None": 33 | self.act = None 34 | elif act == "relu": 35 | self.act = nn.ReLU() 36 | elif act == "elu": 37 | self.act = nn.ELU() 38 | elif act == "prelu": 39 | self.act = nn.PReLU(num_parameters=in_ch) 40 | elif act == "softmax": 41 | self.act = nn.Softmax(dim=-1) 42 | elif act == "sigmoid": 43 | self.act = nn.Sigmoid() 44 | 45 | def forward(self, x, mix_lambda=None, device=None): 46 | """ 47 | Args: 48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list 49 | mix_lambda: torch.tensor [batch], the mixup lambda 50 | Returns: 51 | class_prob: torch.tensor [batch, class_num] 52 | 53 | """ 54 | # batchnorm cancel grandient 55 | if self.freeze: 56 | self.clap_model.eval() 57 | 58 | x = self.clap_model.audio_projection( 59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ 60 | "embedding" 61 | ] 62 | ) 63 | out = self.lp_layer(x) 64 | if self.act is not None: 65 | out = self.act(out) 66 | return out 67 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import ( 14 | get_pretrained_url, 15 | list_pretrained_tag_models, 16 | download_pretrained, 17 | ) 18 | 19 | __all__ = ["list_openai_models", "load_openai_model"] 20 | 21 | CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") 22 | 23 | 24 | 25 | def list_openai_models() -> List[str]: 26 | """Returns the names of available CLIP models""" 27 | return list_pretrained_tag_models("openai") 28 | 29 | 30 | def load_openai_model( 31 | name: str, 32 | model_cfg, 33 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 34 | jit=True, 35 | cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"), 36 | enable_fusion: bool = False, 37 | fusion_type: str = "None", 38 | ): 39 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 40 | 41 | Parameters 42 | ---------- 43 | name : str 44 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 45 | device : Union[str, torch.device] 46 | The device to put the loaded model 47 | jit : bool 48 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 49 | 50 | Returns 51 | ------- 52 | model : torch.nn.Module 53 | The CLAP model 54 | preprocess : Callable[[PIL.Image], torch.Tensor] 55 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 56 | """ 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained( 59 | get_pretrained_url(name, "openai"), root=cache_dir 60 | ) 61 | elif os.path.isfile(name): 62 | model_path = name 63 | else: 64 | raise RuntimeError( 65 | f"Model {name} not found; available models = {list_openai_models()}" 66 | ) 67 | 68 | try: 69 | # loading JIT archive 70 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 71 | state_dict = None 72 | except RuntimeError: 73 | # loading saved state dict 74 | if jit: 75 | warnings.warn( 76 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 77 | ) 78 | jit = False 79 | state_dict = torch.load(model_path, map_location="cpu") 80 | 81 | if not jit: 82 | try: 83 | model = build_model_from_openai_state_dict( 84 | state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type 85 | ).to(device) 86 | except KeyError: 87 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 88 | model = build_model_from_openai_state_dict( 89 | sd, model_cfg, enable_fusion, fusion_type 90 | ).to(device) 91 | 92 | if str(device) == "cpu": 93 | model.float() 94 | return model 95 | 96 | # patch the device names 97 | device_holder = torch.jit.trace( 98 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 99 | ) 100 | device_node = [ 101 | n 102 | for n in device_holder.graph.findAllNodes("prim::Constant") 103 | if "Device" in repr(n) 104 | ][-1] 105 | 106 | def patch_device(module): 107 | try: 108 | graphs = [module.graph] if hasattr(module, "graph") else [] 109 | except RuntimeError: 110 | graphs = [] 111 | 112 | if hasattr(module, "forward1"): 113 | graphs.append(module.forward1.graph) 114 | 115 | for graph in graphs: 116 | for node in graph.findAllNodes("prim::Constant"): 117 | if "value" in node.attributeNames() and str(node["value"]).startswith( 118 | "cuda" 119 | ): 120 | node.copyAttributes(device_node) 121 | 122 | model.apply(patch_device) 123 | patch_device(model.encode_audio) 124 | patch_device(model.encode_text) 125 | 126 | # patch dtype to float32 on CPU 127 | if str(device) == "cpu": 128 | float_holder = torch.jit.trace( 129 | lambda: torch.ones([]).float(), example_inputs=[] 130 | ) 131 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 132 | float_node = float_input.node() 133 | 134 | def patch_float(module): 135 | try: 136 | graphs = [module.graph] if hasattr(module, "graph") else [] 137 | except RuntimeError: 138 | graphs = [] 139 | 140 | if hasattr(module, "forward1"): 141 | graphs.append(module.forward1.graph) 142 | 143 | for graph in graphs: 144 | for node in graph.findAllNodes("aten::to"): 145 | inputs = list(node.inputs()) 146 | for i in [ 147 | 1, 148 | 2, 149 | ]: # dtype can be the second or third argument to aten::to() 150 | if inputs[i].node()["value"] == 5: 151 | inputs[i].node().copyAttributes(float_node) 152 | 153 | model.apply(patch_float) 154 | patch_float(model.encode_audio) 155 | patch_float(model.encode_text) 156 | model.float() 157 | 158 | model.audio_branch.audio_length = model.audio_cfg.audio_length 159 | return model 160 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import ( 14 | AttentionPool2d as AbsAttentionPool2d, 15 | ) 16 | except ImportError as e: 17 | timm = None 18 | 19 | from .utils import freeze_batch_norm_2d 20 | 21 | 22 | class TimmModel(nn.Module): 23 | """timm model adapter 24 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_name, 30 | embed_dim, 31 | image_size=224, 32 | pool="avg", 33 | proj="linear", 34 | drop=0.0, 35 | pretrained=False, 36 | ): 37 | super().__init__() 38 | if timm is None: 39 | raise RuntimeError("Please `pip install timm` to use timm models.") 40 | 41 | self.image_size = to_2tuple(image_size) 42 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 43 | feat_size = self.trunk.default_cfg.get("pool_size", None) 44 | feature_ndim = 1 if not feat_size else 2 45 | if pool in ("abs_attn", "rot_attn"): 46 | assert feature_ndim == 2 47 | # if attn pooling used, remove both classifier and default pool 48 | self.trunk.reset_classifier(0, global_pool="") 49 | else: 50 | # reset global pool if pool config set, otherwise leave as network default 51 | reset_kwargs = dict(global_pool=pool) if pool else {} 52 | self.trunk.reset_classifier(0, **reset_kwargs) 53 | prev_chs = self.trunk.num_features 54 | 55 | head_layers = OrderedDict() 56 | if pool == "abs_attn": 57 | head_layers["pool"] = AbsAttentionPool2d( 58 | prev_chs, feat_size=feat_size, out_features=embed_dim 59 | ) 60 | prev_chs = embed_dim 61 | elif pool == "rot_attn": 62 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 63 | prev_chs = embed_dim 64 | else: 65 | assert proj, "projection layer needed if non-attention pooling is used." 66 | 67 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 68 | if proj == "linear": 69 | head_layers["drop"] = nn.Dropout(drop) 70 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim) 71 | elif proj == "mlp": 72 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 73 | 74 | self.head = nn.Sequential(head_layers) 75 | 76 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 77 | """lock modules 78 | Args: 79 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 80 | """ 81 | if not unlocked_groups: 82 | # lock full model 83 | for param in self.trunk.parameters(): 84 | param.requires_grad = False 85 | if freeze_bn_stats: 86 | freeze_batch_norm_2d(self.trunk) 87 | else: 88 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 89 | try: 90 | # FIXME import here until API stable and in an official release 91 | from timm.models.helpers import group_parameters, group_modules 92 | except ImportError: 93 | raise RuntimeError( 94 | "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" 95 | ) 96 | matcher = self.trunk.group_matcher() 97 | gparams = group_parameters(self.trunk, matcher) 98 | max_layer_id = max(gparams.keys()) 99 | max_layer_id = max_layer_id - unlocked_groups 100 | for group_idx in range(max_layer_id + 1): 101 | group = gparams[group_idx] 102 | for param in group: 103 | self.trunk.get_parameter(param).requires_grad = False 104 | if freeze_bn_stats: 105 | gmodules = group_modules(self.trunk, matcher, reverse=True) 106 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 107 | freeze_batch_norm_2d(self.trunk, gmodules) 108 | 109 | def forward(self, x): 110 | x = self.trunk(x) 111 | x = self.head(x) 112 | return x 113 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import ( 2 | Normalize, 3 | Compose, 4 | RandomResizedCrop, 5 | InterpolationMode, 6 | ToTensor, 7 | Resize, 8 | CenterCrop, 9 | ) 10 | 11 | 12 | def _convert_to_rgb(image): 13 | return image.convert("RGB") 14 | 15 | 16 | def image_transform( 17 | image_size: int, 18 | is_train: bool, 19 | mean=(0.48145466, 0.4578275, 0.40821073), 20 | std=(0.26862954, 0.26130258, 0.27577711), 21 | ): 22 | normalize = Normalize(mean=mean, std=std) 23 | if is_train: 24 | return Compose( 25 | [ 26 | RandomResizedCrop( 27 | image_size, 28 | scale=(0.9, 1.0), 29 | interpolation=InterpolationMode.BICUBIC, 30 | ), 31 | _convert_to_rgb, 32 | ToTensor(), 33 | normalize, 34 | ] 35 | ) 36 | else: 37 | return Compose( 38 | [ 39 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 40 | CenterCrop(image_size), 41 | _convert_to_rgb, 42 | ToTensor(), 43 | normalize, 44 | ] 45 | ) 46 | -------------------------------------------------------------------------------- /code/audioldm/clap/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.1" 2 | -------------------------------------------------------------------------------- /code/audioldm/clap/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/audioldm/clap/training/__init__.py -------------------------------------------------------------------------------- /code/audioldm/clap/training/audioset_textmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/audioldm/clap/training/audioset_textmap.npy -------------------------------------------------------------------------------- /code/audioldm/clap/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import socket 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all( 30 | [var in os.environ for var in pmi_vars] 31 | ): 32 | return True 33 | else: 34 | return False 35 | 36 | 37 | def is_using_distributed(): 38 | if "WORLD_SIZE" in os.environ: 39 | return int(os.environ["WORLD_SIZE"]) > 1 40 | if "SLURM_NTASKS" in os.environ: 41 | return int(os.environ["SLURM_NTASKS"]) > 1 42 | return False 43 | 44 | 45 | def world_info_from_env(): 46 | local_rank = 0 47 | for v in ( 48 | "SLURM_LOCALID", 49 | "MPI_LOCALRANKID", 50 | "OMPI_COMM_WORLD_LOCAL_RANK", 51 | "LOCAL_RANK", 52 | ): 53 | if v in os.environ: 54 | local_rank = int(os.environ[v]) 55 | break 56 | global_rank = 0 57 | for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): 58 | if v in os.environ: 59 | global_rank = int(os.environ[v]) 60 | break 61 | world_size = 1 62 | for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): 63 | if v in os.environ: 64 | world_size = int(os.environ[v]) 65 | break 66 | 67 | return local_rank, global_rank, world_size 68 | 69 | 70 | def init_distributed_device(args): 71 | # Distributed training = training on more than one GPU. 72 | # Works in both single and multi-node scenarios. 73 | args.distributed = False 74 | args.world_size = 1 75 | args.rank = 0 # global rank 76 | args.local_rank = 0 77 | if args.horovod: 78 | assert hvd is not None, "Horovod is not installed" 79 | hvd.init() 80 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 81 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 82 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 83 | args.local_rank = local_rank 84 | args.rank = world_rank 85 | args.world_size = world_size 86 | # args.local_rank = int(hvd.local_rank()) 87 | # args.rank = hvd.rank() 88 | # args.world_size = hvd.size() 89 | args.distributed = True 90 | os.environ["LOCAL_RANK"] = str(args.local_rank) 91 | os.environ["RANK"] = str(args.rank) 92 | os.environ["WORLD_SIZE"] = str(args.world_size) 93 | print( 94 | f"Distributed training: local_rank={args.local_rank}, " 95 | f"rank={args.rank}, world_size={args.world_size}, " 96 | f"hostname={socket.gethostname()}, pid={os.getpid()}" 97 | ) 98 | elif is_using_distributed(): 99 | if "SLURM_PROCID" in os.environ: 100 | # DDP via SLURM 101 | args.local_rank, args.rank, args.world_size = world_info_from_env() 102 | # SLURM var -> torch.distributed vars in case needed 103 | os.environ["LOCAL_RANK"] = str(args.local_rank) 104 | os.environ["RANK"] = str(args.rank) 105 | os.environ["WORLD_SIZE"] = str(args.world_size) 106 | torch.distributed.init_process_group( 107 | backend=args.dist_backend, 108 | init_method=args.dist_url, 109 | world_size=args.world_size, 110 | rank=args.rank, 111 | ) 112 | elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster 113 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 114 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 115 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 116 | args.local_rank = local_rank 117 | args.rank = world_rank 118 | args.world_size = world_size 119 | torch.distributed.init_process_group( 120 | backend=args.dist_backend, 121 | init_method=args.dist_url, 122 | world_size=args.world_size, 123 | rank=args.rank, 124 | ) 125 | else: 126 | # DDP via torchrun, torch.distributed.launch 127 | args.local_rank, _, _ = world_info_from_env() 128 | torch.distributed.init_process_group( 129 | backend=args.dist_backend, init_method=args.dist_url 130 | ) 131 | args.world_size = torch.distributed.get_world_size() 132 | args.rank = torch.distributed.get_rank() 133 | args.distributed = True 134 | print( 135 | f"Distributed training: local_rank={args.local_rank}, " 136 | f"rank={args.rank}, world_size={args.world_size}, " 137 | f"hostname={socket.gethostname()}, pid={os.getpid()}" 138 | ) 139 | 140 | if torch.cuda.is_available(): 141 | if args.distributed and not args.no_set_device_rank: 142 | device = "cuda:%d" % args.local_rank 143 | else: 144 | device = "cuda:0" 145 | torch.cuda.set_device(device) 146 | else: 147 | device = "cpu" 148 | args.device = device 149 | device = torch.device(device) 150 | return device 151 | -------------------------------------------------------------------------------- /code/audioldm/clap/training/infer_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import os 4 | import torch 5 | import librosa 6 | from open_clip import create_model 7 | from training.data import get_audio_features 8 | from training.data import int16_to_float32, float32_to_int16 9 | from transformers import RobertaTokenizer 10 | 11 | tokenize = RobertaTokenizer.from_pretrained("roberta-base") 12 | 13 | 14 | def tokenizer(text): 15 | result = tokenize( 16 | text, 17 | padding="max_length", 18 | truncation=True, 19 | max_length=77, 20 | return_tensors="pt", 21 | ) 22 | return {k: v.squeeze(0) for k, v in result.items()} 23 | 24 | 25 | PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" 26 | WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" 27 | 28 | 29 | def infer_text(): 30 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 31 | precision = "fp32" 32 | amodel = "HTSAT-tiny" # or 'PANN-14' 33 | tmodel = "roberta" # the best text encoder in our training 34 | enable_fusion = False # False if you do not want to use the fusion model 35 | fusion_type = "aff_2d" 36 | pretrained = PRETRAINED_PATH 37 | 38 | model, model_cfg = create_model( 39 | amodel, 40 | tmodel, 41 | pretrained, 42 | precision=precision, 43 | device=device, 44 | enable_fusion=enable_fusion, 45 | fusion_type=fusion_type, 46 | ) 47 | # load the text, can be a list (i.e. batch size) 48 | text_data = ["I love the contrastive learning", "I love the pretrain model"] 49 | # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 50 | text_data = tokenizer(text_data) 51 | 52 | text_embed = model.get_text_embedding(text_data) 53 | print(text_embed.size()) 54 | 55 | 56 | def infer_audio(): 57 | 58 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 59 | precision = "fp32" 60 | amodel = "HTSAT-tiny" # or 'PANN-14' 61 | tmodel = "roberta" # the best text encoder in our training 62 | enable_fusion = False # False if you do not want to use the fusion model 63 | fusion_type = "aff_2d" 64 | pretrained = PRETRAINED_PATH 65 | 66 | model, model_cfg = create_model( 67 | amodel, 68 | tmodel, 69 | pretrained, 70 | precision=precision, 71 | device=device, 72 | enable_fusion=enable_fusion, 73 | fusion_type=fusion_type, 74 | ) 75 | 76 | # load the waveform of the shape (T,), should resample to 48000 77 | audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) 78 | # quantize 79 | audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) 80 | audio_waveform = torch.from_numpy(audio_waveform).float() 81 | audio_dict = {} 82 | 83 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 84 | import ipdb 85 | 86 | ipdb.set_trace() 87 | audio_dict = get_audio_features( 88 | audio_dict, 89 | audio_waveform, 90 | 480000, 91 | data_truncating="fusion", 92 | data_filling="repeatpad", 93 | audio_cfg=model_cfg["audio_cfg"], 94 | ) 95 | # can send a list to the model, to process many audio tracks in one time (i.e. batch size) 96 | audio_embed = model.get_audio_embedding([audio_dict]) 97 | print(audio_embed.size()) 98 | import ipdb 99 | 100 | ipdb.set_trace() 101 | 102 | 103 | if __name__ == "__main__": 104 | infer_text() 105 | infer_audio() 106 | -------------------------------------------------------------------------------- /code/audioldm/clap/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | 8 | hostname = socket.gethostname() 9 | formatter = logging.Formatter( 10 | f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", 11 | datefmt="%Y-%m-%d,%H:%M:%S", 12 | ) 13 | else: 14 | formatter = logging.Formatter( 15 | "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" 16 | ) 17 | 18 | logging.root.setLevel(level) 19 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 20 | for logger in loggers: 21 | logger.setLevel(level) 22 | 23 | stream_handler = logging.StreamHandler() 24 | stream_handler.setFormatter(formatter) 25 | logging.root.addHandler(stream_handler) 26 | 27 | if log_file: 28 | file_handler = logging.FileHandler(filename=log_file) 29 | file_handler.setFormatter(formatter) 30 | logging.root.addHandler(file_handler) 31 | -------------------------------------------------------------------------------- /code/audioldm/clap/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | 24 | return _lr_adjuster 25 | -------------------------------------------------------------------------------- /code/audioldm/clap/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | # NOTE: This script is currently not supported for CLAP. 2 | import logging 3 | from contextlib import suppress 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | from open_clip import tokenize 10 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 11 | 12 | 13 | def zero_shot_classifier(model, classnames, templates, args): 14 | with torch.no_grad(): 15 | zeroshot_weights = [] 16 | for classname in tqdm(classnames): 17 | texts = [template(classname) for template in templates] # format with class 18 | texts = tokenize(texts).to(args.device) # tokenize 19 | if args.distributed and not args.horovod: 20 | class_embeddings = model.module.encode_text(texts) 21 | else: 22 | class_embeddings = model.encode_text(texts) 23 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 24 | class_embedding /= class_embedding.norm() 25 | zeroshot_weights.append(class_embedding) 26 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 27 | return zeroshot_weights 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | pred = output.topk(max(topk), 1, True, True)[1].t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | return [ 34 | float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 35 | for k in topk 36 | ] 37 | 38 | 39 | def run(model, classifier, dataloader, args): 40 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress 41 | with torch.no_grad(): 42 | top1, top5, n = 0.0, 0.0, 0.0 43 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 44 | images = images.to(args.device) 45 | target = target.to(args.device) 46 | 47 | with autocast(): 48 | # predict 49 | if args.distributed and not args.horovod: 50 | image_features = model.module.encode_image(images) 51 | else: 52 | image_features = model.encode_image(images) 53 | image_features = F.normalize(image_features, dim=-1) 54 | logits = 100.0 * image_features @ classifier 55 | 56 | # measure accuracy 57 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 58 | top1 += acc1 59 | top5 += acc5 60 | n += images.size(0) 61 | 62 | top1 = top1 / n 63 | top5 = top5 / n 64 | return top1, top5 65 | 66 | 67 | def zero_shot_eval(model, data, epoch, args): 68 | if "imagenet-val" not in data and "imagenet-v2" not in data: 69 | return {} 70 | if args.zeroshot_frequency == 0: 71 | return {} 72 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 73 | return {} 74 | 75 | logging.info("Starting zero-shot imagenet.") 76 | 77 | logging.info("Building zero-shot classifier") 78 | classifier = zero_shot_classifier( 79 | model, imagenet_classnames, openai_imagenet_template, args 80 | ) 81 | 82 | logging.info("Using classifier") 83 | results = {} 84 | if "imagenet-val" in data: 85 | top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) 86 | results["imagenet-zeroshot-val-top1"] = top1 87 | results["imagenet-zeroshot-val-top5"] = top5 88 | if "imagenet-v2" in data: 89 | top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) 90 | results["imagenetv2-zeroshot-val-top1"] = top1 91 | results["imagenetv2-zeroshot-val-top5"] = top5 92 | 93 | logging.info("Finished zero-shot imagenet.") 94 | 95 | return results 96 | -------------------------------------------------------------------------------- /code/audioldm/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self 8 | -------------------------------------------------------------------------------- /code/audioldm/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | # print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /code/audioldm/hifigan/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import audioldm.hifigan as hifigan 8 | 9 | HIFIGAN_16K_64 = { 10 | "resblock": "1", 11 | "num_gpus": 6, 12 | "batch_size": 16, 13 | "learning_rate": 0.0002, 14 | "adam_b1": 0.8, 15 | "adam_b2": 0.99, 16 | "lr_decay": 0.999, 17 | "seed": 1234, 18 | "upsample_rates": [5, 4, 2, 2, 2], 19 | "upsample_kernel_sizes": [16, 16, 8, 4, 4], 20 | "upsample_initial_channel": 1024, 21 | "resblock_kernel_sizes": [3, 7, 11], 22 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 23 | "segment_size": 8192, 24 | "num_mels": 64, 25 | "num_freq": 1025, 26 | "n_fft": 1024, 27 | "hop_size": 160, 28 | "win_size": 1024, 29 | "sampling_rate": 16000, 30 | "fmin": 0, 31 | "fmax": 8000, 32 | "fmax_for_loss": None, 33 | "num_workers": 4, 34 | "dist_config": { 35 | "dist_backend": "nccl", 36 | "dist_url": "tcp://localhost:54321", 37 | "world_size": 1, 38 | }, 39 | } 40 | 41 | 42 | def get_available_checkpoint_keys(model, ckpt): 43 | print("==> Attemp to reload from %s" % ckpt) 44 | state_dict = torch.load(ckpt)["state_dict"] 45 | current_state_dict = model.state_dict() 46 | new_state_dict = {} 47 | for k in state_dict.keys(): 48 | if ( 49 | k in current_state_dict.keys() 50 | and current_state_dict[k].size() == state_dict[k].size() 51 | ): 52 | new_state_dict[k] = state_dict[k] 53 | else: 54 | print("==> WARNING: Skipping %s" % k) 55 | print( 56 | "%s out of %s keys are matched" 57 | % (len(new_state_dict.keys()), len(state_dict.keys())) 58 | ) 59 | return new_state_dict 60 | 61 | 62 | def get_param_num(model): 63 | num_param = sum(param.numel() for param in model.parameters()) 64 | return num_param 65 | 66 | 67 | def get_vocoder(config, device): 68 | config = hifigan.AttrDict(HIFIGAN_16K_64) 69 | vocoder = hifigan.Generator(config) 70 | vocoder.eval() 71 | vocoder.remove_weight_norm() 72 | vocoder.to(device) 73 | return vocoder 74 | 75 | 76 | def vocoder_infer(mels, vocoder, lengths=None): 77 | with torch.no_grad(): 78 | wavs = vocoder(mels).squeeze(1) 79 | 80 | wavs = (wavs.cpu().numpy() * 32768).astype("int16") 81 | 82 | if lengths is not None: 83 | wavs = wavs[:, :lengths] 84 | 85 | return wavs 86 | -------------------------------------------------------------------------------- /code/audioldm/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/audioldm/latent_diffusion/__init__.py -------------------------------------------------------------------------------- /code/audioldm/latent_diffusion/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_( 47 | one_minus_decay * (shadow_params[sname] - m_param[key]) 48 | ) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /code/audioldm/variational_autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/code/audioldm/variational_autoencoder/__init__.py -------------------------------------------------------------------------------- /code/audioldm/variational_autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audioldm.latent_diffusion.ema import * 3 | from audioldm.variational_autoencoder.modules import Encoder, Decoder 4 | from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution 5 | 6 | from audioldm.hifigan.utilities import get_vocoder, vocoder_infer 7 | 8 | 9 | class AutoencoderKL(nn.Module): 10 | def __init__( 11 | self, 12 | ddconfig=None, 13 | lossconfig=None, 14 | image_key="fbank", 15 | embed_dim=None, 16 | time_shuffle=1, 17 | subband=1, 18 | ckpt_path=None, 19 | reload_from_ckpt=None, 20 | ignore_keys=[], 21 | colorize_nlabels=None, 22 | monitor=None, 23 | base_learning_rate=1e-5, 24 | ): 25 | super().__init__() 26 | 27 | self.encoder = Encoder(**ddconfig) 28 | self.decoder = Decoder(**ddconfig) 29 | 30 | self.subband = int(subband) 31 | 32 | if self.subband > 1: 33 | print("Use subband decomposition %s" % self.subband) 34 | 35 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 36 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 37 | 38 | self.vocoder = get_vocoder(None, "cpu") 39 | self.embed_dim = embed_dim 40 | 41 | if monitor is not None: 42 | self.monitor = monitor 43 | 44 | self.time_shuffle = time_shuffle 45 | self.reload_from_ckpt = reload_from_ckpt 46 | self.reloaded = False 47 | self.mean, self.std = None, None 48 | 49 | def encode(self, x): 50 | # x = self.time_shuffle_operation(x) 51 | x = self.freq_split_subband(x) 52 | h = self.encoder(x) 53 | moments = self.quant_conv(h) 54 | posterior = DiagonalGaussianDistribution(moments) 55 | return posterior 56 | 57 | def decode(self, z): 58 | z = self.post_quant_conv(z) 59 | dec = self.decoder(z) 60 | dec = self.freq_merge_subband(dec) 61 | return dec 62 | 63 | def decode_to_waveform(self, dec): 64 | dec = dec.squeeze(1).permute(0, 2, 1) 65 | wav_reconstruction = vocoder_infer(dec, self.vocoder) 66 | return wav_reconstruction 67 | 68 | def forward(self, input, sample_posterior=True): 69 | posterior = self.encode(input) 70 | if sample_posterior: 71 | z = posterior.sample() 72 | else: 73 | z = posterior.mode() 74 | 75 | if self.flag_first_run: 76 | print("Latent size: ", z.size()) 77 | self.flag_first_run = False 78 | 79 | dec = self.decode(z) 80 | 81 | return dec, posterior 82 | 83 | def freq_split_subband(self, fbank): 84 | if self.subband == 1 or self.image_key != "stft": 85 | return fbank 86 | 87 | bs, ch, tstep, fbins = fbank.size() 88 | 89 | assert fbank.size(-1) % self.subband == 0 90 | assert ch == 1 91 | 92 | return ( 93 | fbank.squeeze(1) 94 | .reshape(bs, tstep, self.subband, fbins // self.subband) 95 | .permute(0, 2, 1, 3) 96 | ) 97 | 98 | def freq_merge_subband(self, subband_fbank): 99 | if self.subband == 1 or self.image_key != "stft": 100 | return subband_fbank 101 | assert subband_fbank.size(1) == self.subband # Channel dimension 102 | bs, sub_ch, tstep, fbins = subband_fbank.size() 103 | return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) 104 | -------------------------------------------------------------------------------- /code/audioldm/variational_autoencoder/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.mean( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.mean( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /code/ddm_inversion/ddim_inversion.py: -------------------------------------------------------------------------------- 1 | # Code from inbarhub/DDPM_inversoin and from google/prompt-to-prompt 2 | 3 | from typing import Union, Optional, List 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from utils import get_text_embeddings 8 | 9 | 10 | def next_step(ldm_model, model_output: Union[torch.FloatTensor, np.ndarray], 11 | timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 12 | timestep, next_timestep = min(timestep - ldm_model.model.scheduler.config.num_train_timesteps 13 | // ldm_model.model.scheduler.num_inference_steps, 999), timestep 14 | alpha_prod_t = ldm_model.model.scheduler.alphas_cumprod[timestep] if timestep >= 0 else ldm_model.model.scheduler.final_alpha_cumprod 15 | alpha_prod_t_next = ldm_model.model.scheduler.alphas_cumprod[next_timestep] 16 | beta_prod_t = 1 - alpha_prod_t 17 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 18 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 19 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 20 | return next_sample 21 | 22 | 23 | def get_noise_pred(ldm_model, latent, t, text_emb, uncond_emb, cfg_scale): 24 | noise_pred_uncond, _, _ = ldm_model.unet_forward( 25 | latent, 26 | timestep=t, 27 | encoder_hidden_states=uncond_emb.embedding_hidden_states, 28 | class_labels=uncond_emb.embedding_class_lables, 29 | encoder_attention_mask=uncond_emb.boolean_prompt_mask, 30 | ) 31 | 32 | noise_prediction_text, _, _ = ldm_model.unet_forward( 33 | latent, 34 | timestep=t, 35 | encoder_hidden_states=text_emb.embedding_hidden_states, 36 | class_labels=text_emb.embedding_class_lables, 37 | encoder_attention_mask=text_emb.boolean_prompt_mask, 38 | ) 39 | 40 | noise_pred = noise_pred_uncond.sample + cfg_scale * (noise_prediction_text.sample - noise_pred_uncond.sample) 41 | return noise_pred 42 | 43 | 44 | @torch.no_grad() 45 | def ddim_inversion(ldm_model, w0, prompts, cfg_scale, num_inference_steps, skip): 46 | 47 | _, text_emb, uncond_emb = get_text_embeddings(prompts, [""], ldm_model) 48 | 49 | latent = w0.clone().detach() 50 | for i in tqdm(range(num_inference_steps)): 51 | if num_inference_steps - i <= skip: 52 | break 53 | t = ldm_model.model.scheduler.timesteps[len(ldm_model.model.scheduler.timesteps) - i - 1] 54 | noise_pred = get_noise_pred(ldm_model, latent, t, text_emb, uncond_emb, cfg_scale) 55 | latent = next_step(ldm_model, noise_pred, t, latent) 56 | return latent 57 | 58 | 59 | @torch.no_grad() 60 | def text2image_ldm_stable(ldm_model, prompt: List[str], num_inference_steps: int = 50, 61 | guidance_scale: float = 7.5, xt: Optional[torch.FloatTensor] = None, skip: int = 0): 62 | _, text_emb, uncond_emb = get_text_embeddings(prompt, [""], ldm_model) 63 | 64 | for t in tqdm(ldm_model.model.scheduler.timesteps[skip:]): 65 | noise_pred_uncond, _, _ = ldm_model.unet_forward( 66 | xt, 67 | timestep=t, 68 | encoder_hidden_states=uncond_emb.embedding_hidden_states, 69 | class_labels=uncond_emb.embedding_class_lables, 70 | encoder_attention_mask=uncond_emb.boolean_prompt_mask, 71 | ) 72 | 73 | noise_prediction_text, _, _ = ldm_model.unet_forward( 74 | xt, 75 | timestep=t, 76 | encoder_hidden_states=text_emb.embedding_hidden_states, 77 | class_labels=text_emb.embedding_class_lables, 78 | encoder_attention_mask=text_emb.boolean_prompt_mask, 79 | ) 80 | 81 | noise_pred = noise_pred_uncond.sample + guidance_scale * (noise_prediction_text.sample - noise_pred_uncond.sample) 82 | xt = ldm_model.model.scheduler.step(noise_pred, t, xt, eta=0).prev_sample 83 | 84 | return xt 85 | -------------------------------------------------------------------------------- /code/images_run_sdedit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models import load_model 3 | import os 4 | from torch import inference_mode 5 | import torch 6 | import wandb 7 | from tqdm import tqdm 8 | from utils import set_reproducability, get_text_embeddings, load_image 9 | from pc_drift import forward_directional 10 | import torchvision.transforms as T 11 | import numpy as np 12 | # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--device_num", type=int, default=0, help="GPU device number") 17 | parser.add_argument('-s', "--seed", type=int, default=None, help="Random seed") 18 | parser.add_argument("--model_id", type=str, choices=["CompVis/stable-diffusion-v1-4"], 19 | default="CompVis/stable-diffusion-v1-4", help='Audio diffusion model to use') 20 | 21 | parser.add_argument("--init_im", type=str, required=True, help='Audio to invert and extract PCs from') 22 | parser.add_argument("--cfg_tar", type=float, default=12, 23 | help='Classifier-free guidance strength for reverse process') 24 | parser.add_argument("--num_diffusion_steps", type=int, default=100, 25 | help="Number of diffusion steps. TANGO and AudioLDM2 are recommended to be used with 200 steps" 26 | ", while AudioLDM is recommeneded to be used with 100 steps") 27 | parser.add_argument("--target_prompt", type=str, nargs='+', default=[""], 28 | help="Prompt to accompany the reverse process. Should describe the wanted edited audio.") 29 | parser.add_argument("--target_neg_prompt", type=str, nargs='+', default=[""], 30 | help="Negative prompt to accompany the inversion and generation process") 31 | parser.add_argument('--results_path', default='sdedit', help='path to dump results') 32 | parser.add_argument("--tstart", type=int, default=50, 33 | help="Diffusion timestep to start the reverse process from. Controls editing strength.") 34 | 35 | parser.add_argument('--wandb_name', type=str, default=None) 36 | parser.add_argument('--wandb_group', type=str, default=None) 37 | parser.add_argument('--wandb_disable', action='store_true') 38 | 39 | args = parser.parse_args() 40 | args.eta = 1. 41 | set_reproducability(args.seed, extreme=False) 42 | skip = args.num_diffusion_steps - args.tstart 43 | 44 | image_name_png = f's{args.seed}_skip{skip}_cfg{args.cfg_tar}' 45 | args.image_name_png = image_name_png 46 | 47 | wandb.login(key='') 48 | wandb_run = wandb.init(project="AudInv", entity='', config={}, 49 | name=args.wandb_name if args.wandb_name is not None else image_name_png, 50 | group=args.wandb_group, 51 | mode='disabled' if args.wandb_disable else 'online', 52 | settings=wandb.Settings(_disable_stats=True), 53 | job_type='sdedit_images') 54 | wandb.config.update(args) 55 | 56 | device = f"cuda:{args.device_num}" 57 | torch.cuda.set_device(args.device_num) 58 | 59 | ldm_stable = load_model(args.model_id, device, args.num_diffusion_steps) 60 | with torch.no_grad(): 61 | x0 = load_image(args.init_im, device=device) 62 | torch.cuda.empty_cache() 63 | 64 | with inference_mode(), torch.no_grad(): 65 | w0 = ldm_stable.vae_encode(x0) 66 | 67 | text_embeddings_class_labels, text_emb, uncond_emb = get_text_embeddings( 68 | args.target_prompt, args.target_neg_prompt, ldm_stable) 69 | 70 | timesteps = ldm_stable.model.scheduler.timesteps 71 | latents = [] 72 | for _ in range(len(timesteps) + 1): 73 | shape = (1, ldm_stable.model.unet.config.in_channels, w0.shape[2], 74 | w0.shape[3]) 75 | lat = torch.randn(shape, device=device, dtype=w0.dtype) 76 | 77 | # scale the initial noise by the standard deviation required by the scheduler 78 | lat = lat * ldm_stable.model.scheduler.init_noise_sigma 79 | latents.append(lat) 80 | 81 | timesteps = timesteps[skip:] 82 | latents = latents[skip + 1:] 83 | 84 | noise = torch.randn_like(w0, device=device) 85 | xt = ldm_stable.model.scheduler.add_noise(w0, noise, timesteps[:1].unsqueeze(0)) 86 | 87 | del noise, w0 88 | 89 | for it, t in tqdm(enumerate(timesteps), total=len(timesteps)): 90 | xt, _ = forward_directional( 91 | ldm_stable, xt, t, latents[it], uncond_emb, text_emb, args.cfg_tar, 92 | eta=args.eta) 93 | 94 | del latents, uncond_emb, text_emb 95 | torch.cuda.empty_cache() 96 | 97 | with inference_mode(): 98 | x0_dec = ldm_stable.vae_decode(xt) 99 | if x0_dec.dim() < 4: 100 | x0_dec = x0_dec[None, :, :, :] 101 | x0_dec = x0_dec.clamp(-1, 1) 102 | 103 | with torch.no_grad(): 104 | x0_dec = (x0_dec + 1) / 2 105 | x0 = (x0 + 1) / 2 106 | image = T.functional.to_pil_image(x0_dec[0].cpu().detach()) 107 | orig_image = T.functional.to_pil_image(x0[0].cpu().detach()) 108 | 109 | # same output 110 | save_path = os.path.join(args.results_path, 111 | args.model_id.split('/')[1], os.path.basename(args.init_im).split('.')[0], 112 | 'pmt_' + "__".join([x.replace(" ", "_") for x in args.target_prompt]) + 113 | "__neg__" + "__".join([x.replace(" ", "_") for x in args.target_neg_prompt])) 114 | os.makedirs(save_path, exist_ok=True) 115 | 116 | save_full_path_wave = os.path.join(save_path, image_name_png + ".png") 117 | save_full_path_origwave = os.path.join(save_path, "orig.png") 118 | 119 | image.save(save_full_path_wave) 120 | orig_image.save(save_full_path_origwave) 121 | logging_dict = {'image_orig': wandb.Image(np.array(orig_image), caption='orig'), 122 | 'image_gen': wandb.Image(np.array(image), caption=image_name_png)} 123 | 124 | wandb.log(logging_dict) 125 | wandb_run.finish() 126 | -------------------------------------------------------------------------------- /docs/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HilaManor/AudioEditingCode/d5e2723768f548104d5bc182622d996b056ee840/docs/favicon.ico -------------------------------------------------------------------------------- /docs/functionality.js: -------------------------------------------------------------------------------- 1 | // javascript file that set the functionality of some of the components in the index.html file 2 | // imported to the index.html file via the