├── zero-shot_sed_eval ├── eval.sh ├── README.md └── evaluate.py ├── requrirements.txt ├── clotho-moment_generetor ├── config.yaml ├── 2_convert_bg.py ├── README.md ├── 3_clip_bg.py ├── 4_clip_fg.py ├── 1_collect_data.py ├── 5_create_recipe.py └── 6_create_dataset.py ├── feature_extractor ├── README.md ├── extract_feat.sh ├── extract_text_feat.py └── extract_audio_feat.py ├── README.md └── .gitignore /zero-shot_sed_eval/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ref_jsonl=/LIGHTHOUSE/PATH/data/tut2017/tut2017_test_release.jsonl 4 | pred_jsonl=/YOUR/PREDICTED/RESULT/IN/LIGHTHOUSE/hl_val_submission.jsonl 5 | 6 | python evaluate.py $ref_jsonl $pred_jsonl 7 | -------------------------------------------------------------------------------- /requrirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.3 2 | tqdm==4.66.4 3 | PyYAML==6.0.1 4 | librosa==0.10.1 5 | torch==2.1.0 6 | torchaudio==2.1.0 7 | torchlibrosa==0.1.0 8 | torchvision==0.16.0 9 | transformers==4.34.0 10 | msclap==1.3.3 11 | scikit-learn==1.3.1 12 | -------------------------------------------------------------------------------- /zero-shot_sed_eval/README.md: -------------------------------------------------------------------------------- 1 | # Zero-shot SED evaluation 2 | These scripts are used to evaluate the zero-shot SED system. 3 | 4 | ## How to evaluate the zero-shot SED system? 5 | 1. Install the required packages: 6 | ```bash 7 | pip install -r ../requirements.txt 8 | ``` 9 | 2. Set the config in `eval.sh` and run the following command: 10 | ```bash 11 | bash eval.sh 12 | ``` 13 | -------------------------------------------------------------------------------- /clotho-moment_generetor/config.yaml: -------------------------------------------------------------------------------- 1 | root_wwt: /YOUR/WALKING_TOUR/DIRECTORY/WalkingTourVideos 2 | root_clotho: /YOUR/CLOTHO/DIRECTORY/clotho_v2.1 3 | save_dir: ./clotho-moment 4 | 5 | tmp_dir: ./.tmp 6 | 7 | # dataset config 8 | sr: 32000 9 | split_ratio: [7, 1, 2] 10 | clip_duration: 60 11 | clip_interval: 1 12 | clip_db: 5 13 | min_fg_db: 5 14 | max_fg_db: -5 15 | min_bg_db: -25 16 | max_bg_db: -15 17 | avg_interval: 30 18 | -------------------------------------------------------------------------------- /feature_extractor/README.md: -------------------------------------------------------------------------------- 1 | # Feature Extractor 2 | ## What is this? 3 | These scripts provide the procedure to extract features for lighthouse from audio files using [MS-CLAP](https://github.com/microsoft/CLAP). 4 | 5 | If you want to extract features from your own audio files, please set the path to the audio files in `extract_feat.sh`. 6 | 7 | ## How to extract features? 8 | 1. Install the required packages: 9 | ```bash 10 | pip install -r ../requirements.txt 11 | ``` 12 | 2. Set the config in `extract_feat.sh` and run the following command: 13 | ```bash 14 | bash extract_feat.sh 15 | ``` 16 | -------------------------------------------------------------------------------- /feature_extractor/extract_feat.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | win_sec=1 4 | hop_sec=1 5 | model_name=clap 6 | 7 | data_dir=../clotho-moment_generetor/clotho-moment 8 | 9 | for mode in train valid test; do 10 | python extract_text_feat.py \ 11 | ${data_dir}/text/${mode}.jsonl \ 12 | ${data_dir}/feature \ 13 | --model_name=${model_name} 14 | done 15 | 16 | for mode in train valid test; do 17 | python extract_audio_feat.py \ 18 | ${data_dir}/wav/${mode} \ 19 | ${data_dir}/feature \ 20 | --win_sec ${win_sec} \ 21 | --hop_sec ${hop_sec} \ 22 | --model_name=${model_name} 23 | done 24 | -------------------------------------------------------------------------------- /clotho-moment_generetor/2_convert_bg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import yaml 6 | 7 | 8 | def extract_wav(save_dir, mode, tmp_dir): 9 | os.makedirs(f"{tmp_dir}/{mode}/bg", exist_ok=True) 10 | with open(save_dir / "json" / f"bg_{mode}.json") as f: 11 | dict_bg = json.load(f) 12 | 13 | # mp4 to wav 14 | for name, value in dict_bg.items(): 15 | print(value) 16 | path_bg = value["original_path"] 17 | command = f""" 18 | ffmpeg -i {path_bg} -vn -ac 1 -b:a 192k ./{tmp_dir}/{mode}/bg/{name}.wav 19 | """ 20 | os.system(command) 21 | 22 | 23 | if __name__ == "__main__": 24 | with open("config.yaml") as f: 25 | config = yaml.safe_load(f) 26 | 27 | for mode in ["train", "valid", "test"]: 28 | extract_wav(Path(config["save_dir"]), mode, config["tmp_dir"]) 29 | -------------------------------------------------------------------------------- /clotho-moment_generetor/README.md: -------------------------------------------------------------------------------- 1 | # Clotho-Moment Generetor 2 | ## What is this? 3 | These scripts provide the procedure to generate Clotho-Moments from Clotho and Walking Tours. 4 | - Clotho: [https://zenodo.org/records/4783391](https://zenodo.org/records/4783391) 5 | - Walking Tour: [https://shashankvkt.github.io/dora](https://shashankvkt.github.io/dora) 6 | 7 | ## How to generate Clotho-Moments? 8 | 1. Install the required packages: 9 | ```bash 10 | pip install -r ../requirements.txt 11 | ``` 12 | 2. Install ffmpeg. 13 | 3. Download Clotho and Walking Tour datasets. 14 | 4. Set the path to the downloaded datasets and the save directory in "config.yaml". 15 | 5. Run the following command to generate Clotho-Moments: 16 | ```bash 17 | python 1_collect_data.py 18 | python 2_covert_bg.py 19 | python 3_clip_bg.py 20 | python 4_clip_fg.py 21 | python 5_create_recipe.py 22 | python 6_create_dataset.py 23 | ``` 24 | 25 | ## Reproduce the results 26 | After executing `python 5_create_recipe.py`, you can reproduce the results by overwriting the `/SAVE_DIR/json/recipe_*.json` by the provided recipe files this repository. 27 | 28 | Note that if your move the save directory, you need to change the path in the recipe files. 29 | -------------------------------------------------------------------------------- /clotho-moment_generetor/3_clip_bg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import soundfile as sf 7 | import yaml 8 | from tqdm import tqdm 9 | 10 | 11 | def preprocess(save_dir, mode, tmp_dir, clip_duration, clip_interval): 12 | dict_clip_bg = {} 13 | path_wav = tmp_dir / mode / "bg" 14 | for fname in path_wav.glob("*.wav"): 15 | os.makedirs(path_wav / fname.stem, exist_ok=True) 16 | dict_clip_bg[fname.stem] = {"original_path": str(fname), "clips": []} 17 | 18 | s, sr = sf.read(fname) 19 | for start_sample in tqdm(np.arange(0, len(s), clip_interval * sr)): 20 | end_sample = start_sample + clip_duration * sr 21 | if end_sample > len(s): 22 | break 23 | start_sample = int(start_sample) 24 | end_sample = int(end_sample) 25 | 26 | _s = s[start_sample:end_sample] 27 | start_sec, end_sec = round(start_sample / sr, 1), round(end_sample / sr, 1) 28 | 29 | save_path = str(path_wav / fname.stem / f"{start_sec}_{end_sec}.wav") 30 | 31 | sf.write(save_path, _s, sr) 32 | dict_clip_bg[fname.stem]["clips"].append(save_path) 33 | 34 | with open(save_dir / "json" / f"bg_{mode}.json", "w") as f: 35 | json.dump(dict_clip_bg, f, indent=2) 36 | 37 | 38 | if __name__ == "__main__": 39 | with open("config.yaml") as f: 40 | config = yaml.safe_load(f) 41 | 42 | save_dir = Path(config["save_dir"]) 43 | 44 | for mode in ["train", "valid", "test"]: 45 | preprocess( 46 | Path(config["save_dir"]), 47 | mode, 48 | Path(config["tmp_dir"]), 49 | config["clip_duration"], 50 | config["clip_interval"], 51 | ) 52 | -------------------------------------------------------------------------------- /clotho-moment_generetor/4_clip_fg.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import soundfile as sf 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | 9 | def preprocess(save_dir, mode, tmp_dir, clip_db): 10 | with open(save_dir / "json" / f"fg_{mode}.json") as f: 11 | dict_fg = json.load(f) 12 | 13 | dir_wav = Path(tmp_dir) / mode / "fg" 14 | dir_wav.mkdir(parents=True, exist_ok=True) 15 | 16 | for key, value in tqdm(dict_fg.items()): 17 | s, sr = sf.read(value["original_path"]) 18 | 19 | global_power = (s**2).mean() 20 | threshold = global_power * 10 ** (-clip_db / 10) 21 | 22 | # calc onset 23 | onset = 0 24 | for sec in range(len(s) // sr): 25 | local_power = (s[sec * sr : (sec + 1) * sr] ** 2).mean() 26 | if local_power > threshold: 27 | onset = sec 28 | break 29 | 30 | # calc offset 31 | offset = len(s) // sr 32 | for sec in range(len(s) // sr, 0, -1): 33 | local_power = (s[(sec - 1) * sr : sec * sr] ** 2).mean() 34 | if local_power > threshold: 35 | offset = sec 36 | break 37 | 38 | s = s[onset * sr : offset * sr] 39 | sf.write(dir_wav / f"{key}.wav", s, sr) 40 | 41 | value.update({"duration": len(s) / sr, "clip": str(dir_wav / f"{key}.wav")}) 42 | dict_fg[key] = value 43 | 44 | with open(save_dir / "json" / f"fg_{mode}.json", "w") as f: 45 | json.dump(dict_fg, f, indent=4) 46 | 47 | 48 | if __name__ == "__main__": 49 | with open("config.yaml") as f: 50 | config = yaml.safe_load(f) 51 | 52 | save_dir = Path(config["save_dir"]) 53 | 54 | for mode in ["train", "valid", "test"]: 55 | preprocess( 56 | Path(config["save_dir"]), 57 | mode, 58 | Path(config["tmp_dir"]), 59 | config["clip_db"], 60 | ) 61 | -------------------------------------------------------------------------------- /clotho-moment_generetor/1_collect_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | from pathlib import Path 5 | 6 | import yaml 7 | 8 | 9 | def collect_fg(save_dir, root_clotho, mode, clotho_mode): 10 | dict_fg = {} 11 | # load captions 12 | with open(root_clotho / f"clotho_captions_{clotho_mode}.csv") as f: 13 | reader = csv.reader(f) 14 | next(reader) 15 | for row in reader: 16 | key = row[0][:-4] 17 | dict_fg[key] = {"captions": [cap for n, cap in enumerate(row[1:])]} 18 | 19 | # load wavs 20 | for key in dict_fg.keys(): 21 | originak_path = root_clotho / clotho_mode / f"{key}.wav" 22 | dict_fg[key].update({"original_path": str(originak_path)}) 23 | 24 | print(f"Number of foreground clips in {mode}: {len(dict_fg)}") 25 | 26 | # save data 27 | with open(save_dir / "json" / f"fg_{mode}.json", "w") as f: 28 | json.dump(dict_fg, f, indent=4) 29 | 30 | 31 | def collect_bg(save_dir, root_wwt, split_ratio): 32 | split_ratio = [r / sum(split_ratio) for r in split_ratio] 33 | list_bg = list(root_wwt.glob("*.mp4")) 34 | num_tr = round(len(list_bg) * split_ratio[0]) 35 | num_vl = round(len(list_bg) * split_ratio[1]) 36 | 37 | split_list_bg = [ 38 | list_bg[:num_tr], 39 | list_bg[num_tr : num_tr + num_vl], 40 | list_bg[num_tr + num_vl :], 41 | ] 42 | 43 | for mode, _list_bg in zip(["train", "valid", "test"], split_list_bg): 44 | dict_bg = {} 45 | for path_bg in _list_bg: 46 | dict_bg[path_bg.stem] = {"original_path": str(path_bg)} 47 | 48 | print(f"Number of background clips in {mode}: {len(dict_bg)}") 49 | 50 | with open(save_dir / "json" / f"bg_{mode}.json", "w") as f: 51 | json.dump(dict_bg, f, indent=4) 52 | 53 | 54 | if __name__ == "__main__": 55 | with open("config.yaml") as f: 56 | config = yaml.safe_load(f) 57 | 58 | save_dir = Path(config["save_dir"]) 59 | 60 | os.makedirs(save_dir / "json", exist_ok=True) 61 | 62 | for mode, clotho_mode in zip( 63 | ["train", "valid", "test"], ["development", "validation", "evaluation"] 64 | ): 65 | collect_fg(save_dir, Path(config["root_clotho"]), mode, clotho_mode) 66 | 67 | collect_bg(save_dir, Path(config["root_wwt"]), config["split_ratio"]) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lighthouse-Wrapper-for-Audio-Moment-Retrieval 2 | 3 | ## What is this? 4 | This repository provides the procedure to conduct experiments with [Lighthouse](https://github.com/line/lighthouse) for the paper ["Language-based Audio Moment Retrieval" (Munakata et.al., ICASSP 2025)](https://arxiv.org/abs/2409.15672). 5 | In addition, it supports the following functionalities: 6 | - Generation of Clotho-Moments from Clotho and UnAv-100 7 | - Extraction of CLAP Features 8 | - Evaluation of Zero-shot Sound Event Detection 9 | The raw audio dataset is provided in the following links: 10 | - HuggingFace 11 | - [Clotho-Moment](https://huggingface.co/datasets/lighthouse-emnlp2024/Clotho-Moment) 12 | - Zenodo 13 | - [Clotho-Moment/UnAV100-subset/TUT Sound Events 2017](https://zenodo.org/records/13836117) 14 | The captions are available in Lighthouse: 15 | - [Clotho-Moment](https://github.com/line/lighthouse/tree/main/data/clotho_moment) 16 | - [UnAV100-subset](https://github.com/line/lighthouse/tree/main/data/unav100-subset) 17 | - [TUT Sound Events 2017](https://github.com/line/lighthouse/tree/main/data/tut2017) 18 | 19 | 20 | ## How to train/evaluate AMR models with Lighthouse? 21 | 1. Install [Lighthouse](https://github.com/line/lighthouse) 22 | 23 | 2. Download extracted CLAP features of Clotho-Moment/UnAV100-subset/TUT Sound Events 2017 from [here](https://zenodo.org/records/13806234) 24 | - You can also download wav files from [here](https://zenodo.org/records/13836117) 25 | 26 | 3. Set the path to the downloaded features in "(LIGHTHOUSE_PATH)/features". 27 | - For example, if you downloaded Clotho-Moment features, set the path to "(LIGHTHOUSE_PATH)/features/clotho-moment". 28 | 29 | 4. Run the following command to train the AMR model: 30 | ```bash 31 | python training/train.py --model qd_detr --dataset clotho-moment --feature clap 32 | ``` 33 | 34 | 5. Run the following command to evaluate the AMR model: 35 | ```bash 36 | model=qd_detr 37 | dataset=unav100-subset 38 | feature=clap 39 | model_path={lighthouse_dir}/results/qd_detr/clotho-moment/clap/best.ckpt 40 | eval_split_name=val 41 | eval_path=data/unav100-subset/unav100-subset_test_release.jsonl 42 | 43 | python training/evaluate.py \ 44 | --model $model \ 45 | --dataset $dataset \ 46 | --feature $feature \ 47 | --model_path $model_path \ 48 | --eval_split_name $eval_split_name \ 49 | --eval_path $eval_path 50 | ``` 51 | 52 | ## Generation of Clotho-Moments 53 | `./clotho-moment_generetor` generates Clotho-Moments from Clotho and Walking Tours. 54 | Please read the README.md in the directory for more details. 55 | 56 | ## Feature Extraction using CLAP 57 | `./feature_extractor` extracts CLAP features for lighthouse. 58 | Please read the README.md in the directory for more details. 59 | 60 | ## Evaluation of Zero-shot Sound Event Detection 61 | `./zero-shot_sed_eval` evaluates the zero-shot SED system. 62 | Please read the README.md in the directory for more details. 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /feature_extractor/extract_text_feat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from msclap import CLAP 8 | from tqdm import tqdm 9 | 10 | 11 | def save_text(data_path, save_dir, extractor, model_name): 12 | if not data_path.exists(): 13 | print(f"{data_path} does not exist.") 14 | return 15 | 16 | with open(data_path) as f: 17 | data = [json.loads(line) for line in f] 18 | 19 | # Setup the directory to save the text features" 20 | dir_save_feats = save_dir / f"{model_name}_text" 21 | dir_save_feats.mkdir(exist_ok=True, parents=True) 22 | 23 | print("save text data...") 24 | for _d in tqdm(data): 25 | feat, proj_feat = extractor.extract_text_feats(_d["query"]) 26 | 27 | path_feats = dir_save_feats / f"qid{_d['qid']}.npz" 28 | np.savez(path_feats, last_hidden_state=feat) 29 | 30 | 31 | class ClapExtractor: 32 | def __init__(self): 33 | # if gpu is available, use it 34 | self.use_cuda = torch.cuda.is_available() 35 | self.wrapper = CLAP(use_cuda=self.use_cuda, version="2023") 36 | self.text_enc = self.wrapper.clap.caption_encoder 37 | if self.use_cuda: 38 | print("Inference on GPU") 39 | self.text_enc = self.text_enc.cuda() 40 | else: 41 | print("Inference on CPU") 42 | 43 | @torch.no_grad() 44 | def extract_text_feats(self, text): 45 | x = self.wrapper.preprocess_text([text]) 46 | mask = x["attention_mask"] 47 | len_output = torch.sum(mask, dim=-1, keepdims=True) 48 | out = self.text_enc.base(**x) 49 | hidden_states = out[0] 50 | pooled_output = out[1] 51 | 52 | if "clip" in self.text_enc.text_model: 53 | out = self.clip_text_projection(pooled_output) # get CLS token output 54 | elif "gpt" in self.text_enc.text_model: 55 | batch_size = x["input_ids"].shape[0] 56 | sequence_lengths = ( 57 | torch.ne(x["input_ids"], 0).sum(-1) - 1 58 | ) # tensor([13, 14, 18, 17]) 59 | out = hidden_states[ 60 | torch.arange(batch_size, device=hidden_states.device), 61 | sequence_lengths, 62 | ] # [batch_size, 768] = [4, 768] 63 | else: 64 | out = hidden_states[:, 0, :] # get CLS token output 65 | 66 | projected_feat = self.text_enc.projection(out) 67 | 68 | feat = hidden_states[0, :len_output].cpu().numpy() 69 | proj_feat = projected_feat.cpu().numpy() 70 | 71 | return feat, proj_feat 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("data_dir", type=Path, default=None, help="text data path") 77 | parser.add_argument("save_dir", type=Path, help="directory to save the features") 78 | parser.add_argument("--model_name", default="clap", help="model name") 79 | args = parser.parse_args() 80 | 81 | if args.model_name == "clap": 82 | extractor = ClapExtractor() 83 | else: 84 | raise ValueError(f"Invalid model name: {args.model_name}") 85 | 86 | save_text( 87 | args.data_dir, 88 | args.save_dir, 89 | extractor, 90 | args.model_name, 91 | ) 92 | -------------------------------------------------------------------------------- /clotho-moment_generetor/5_create_recipe.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import soundfile as sf 7 | import yaml 8 | from tqdm import tqdm 9 | 10 | 11 | class Loader: 12 | def __init__( 13 | self, 14 | min_fg_db, 15 | max_fg_db, 16 | min_bg_db, 17 | max_bg_db, 18 | avg_interval, 19 | ): 20 | self.min_fg_db = min_fg_db 21 | self.max_fg_db = max_fg_db 22 | self.min_bg_db = min_bg_db 23 | self.max_bg_db = max_bg_db 24 | self.avg_interval = avg_interval 25 | self.qid = 0 26 | 27 | def create_recipe(self, path_fg, path_bg, save_dir, mode): 28 | with open(path_fg) as f: 29 | self.dict_fg = json.load(f) 30 | with open(path_bg) as f: 31 | self.dict_bg = json.load(f) 32 | 33 | list_recipe = [] 34 | for name, path_bg in self.load_bg(): 35 | data_name = f"{Path(path_bg).parent.stem}_{Path(path_bg).stem}" 36 | recipe = {"name": data_name, "bg": {"path": path_bg}} 37 | 38 | info = sf.info(path_bg) 39 | duration_bg = info.frames / info.samplerate 40 | 41 | recipe["fg"] = self.fg_sample(duration_bg) 42 | recipe["bg"]["dB"] = ( 43 | random.random() * (self.max_bg_db - self.min_bg_db) + self.min_bg_db 44 | ) 45 | 46 | list_recipe.append(recipe) 47 | 48 | with open(save_dir / "json" / f"recipe_{mode}.json", "w") as f: 49 | json.dump(list_recipe, f, indent=4) 50 | 51 | def load_bg(self): 52 | for name, value in self.dict_bg.items(): 53 | for clip in tqdm(value["clips"]): 54 | yield name, clip 55 | 56 | def fg_sample(self, duration_bg): 57 | keys = list(self.dict_fg.keys()) 58 | random.shuffle(keys) 59 | list_fg = [] 60 | current_time = 0 61 | 62 | for sample_key in keys: 63 | current_time += np.random.exponential(self.avg_interval) 64 | dict_status, duration_fg = self.get_info(sample_key, current_time) 65 | current_time += duration_fg 66 | 67 | if current_time > duration_bg: 68 | break 69 | else: 70 | list_fg.append(dict_status) 71 | self.qid += 1 72 | 73 | list_fg.sort(key=lambda x: x["start_time"]) 74 | 75 | return list_fg 76 | 77 | def get_info(self, sample_key, start_time): 78 | dict_status = {} 79 | 80 | path_fg = self.dict_fg[sample_key]["clip"] 81 | cap = random.sample(self.dict_fg[sample_key]["captions"], 1)[0] 82 | db = random.random() * (self.max_fg_db - self.min_fg_db) + self.min_fg_db 83 | 84 | dict_status["qid"] = self.qid 85 | dict_status["path"] = path_fg 86 | dict_status["caption"] = cap 87 | dict_status["dB"] = db 88 | dict_status["duration"] = self.dict_fg[sample_key]["duration"] 89 | dict_status["start_time"] = start_time 90 | 91 | return dict_status, dict_status["duration"] 92 | 93 | 94 | if __name__ == "__main__": 95 | with open("config.yaml") as f: 96 | config = yaml.safe_load(f) 97 | 98 | save_dir = Path(config["save_dir"]) 99 | loader = Loader( 100 | config["min_fg_db"], 101 | config["max_fg_db"], 102 | config["min_bg_db"], 103 | config["max_bg_db"], 104 | config["avg_interval"], 105 | ) 106 | 107 | for mode in ["train", "valid", "test"]: 108 | fg_json = save_dir / "json" / f"fg_{mode}.json" 109 | bg_json = save_dir / "json" / f"bg_{mode}.json" 110 | loader.create_recipe(fg_json, bg_json, save_dir, mode) 111 | -------------------------------------------------------------------------------- /clotho-moment_generetor/6_create_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from functools import partial 4 | from multiprocessing import Pool 5 | from pathlib import Path 6 | 7 | import librosa 8 | import numpy as np 9 | import soundfile as sf 10 | import yaml 11 | from tqdm import tqdm 12 | 13 | 14 | def _gen_audio(data, wav_dir, duration, sr, bg_db=None): 15 | path_bg = data["bg"]["path"] 16 | if bg_db is None: 17 | bg_db = data["bg"]["dB"] 18 | data_name = f"{Path(path_bg).parent.stem}_{Path(path_bg).stem}" 19 | save_path = wav_dir / f"{data_name}.wav" 20 | if os.path.exists(save_path): 21 | return 22 | s_bg, _ = librosa.load(path_bg, sr=sr) 23 | 24 | power_bg = np.mean(s_bg**2) 25 | s_bg = 10 ** (bg_db / 20) * s_bg 26 | 27 | for fg in data["fg"]: 28 | # add fg to bg 29 | s_fg, _ = librosa.load(fg["path"], sr=sr) 30 | s_fg = s_fg / np.max(np.abs(s_fg)) 31 | 32 | weight = 10 ** (fg["dB"] / 20) * np.sqrt(power_bg) 33 | s_fg = weight * s_fg 34 | 35 | start_time = fg["start_time"] 36 | start_sample = int(start_time * sr) 37 | 38 | s_bg[start_sample : start_sample + len(s_fg)] += s_fg 39 | 40 | # save wav 41 | s_bg = 1 / np.max(np.abs(s_bg)) * s_bg 42 | sf.write(str(save_path), s_bg, sr) 43 | 44 | 45 | def _gen_text(data, wav_dir, duration, bg_db=None): 46 | path_bg = data["bg"]["path"] 47 | data_name = f"{Path(path_bg).parent.stem}_{Path(path_bg).stem}" 48 | list_info = [] 49 | 50 | for fg in data["fg"]: 51 | qid = fg["qid"] 52 | start = fg["start_time"] 53 | end = start + fg["duration"] 54 | caption = fg["caption"] 55 | 56 | _info = { 57 | "qid": f"{qid:05d}", 58 | "query": caption, 59 | "duration": duration, 60 | "vid": data_name, 61 | "relevant_windows": [[float(f"{start:.1f}"), float(f"{end:.1f}")]], 62 | "fg_dB": fg["dB"], 63 | } 64 | if bg_db is not None: 65 | _info["bg_dB"] = bg_db 66 | list_info.append(_info) 67 | 68 | return list_info 69 | 70 | 71 | def generate_data(save_dir, mode, duration, sr, bg_db): 72 | path_recipe = Path(config["save_dir"]) / "json" / f"recipe_{mode}.json" 73 | wav_dir = Path(config["save_dir"]) / "wav" / f"{sr}hz" / mode 74 | if bg_db is not None: 75 | wav_dir = wav_dir / f"{bg_db}dB" 76 | text_dir = Path(config["save_dir"]) / "text" 77 | os.makedirs(wav_dir, exist_ok=True) 78 | os.makedirs(text_dir, exist_ok=True) 79 | with open(path_recipe) as f: 80 | recipe = json.load(f) 81 | 82 | # Generate text data 83 | print(f"Generate {mode} text data") 84 | if os.path.exists(text_dir / f"{mode}.jsonl"): 85 | print(f"{mode}.jsonl already exists") 86 | else: 87 | for data in tqdm(recipe): 88 | list_info = _gen_text(data, wav_dir, duration, bg_db) 89 | for _info in list_info: 90 | # Save the query as jsonl 91 | with open(text_dir / f"{mode}.jsonl", "a") as f: 92 | json.dump(_info, f) 93 | f.write("\n") 94 | 95 | # Generate audio data 96 | print(f"Generate {mode} audio data") 97 | map_fn = partial(_gen_audio, wav_dir=wav_dir, duration=duration, bg_db=bg_db, sr=sr) 98 | with Pool(processes=16) as p: 99 | list(tqdm(p.imap_unordered(map_fn, recipe), total=len(recipe))) 100 | 101 | 102 | if __name__ == "__main__": 103 | with open("config.yaml") as f: 104 | config = yaml.safe_load(f) 105 | 106 | for mode in ["train", "valid", "test"]: 107 | generate_data( 108 | config["save_dir"], mode, config["clip_duration"], config["sr"], bg_db=None 109 | ) 110 | -------------------------------------------------------------------------------- /zero-shot_sed_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import numpy as np 5 | from sklearn.metrics import precision_recall_fscore_support 6 | 7 | 8 | def load_ref_as_array(ref_jsonl, label_resolution): 9 | with open(ref_jsonl, "r") as f: 10 | lines = f.readlines() 11 | data = [json.loads(line) for line in lines] 12 | 13 | # get label set 14 | labels = [] 15 | for item in data: 16 | query = item["query"] 17 | if query not in labels: 18 | labels.append(query) 19 | 20 | er_dict = {} 21 | for num, item in enumerate(data): 22 | # load data 23 | vid = item["vid"] 24 | query = item["query"] 25 | duration = item["duration"] 26 | relevant_windows = item["relevant_windows"] 27 | 28 | label_idx = labels.index(query) 29 | total_frames = int(np.ceil(duration * label_resolution)) 30 | time_axis = np.arange(total_frames) 31 | 32 | # load nparray 33 | er = er_dict.get(vid, np.zeros((len(labels), total_frames))) 34 | 35 | # load start and end time as nparray 36 | for t in relevant_windows: 37 | ts, te = t 38 | active_segment = (time_axis >= ts) * (time_axis <= te) 39 | er[label_idx, :] += active_segment 40 | 41 | er_dict[vid] = er 42 | 43 | # binarize 44 | for k, v in er_dict.items(): 45 | er_dict[k] = er_dict[k] > 0 46 | 47 | return er_dict, labels 48 | 49 | 50 | def load_pred_as_array(pred_jsonl, ref_er_dict, labels, threshold, label_resolution): 51 | with open(pred_jsonl, "r") as f: 52 | lines = f.readlines() 53 | data = [json.loads(line) for line in lines] 54 | 55 | er_dict = {} 56 | for num, item in enumerate(data): 57 | # load data 58 | vid = item["vid"] 59 | query = item["query"] 60 | relevant_windows = item["pred_relevant_windows"] 61 | 62 | # load nparray 63 | duration = ref_er_dict[vid].shape[1] 64 | er = er_dict.get(vid, np.zeros((len(labels), duration))) 65 | 66 | label_idx = labels.index(query) 67 | total_frames = int(np.ceil(duration / label_resolution)) 68 | time_axis = np.arange(total_frames * label_resolution) 69 | 70 | # load start and end time as nparray 71 | for t in relevant_windows: 72 | ts, te, score = t 73 | active_segment = (time_axis >= ts) * (time_axis <= te) 74 | if score > threshold: 75 | er[label_idx, :] += active_segment 76 | 77 | er_dict[vid] = er 78 | 79 | # binarize 80 | for k, v in er_dict.items(): 81 | er_dict[k] = er_dict[k] > 0 82 | 83 | return er_dict 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("ref_jsonl", type=str, help="target jsonl file") 89 | parser.add_argument("pred_jsonl", type=str, help="prediction jsonl file") 90 | parser.add_argument("--threshold", type=float, default=0.5) 91 | parser.add_argument("--label_resolution", type=float, default=1) 92 | args = parser.parse_args() 93 | 94 | ref_er_dict, labels = load_ref_as_array(args.ref_jsonl, args.label_resolution) 95 | pred_er_dict = load_pred_as_array( 96 | args.pred_jsonl, 97 | ref_er_dict, 98 | labels, 99 | args.threshold, 100 | args.label_resolution, 101 | ) 102 | 103 | # evaluate 104 | ref_frames, pred_frames = [], [] 105 | for vid in ref_er_dict.keys(): 106 | ref_frames.append(ref_er_dict[vid]) 107 | pred_frames.append(pred_er_dict[vid]) 108 | 109 | ref_frames = np.hstack(ref_frames) 110 | pred_frames = np.hstack(pred_frames) 111 | 112 | precision, recall, fscore, _ = precision_recall_fscore_support( 113 | ref_frames, pred_frames, average="micro", zero_division=0 114 | ) 115 | print("Micro Average") 116 | print("Precision, Recall, F1") 117 | print(f"{precision * 100:.2f}, {recall * 100:.2f}, {fscore * 100:.2f}") 118 | -------------------------------------------------------------------------------- /feature_extractor/extract_audio_feat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from msclap import CLAP 7 | from torch.nn import functional as F 8 | from tqdm import tqdm 9 | 10 | 11 | def dump_audio(data_dir, save_dir, extractor, model_name): 12 | list_data = sorted(data_dir.glob("*.wav")) 13 | print(list_data) 14 | if len(list_data) == 0: 15 | print(f"No audio files found in {data_dir}") 16 | return 17 | 18 | # Setup the directory to save the audio features" 19 | dir_save_feats = save_dir / model_name 20 | dir_save_feats.mkdir(exist_ok=True, parents=True) 21 | 22 | # Loop through the audio files and extract the audio featseddings 23 | print("dump audio data...") 24 | for path_wav in tqdm(list_data): 25 | path_feats = dir_save_feats / f"{path_wav.stem}.npz" 26 | if path_feats.exists(): 27 | continue 28 | feat, proj_feat = extractor.extract_audio_feats(str(path_wav)) 29 | 30 | np.savez(path_feats, features=feat) 31 | 32 | 33 | class ClapExtractor: 34 | def __init__(self, win_sec, hop_sec): 35 | # if gpu is available, use it 36 | self.use_cuda = torch.cuda.is_available() 37 | self.wrapper = CLAP(use_cuda=self.use_cuda, version="2023") 38 | if self.use_cuda: 39 | self.wrapper.clap.caption_encoder = self.wrapper.clap.caption_encoder.cuda() 40 | print("Inference on GPU") 41 | else: 42 | print("Inference on CPU") 43 | 44 | self.sl_win = SlidingWindos(win_sec, hop_sec) 45 | 46 | @torch.no_grad() 47 | def extract_audio_feats(self, path_wav): 48 | audio, sr = self.wrapper.read_audio(path_wav, resample=True) 49 | frames = self.sl_win(audio[0], sr) 50 | frames = frames.cuda() if self.use_cuda else frames 51 | 52 | feats = self.wrapper.clap.audio_encoder.base(frames)["embedding"] 53 | proj_feats = self.wrapper.clap.audio_encoder.projection(feats) 54 | 55 | feats = feats.cpu().numpy() 56 | proj_feats = proj_feats.cpu().numpy() 57 | 58 | return feats, proj_feats 59 | 60 | 61 | class SlidingWindos: 62 | def __init__(self, win_sec, hop_sec): 63 | self.win_sec = win_sec 64 | self.hop_sec = hop_sec 65 | 66 | def __call__(self, audio, sr): 67 | """ 68 | Perform sliding window processing on a 1D tensor with center-based cutting. 69 | 70 | Parameters: 71 | audio (torch.tensor): 1D tensor. 72 | win_sec (float): Length of each window. 73 | hop_sec (float): Number of elements to move the window at each step. 74 | sr (int): Sampling rate. 75 | 76 | Returns: 77 | torch.tensor: 2D tensor with shape (num_windows, win_length). 78 | """ 79 | if audio.ndim != 1: 80 | raise ValueError("Input audio must be 1D tensor.") 81 | 82 | win_length = int(self.win_sec * sr) 83 | hop_length = int(self.hop_sec * sr) 84 | 85 | half_win = win_length // 2 86 | padded_audio = F.pad(audio, (half_win, half_win), mode="constant", value=0) 87 | windows = padded_audio.unfold(0, win_length, hop_length) 88 | 89 | return torch.tensor(windows) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("data_dir", type=Path, help="audio data directory") 95 | parser.add_argument("save_dir", type=Path, help="directory to save the features") 96 | parser.add_argument("--win_sec", type=float, default=1.0, help="window length") 97 | parser.add_argument("--hop_sec", type=float, default=1.0, help="hop length") 98 | parser.add_argument("--model_name", default="clap", help="model name") 99 | args = parser.parse_args() 100 | 101 | if args.model_name == "clap": 102 | extractor = ClapExtractor(args.win_sec, args.hop_sec) 103 | else: 104 | raise ValueError(f"Invalid model name: {args.model_name}") 105 | 106 | dump_audio( 107 | args.data_dir, 108 | args.save_dir, 109 | extractor, 110 | args.model_name, 111 | ) 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | # Temporary files 179 | clotho-moment_generetor/.tmp 180 | --------------------------------------------------------------------------------