├── LICENSE ├── README.md ├── baseline ├── avse1 │ ├── README.md │ ├── config.py │ ├── dataset.py │ ├── model.py │ ├── requirements.txt │ ├── test.py │ ├── train.py │ └── utils │ │ ├── generic.py │ │ ├── nn.py │ │ ├── resnet.py │ │ └── tcn.py ├── avse2 │ ├── LICENSE │ ├── README.md │ ├── config.py │ ├── dataset.py │ ├── model.py │ ├── requirements.txt │ ├── test.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ └── dnn.py ├── avse3 │ ├── README.md │ ├── config.py │ ├── dataset.py │ ├── loss.py │ ├── model.py │ ├── model_utils │ │ ├── __init__.py │ │ ├── generic.py │ │ ├── nn.py │ │ └── visual.py │ ├── test.py │ ├── train.py │ └── utils.py └── avse4 │ ├── README.md │ ├── conf │ ├── eval.yaml │ └── train.yaml │ ├── dataset.py │ ├── model.py │ ├── test.py │ ├── train.py │ └── utils.py ├── data_preparation ├── avse1 │ ├── README.md │ ├── build_scenes.py │ ├── create_speech_maskers.py │ ├── data_config.yaml │ ├── prepare_avse1_data.py │ ├── requirements.txt │ ├── scene_builder_avse1.py │ ├── scene_renderer_avse1.py │ ├── setup_avse1_data.sh │ ├── speech_weight.mat │ └── utils.py └── avse4 │ ├── build_scenes.py │ ├── clarity │ └── data │ │ ├── HOA_tools_cec2.py │ │ ├── params │ │ └── speech_weight.mat │ │ ├── scene_builder_cec2.py │ │ ├── scene_renderer_cec2.py │ │ └── utils.py │ ├── config.yaml │ ├── create_speech_maskers.py │ ├── hydra │ └── launcher │ │ ├── cec2_submitit_local.yaml │ │ └── cec2_submitit_slurm.yaml │ ├── render_scenes.py │ └── setup_avsec4_data.sh ├── evaluation ├── avse1 │ ├── config.yaml │ └── objective_evaluation.py └── avse4 │ ├── config.yaml │ ├── mbstoi │ ├── __init__.py │ ├── mbstoi.py │ ├── mbstoi_utils.py │ └── parameters.yaml │ └── objective_evaluation.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Audio-Visual Speech Enhancement Challenge (AVSE) 2 | 3 | Human performance in everyday noisy situations is known to be dependent upon both aural and visual senses that are contextually combined by the brain’s multi-level integration strategies. The multimodal nature of speech is well established, with listeners known to unconsciously lip read to improve the intelligibility of speech in a real noisy environment. Studies in neuroscience have shown that the visual aspect of speech has a potentially strong impact on the ability of humans to focus their auditory attention on a particular stimulus. 4 | 5 | Over the last few decades, there have been major advances in machine learning applied to speech technology made possible by Machine Learning related Challenges including CHiME, REVERB, Blizzard, Clarity and Hurricane. However, the aforementioned challenges are based on single and multi-channel audio-only processing and have not exploited the multimodal nature of speech. The aim of this first audio visual (AV) speech enhancement challenge is to bring together the wider computer vision, hearing and speech research communities to explore novel approaches to multimodal speech-in-noise processing. 6 | 7 | In this repository, you will find code to support the AVSE Challenge, including the baseline and scripts for preparing the necessary data. 8 | 9 | More details can be found on the challenge website: 10 | https://challenge.cogmhear.org 11 | 12 | ## Announcements 13 | 14 | Any announcements about the challenge will be made in our mailing list (avse-challenge@mlist.is.ed.ac.uk). 15 | See [here](https://challenge.cogmhear.org/#/docs?id=announcements) on how to subscribe to it. 16 | 17 | ## Installation 18 | *Instructions to build data from previous AVSEC{1,2,3} editions are [here](data_preparation/avse1/)* 19 | 20 | **We are currently running the fourth edition of AVSEC** 21 | 22 | Follow instructions below to build the **AVSEC-4** dataset 23 | 24 | ```bash 25 | 26 | # Clone repository 27 | git clone https://github.com/cogmhear/avse_challenge.git 28 | cd avse_challenge 29 | 30 | # Create & activate environment with conda, see https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html 31 | conda create --name avse python=3.9 32 | conda activate avse 33 | 34 | # Install ffmpeg 2.8 35 | conda install -c rmg ffmpeg 36 | 37 | # Install requirements 38 | pip install -r requirements.txt 39 | ``` 40 | ## Data preparation 41 | 42 | These scripts should be run in a unix environment and require an installed version of the [ffmpeg](https://www.ffmpeg.org) tool (required version 2.8; see Installation for the correct installation command). 43 | 44 | 1) Download necessary data: 45 | 46 | - target videos: 47 | Lip Reading Sentences 3 (LRS3) Dataset 48 | https://mm.kaist.ac.kr/datasets/lip_reading/ 49 | 50 | Follow the instructions on the website to obtain credentials to download the videos. 51 | 52 | - Noise maskers and metadata (AVSEC-4): 53 | https://data.cstr.ed.ac.uk/cogmhear/protected/avsec4_data.tar [4.1GB] 54 | 55 | Please register for the AVSE challenge to obtain the download credentials: [registration form](https://challenge.cogmhear.org/#/getting-started/register) 56 | 57 | Noise maskers and metadata of previous editions are available [here](data_preparation/avse1/README.md) 58 | 59 | - Room simulation data and impulse responses from the [Clarity Challenge (CEC2)](https://github.com/claritychallenge/clarity/tree/main/recipes/cec2) and Head-Related Transfer Functions from [OlHeaD-HRTF Database](https://uol.de/mediphysik/downloads/hearingdevicehrtfs): 60 | https://data.cstr.ed.ac.uk/cogmhear/protected/clarity_cec2_data.tar [64GB] 61 | 62 |

AVSEC-4 uses a subset of the data released by the Clarity Enhancement Challenge 2 and a subset of HRTFs of the OlHeaD-HRTF Database from Oldenburg University. 63 | Download the tar file above to obtain HRTFs, room simulation data and resampled (16000 Hz) impulse responses.

64 | 65 | 66 | 2) Set up data structure and create speech maskers (see EDIT_THIS to change local paths): 67 | ```bash 68 | cd data_preparation/avse4 69 | ./setup_avsec4_data.sh 70 | ``` 71 | 72 | 3) Change root path defined in [data_preparation/avse4/config.yaml](data_preparation/avse4/config.yaml) to the location of the data. 73 | 74 | 4) Prepare noisy data: 75 | 76 | Data preparation scripts were adapted from original code by [Clarity Enhancement Challenge 2](https://github.com/claritychallenge/clarity/tree/main/recipes/cec2) under MIT License. 77 | 78 | ```bash 79 | cd data_preparation/avse4 80 | python build_scenes.py 81 | ``` 82 | Tu build data locally single-run: 83 | ```bash 84 | python render_scenes.py 85 | ``` 86 | Alternatively, if using multi-run: 87 | 88 | [//]: # (# python render_scenes.py 'render_starting_chunk=range(0, 494, 13)' --multirun ) 89 | ```bash 90 | #20 subjobs, starting in scene 0 and rendering 400 scenes 91 | python render_scenes.py 'render_starting_chunk=range(0, 400, 20)' --multirun 92 | ``` 93 | **Rendering binaural and/or monoaural signals** 94 | 95 | Scripts allow you to render binaural and monoaural signals. To choose which signals to render set the corresponding parameters in the [config](data_preparation/avse4/config.yaml) file to *True* for the set of signals you want to render: 96 | ```bash 97 | binaural_render: True 98 | monoaural_render: True 99 | ``` 100 | #### Data structure 101 | 102 | ```bash 103 | └── avsec4 104 | ├── dev 105 | │ ├── interferers 106 | │ ├── rooms 107 | │ │ ├─ ac [20 MB] 108 | │ │ ├─ HOA_IRs_16k [18.8 GB] 109 | │ │ ├─ rpf [79 MB] 110 | │ ├── scenes [12 GB] 111 | │ ├── targets 112 | │ └── targets_video 113 | ├── hrir 114 | │ ├─ HRIRs_MAT 115 | ├── maskers_music [607 MB] 116 | ├── maskers_noise [3.9 GB] 117 | ├── maskers_speech [5.3 GB] 118 | ├── metadata 119 | └── train 120 | │ ├── interferers 121 | │ ├── rooms 122 | │ │ ├─ ac [48 MB] 123 | │ │ ├─ HOA_IRs_16k [45.2 GB] 124 | │ │ ├─ rpf [189 MB] 125 | │ ├── scenes [141 GB] 126 | │ ├── targets 127 | │ └── targets_video 128 | ``` 129 | 130 | ## Baseline 131 | 132 | AVSEC-4 baseline coming soon (late March 2025) 133 | 134 | [//]: # ([code](./baseline/avse1/)) 135 | 136 | [//]: # () 137 | [//]: # ([pretrained_model](https://data.cstr.ed.ac.uk/cogmhear/protected/avse1_baseline.ckpt)) 138 | 139 | The credentials to download the pretrained model are the same as the ones used to download the noise maskers and the metadata. 140 | 141 | ## Evaluation 142 | 143 | **Binaural signals** 144 | 145 | We provide a script to compute MBSTOI from binaural signals. We use MBSTOI scripts from the [Clarity Challenge](https://github.com/claritychallenge/clarity/tree/main/clarity/evaluator/mbstoi). The original MBSTOI Matlab implementation is available [here.](http://ah-andersen.net/code/) 146 | 147 | ``` 148 | cd evaluation/avse4/ 149 | python objective_evaluation.py 150 | ``` 151 | Note: before running this script please edit the paths and file name formats defined in evaluation/avse1/config.yaml (see EDIT_THIS). 152 | 153 | **Monophonic signals** 154 | 155 | To compute objective metrics using monophonic signals (i.e., STOI and PESQ) please use evaluation scripts from in AVSEC-1. 156 | 157 | ``` 158 | cd evaluation/avse1/ 159 | python objective_evaluation.py 160 | ``` 161 | that require the following libraries: 162 | ``` 163 | pip install pystoi==0.3.3 164 | pip install pesq==0.0.4 165 | ``` 166 | 167 | ## Challenges 168 | 169 | Current challenge 170 | 171 | - The 4th Audio-Visual Speech Enhancement Challenge (AVSEC-4) 172 | [data_preparation](./data_preparation/avse4/) 173 | [baseline](./baseline/avse4/) -TBA 174 | [evaluation](./evaluation/avse4/) 175 | 176 | ## License 177 | 178 | Videos are derived from: 179 | - [LRS3 dataset](https://mm.kaist.ac.kr/datasets/lip_reading/) 180 | Creative Commons BY-NC-ND 4.0 license. 181 | 182 | Interferers are derived from: 183 | - [Clarity Enhancement Challenge (CEC1)](https://github.com/claritychallenge/clarity/tree/main/recipes/cec1) 184 | Creative Commons Attribution Share Alike 4.0 International. 185 | 186 | - [DNS Challenge second edition](https://github.com/microsoft/DNS-Challenge). 187 | Only Freesound clips were selected 188 | Creative Commons 0 License. 189 | 190 | - [LRS3 dataset](https://mm.kaist.ac.kr/datasets/lip_reading/) 191 | Creative Commons BY-NC-ND 4.0 license. 192 | 193 | - [MedleyDB audio](https://medleydb.weebly.com/) 194 | The dataset is licensed under CC BY-NC-SA 4.0. 195 | 196 | Impulse responses and room simulation data derived from: 197 | - [Clarity Enhancement Challenge (CEC2)](https://github.com/claritychallenge/clarity/tree/main/recipes/cec2) 198 | The dataset is licensed under CC BY-SA 4.0. 199 | 200 | Head-Related Transfer Functions derived from: 201 | - [OlHeaD-HRTF Database](https://uol.de/mediphysik/downloads/hearingdevicehrtfs): 202 | The dataset is licensed under CC BY-NC-SA 4.0. 203 | 204 | Scripts: 205 | 206 | Data preparation scripts were adapted from original code by [Clarity Enhancement Challenge 2](https://github.com/claritychallenge/clarity/tree/main/recipes/cec2). Modifications include: extracting target audio from video and different settings for sampling rate (16kHz), no random starting time for target speaker and no head rotations. 207 | 208 | 209 | -------------------------------------------------------------------------------- /baseline/avse1/README.md: -------------------------------------------------------------------------------- 1 | # Baseline model for 1st COG-MHEAR Audio-Visual Speech Enhancement Challenge (AVSE) 2 | 3 | [Challenge link](https://challenge.cogmhear.org/) 4 | 5 | ## Requirements 6 | * Python >= 3.5 (3.6 recommended) 7 | 8 | You can install all requirements using 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Usage 15 | Update config.py with your dataset path 16 | 17 | ### Train 18 | ```bash 19 | python train.py --log_dir ./logs --a_only False --gpu 1 --max_epochs 15 --loss l1 20 | ``` 21 | 22 | ### Model evaluation - dev set 23 | ```bash 24 | python test.py --ckpt_path MODEL_CKPT_PATH --save_root SAVE_ROOT --model_uid baseline --dev_set True --test_set False --cpu False 25 | ``` 26 | 27 | ### Model evaluation - test set 28 | Extract `avse1_evalset.tar` to `$DATA_ROOT/test/scenes` 29 | 30 | ```bash 31 | python test.py --ckpt_path MODEL_CKPT_PATH --save_root SAVE_ROOT --model_uid baseline --dev_set False --test_set True --cpu False 32 | ``` 33 | 34 | -------------------------------------------------------------------------------- /baseline/avse1/config.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import windows as w 2 | 3 | SEED = 999999 4 | dB_levels = [0, 3, 6, 9] 5 | sampling_rate = 16000 6 | img_rows, img_cols = 224, 224 7 | windows = w.hann 8 | # windows = w.hamming 9 | 10 | max_frames = 75 11 | stft_size = 512 12 | window_size = 512 13 | window_shift = 128 14 | window_length = None 15 | fading = False 16 | 17 | max_utterance_length = 48000 18 | num_frames = int(25 * (max_utterance_length / 16000)) 19 | num_stft_frames = 376#int((max_utterance_length - window_size + window_shift) / window_shift) 20 | 21 | nb_channels, img_height, img_width = 1, img_rows, img_cols 22 | DATA_ROOT = "/home/mgo/Documents/data/avse1_data/" 23 | METADATA_ROOT = "/home/mgo/Documents/data/avse1_data/metadata/" -------------------------------------------------------------------------------- /baseline/avse1/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mandar Gogate, All rights reserved. 2 | import json 3 | import logging 4 | import os 5 | import random 6 | from os.path import join, isfile 7 | 8 | import imageio 9 | import librosa 10 | import numpy as np 11 | import torch 12 | import torchvision.transforms as transforms 13 | from decord import VideoReader 14 | from decord import cpu 15 | from pytorch_lightning import LightningDataModule 16 | from scipy.io import wavfile 17 | from torch.utils.data import Dataset 18 | from tqdm import tqdm 19 | 20 | from config import * 21 | from utils.generic import subsample_list 22 | 23 | 24 | def get_images(mp4_file): 25 | data = [np.array(img)[np.newaxis, ...] for img in imageio.mimread(mp4_file)] 26 | return np.concatenate(data, axis=0) 27 | 28 | 29 | def get_transform(): 30 | transform_list = [transforms.ToTensor()] 31 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 32 | return transforms.Compose(transform_list) 33 | 34 | 35 | test_transform = get_transform() 36 | 37 | 38 | class TEDDataset(Dataset): 39 | def __init__(self, scenes_root, shuffle=True, seed=SEED, subsample=1, mask_type="IRM", 40 | add_channel_dim=True, a_only=True, return_stft=False, 41 | clipped_batch=True, sample_items=True): 42 | self.clipped_batch = clipped_batch 43 | self.scenes_root = scenes_root 44 | self.return_stft = return_stft 45 | self.a_only = a_only 46 | self.add_channel_dim = add_channel_dim 47 | self.files_list = self.build_files_list 48 | self.mask_type = mask_type.lower() 49 | self.rgb = True if nb_channels == 3 else False 50 | if shuffle: 51 | random.seed(SEED) 52 | random.shuffle(self.files_list) 53 | if subsample != 1: 54 | self.files_list = subsample_list(self.files_list, sample_rate=subsample) 55 | logging.info("Found {} utterances".format(len(self.files_list))) 56 | self.data_count = len(self.files_list) 57 | self.batch_index = 0 58 | self.total_batches_seen = 0 59 | self.batch_input = {"noisy": None} 60 | self.index = 0 61 | self.max_len = len(self.files_list) 62 | self.max_cache = 0 63 | self.seed = seed 64 | self.window = "hann" 65 | self.fading = False 66 | self.sample_items = sample_items 67 | 68 | @property 69 | def build_files_list(self): 70 | files_list = [] 71 | for file in os.listdir(self.scenes_root): 72 | if file.endswith("mixed.wav"): 73 | files_list.append((join(self.scenes_root, file.replace("mixed", "target")), 74 | join(self.scenes_root, file.replace("mixed", "interferer")), 75 | join(self.scenes_root, file), 76 | join(self.scenes_root, file.replace("_mixed.wav", "_silent.mp4")), 77 | )) 78 | return files_list 79 | 80 | def __len__(self): 81 | return len(self.files_list) 82 | 83 | def __getitem__(self, idx): 84 | data = {} 85 | if self.sample_items: 86 | clean_file, noise_file, noisy_file, mp4_file = random.sample(self.files_list, 1)[0] 87 | else: 88 | clean_file, noise_file, noisy_file, mp4_file = self.files_list[idx] 89 | if self.a_only: 90 | if self.return_stft: 91 | data["noisy_audio_spec"], data["mask"], data["clean"], data["noisy_stft"] = self.get_data(clean_file, 92 | noise_file, 93 | noisy_file, 94 | mp4_file) 95 | else: 96 | data["noisy_audio_spec"], data["mask"] = self.get_data(clean_file, noise_file, noisy_file, mp4_file) 97 | else: 98 | if self.return_stft: 99 | data["noisy_audio_spec"], data["mask"], data["clean"], data["noisy_stft"], data[ 100 | "lip_images"] = self.get_data(clean_file, 101 | noise_file, 102 | noisy_file, 103 | mp4_file) 104 | else: 105 | data["noisy_audio_spec"], data["mask"], data["lip_images"] = self.get_data(clean_file, noise_file, 106 | noisy_file, mp4_file) 107 | 108 | data['scene'] = clean_file.replace(self.scenes_root,"").replace("_target.wav","").replace("/","") 109 | 110 | return data 111 | 112 | def get_noisy_features(self, noisy): 113 | audio_stft = librosa.stft(noisy, win_length=window_size, n_fft=stft_size, hop_length=window_shift, 114 | window=self.window, center=True).T 115 | if self.add_channel_dim: 116 | return np.abs(audio_stft).astype(np.float32)[np.newaxis, ...] 117 | else: 118 | return np.abs(audio_stft).astype(np.float32) 119 | 120 | def load_wav(self, wav_path): 121 | return wavfile.read(wav_path)[1].astype(np.float32) / (2 ** 15) 122 | 123 | def get_data(self, clean_file, noise_file, noisy_file, mp4_file): 124 | noisy = self.load_wav(noisy_file) 125 | if isfile(clean_file): 126 | clean = self.load_wav(clean_file) 127 | else: 128 | clean = np.zeros(noisy.shape) 129 | # noise, _ = librosa.load(noise_file, sr=None) 130 | if self.clipped_batch: 131 | if clean.shape[0] > 48000: 132 | clip_idx = random.randint(0, clean.shape[0] - 48000) 133 | video_idx = max(int((clip_idx / 16000) * 25) - 2, 0) 134 | clean = clean[clip_idx:clip_idx + 48000] 135 | noisy = noisy[clip_idx:clip_idx + 48000] 136 | # noise = noise[clip_idx:clip_idx + 48000] 137 | else: 138 | video_idx = -1 139 | clean = np.pad(clean, pad_width=[0, 48000 - clean.shape[0]], mode="constant") 140 | noisy = np.pad(noisy, pad_width=[0, 48000 - noisy.shape[0]], mode="constant") 141 | # noise = np.pad(noise, pad_width=[0, 48000 - noise.shape[0]], mode="constant") 142 | if not self.a_only: 143 | vr = VideoReader(mp4_file, ctx=cpu(0)) 144 | 145 | if not self.clipped_batch: 146 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 147 | else: 148 | if len(vr) < 75: 149 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 150 | frames = np.concatenate((frames, np.zeros((75 - len(vr), 224, 224, 3)).astype(frames.dtype)), axis=0) 151 | else: 152 | frames = vr.get_batch(list(range(video_idx, video_idx + 75))).asnumpy() 153 | frames = np.moveaxis(frames, -1, 0) 154 | if self.return_stft: 155 | clean_audio = clean 156 | noisy_stft = librosa.stft(noisy, win_length=window_size, n_fft=stft_size, hop_length=window_shift, 157 | window=self.window, center=True).T 158 | if self.a_only: 159 | return self.get_noisy_features(noisy), self.get_noisy_features( 160 | clean), clean_audio, noisy_stft 161 | else: 162 | return self.get_noisy_features(noisy), self.get_noisy_features( 163 | clean), clean_audio, noisy_stft, frames 164 | else: 165 | if self.a_only: 166 | return self.get_noisy_features(noisy), self.get_noisy_features(clean) 167 | else: 168 | return self.get_noisy_features(noisy), self.get_noisy_features(clean), frames 169 | 170 | 171 | class TEDDataModule(LightningDataModule): 172 | def __init__(self, batch_size=16, mask="IRM", add_channel_dim=True, a_only=False): 173 | super(TEDDataModule, self).__init__() 174 | self.train_dataset_batch = TEDDataset(join(DATA_ROOT, "train/scenes"), mask_type=mask, 175 | add_channel_dim=add_channel_dim, a_only=a_only) 176 | self.dev_dataset_batch = TEDDataset(join(DATA_ROOT, "dev/scenes"), mask_type=mask, 177 | add_channel_dim=add_channel_dim, a_only=a_only) 178 | self.dev_dataset = TEDDataset(join(DATA_ROOT, "dev/scenes"), mask_type=mask, 179 | add_channel_dim=add_channel_dim, a_only=a_only, return_stft=True, 180 | clipped_batch=False, sample_items=False) 181 | self.test_dataset = TEDDataset(join(DATA_ROOT, "test/scenes"), mask_type=mask, 182 | add_channel_dim=add_channel_dim, a_only=a_only, return_stft=True, 183 | clipped_batch=False, sample_items=False) 184 | self.batch_size = batch_size 185 | 186 | def train_dataloader(self): 187 | return torch.utils.data.DataLoader(self.train_dataset_batch, batch_size=self.batch_size, num_workers=4, 188 | pin_memory=True, persistent_workers=True) 189 | 190 | def val_dataloader(self): 191 | return torch.utils.data.DataLoader(self.dev_dataset_batch, batch_size=self.batch_size, num_workers=4, pin_memory=True, 192 | persistent_workers=True) 193 | 194 | def test_dataloader(self): 195 | return torch.utils.data.DataLoader(self.dev_dataset, batch_size=self.batch_size, num_workers=4) 196 | 197 | 198 | if __name__ == '__main__': 199 | 200 | dataset = TEDDataset(scenes_root=join(DATA_ROOT, "dev/scenes"), 201 | mask_type="mag", a_only=False, return_stft=True, clipped_batch=False, sample_items=False) 202 | print(dataset.files_list[:2]) 203 | for i in tqdm(range(len(dataset)), ascii=True): 204 | data = dataset[i] 205 | -------------------------------------------------------------------------------- /baseline/avse1/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.5.4 2 | torchvision==0.2.2 3 | decord==0.6.0 4 | librosa==0.8.1 5 | torch==1.9.0 6 | pandas==1.1.5 7 | imageio==2.5.0 8 | tqdm==4.62.1 9 | pytorch_lightning==1.4.9 10 | SoundFile==0.10.2 11 | pypesq==1.2.4 12 | GitPython==3.1.27 13 | numpy==1.22.4 14 | -------------------------------------------------------------------------------- /baseline/avse1/test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os import makedirs 3 | from os.path import isfile, join 4 | 5 | import librosa 6 | import numpy as np 7 | import soundfile as sf 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from torch.nn import functional as F 12 | 13 | from config import sampling_rate, window_shift, window_size 14 | 15 | from dataset import TEDDataModule 16 | from model import AVNet, FusionNet, build_audiofeat_net, build_visualfeat_net 17 | from utils.generic import str2bool 18 | 19 | 20 | def main(args): 21 | clean_root = join(args.save_root, "clean") 22 | noisy_root = join(args.save_root, "noisy") 23 | enhanced_root = join(args.save_root, args.model_uid) 24 | makedirs(args.save_root, exist_ok=True) 25 | makedirs(clean_root, exist_ok=True) 26 | makedirs(noisy_root, exist_ok=True) 27 | makedirs(enhanced_root, exist_ok=True) 28 | datamodule = TEDDataModule(batch_size=args.batch_size, mask=args.mask, a_only=args.a_only) 29 | if args.dev_set and args.test_set: 30 | raise RuntimeError("Select either dev set or test set") 31 | elif args.dev_set: 32 | dataset = datamodule.dev_dataset 33 | elif args.test_set: 34 | dataset = datamodule.test_dataset 35 | else: 36 | raise RuntimeError("Select one of dev set and test set") 37 | print(args.oracle, not args.oracle) 38 | if not args.oracle: 39 | audiofeat_net = build_audiofeat_net(a_only=args.a_only) 40 | if not args.a_only: 41 | visual_net = build_visualfeat_net(extract_feats=True) 42 | else: 43 | visual_net = None 44 | fusion_net = FusionNet(a_only=args.a_only, mask=args.mask) 45 | print("Loading model components", args.ckpt_path) 46 | if args.ckpt_path.endswith("ckpt") and isfile(args.ckpt_path): 47 | model = AVNet.load_from_checkpoint(args.ckpt_path, nets=(visual_net, audiofeat_net, fusion_net), 48 | loss=args.loss, args=args, 49 | a_only=args.a_only) 50 | print("Model loaded") 51 | else: 52 | raise FileNotFoundError("Cannot load model weights: {}".format(args.ckpt_path)) 53 | if not args.cpu: 54 | model.to("cuda:0") 55 | model.eval() 56 | i = 0 57 | with torch.no_grad(): 58 | for i in tqdm(range(len(dataset))): 59 | 60 | data = dataset[i] 61 | 62 | filename = f"{data['scene']}.wav" 63 | # filename = f"{str(i).zfill(5)}.wav" 64 | clean_path = join(clean_root, filename) 65 | noisy_path = join(noisy_root, filename) 66 | enhanced_path = join(enhanced_root, filename) 67 | 68 | if not isfile(clean_path) and not args.test_set: 69 | sf.write(clean_path, data["clean"], samplerate=sampling_rate) 70 | if not isfile(noisy_path): 71 | noisy = librosa.istft(data["noisy_stft"].T, win_length=window_size, hop_length=window_shift, 72 | window="hann", length=len(data["clean"])) 73 | sf.write(noisy_path, noisy, samplerate=sampling_rate) 74 | if not isfile(enhanced_path): 75 | if args.oracle: 76 | pred_mag = np.abs(data["noisy_stft"]) * data["mask"].T 77 | i += 1 78 | else: 79 | inputs = {"noisy_audio_spec": torch.from_numpy(data["noisy_audio_spec"][np.newaxis, ...]).to( 80 | model.device)} 81 | if not args.a_only: 82 | inputs["lip_images"] = torch.from_numpy(data["lip_images"][np.newaxis, ...]).to(model.device) 83 | pred = model(inputs).cpu() 84 | pred_mag = pred.numpy()[0][0] 85 | noisy_phase = np.angle(data["noisy_stft"]) 86 | estimated = pred_mag * (np.cos(noisy_phase) + 1.j * np.sin(noisy_phase)) 87 | estimated_audio = librosa.istft(estimated.T, win_length=window_size, hop_length=window_shift, 88 | window="hann", length=len(data["clean"])) 89 | sf.write(enhanced_path, estimated_audio, samplerate=sampling_rate) 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = ArgumentParser() 94 | parser.add_argument("--a_only", type=str2bool, required=False) 95 | parser.add_argument("--ckpt_path", type=str, required=True) 96 | parser.add_argument("--oracle", type=str2bool, required=False) 97 | parser.add_argument("--save_root", type=str, required=True) 98 | parser.add_argument("--model_uid", type=str, required=True) 99 | parser.add_argument("--dev_set", type=str2bool, required=True) 100 | parser.add_argument("--test_set", type=str2bool, required=False) 101 | parser.add_argument("--cpu", type=str2bool, required=False, help="Evaluate model on CPU") 102 | parser.add_argument("--mask", type=str, default="mag") 103 | parser.add_argument("--batch_size", type=int, default=16) 104 | parser.add_argument("--loss", type=str, default="l1") 105 | args = parser.parse_args() 106 | main(args) 107 | -------------------------------------------------------------------------------- /baseline/avse1/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataset import TEDDataModule 4 | from model import AVNet, FusionNet, build_audiofeat_net, build_visualfeat_net 5 | 6 | SEED = 1143 7 | # fix random seeds for reproducibility 8 | torch.manual_seed(SEED) 9 | torch.backends.cudnn.deterministic = False 10 | torch.backends.cudnn.benchmark = True 11 | np.random.seed(SEED) 12 | 13 | from argparse import ArgumentParser 14 | 15 | from pytorch_lightning import Trainer 16 | from pytorch_lightning.callbacks import ModelCheckpoint 17 | 18 | from utils.generic import str2bool 19 | 20 | 21 | def main(args): 22 | checkpoint_callback = ModelCheckpoint(monitor="val_loss_epoch") 23 | datamodule = TEDDataModule(batch_size=args.batch_size, mask=args.mask, a_only=args.a_only) 24 | audiofeat_net = build_audiofeat_net(a_only=args.a_only) 25 | visual_net = build_visualfeat_net(extract_feats=True) 26 | fusion_net = FusionNet(a_only=args.a_only, mask=args.mask) 27 | 28 | if args.a_only: 29 | model = AVNet((None, audiofeat_net, fusion_net), args.loss, a_only=args.a_only, 30 | val_dataset=datamodule.dev_dataset) 31 | else: 32 | model = AVNet((visual_net, audiofeat_net, fusion_net), args.loss, a_only=args.a_only, 33 | val_dataset=datamodule.dev_dataset) 34 | trainer = Trainer.from_argparse_args(args, default_root_dir=args.log_dir, callbacks=[checkpoint_callback]) 35 | if args.tune: 36 | trainer.tune(model, datamodule) 37 | else: 38 | trainer.fit(model, datamodule) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = ArgumentParser() 43 | parser.add_argument("--a_only", type=str2bool, default=False) 44 | parser.add_argument("--tune", type=str2bool, default=False) 45 | parser.add_argument("--batch_size", type=int, default=8) 46 | parser.add_argument("--lr", type=float, default=0.00158) 47 | parser.add_argument("--log_dir", type=str, required=True) 48 | parser.add_argument("--loss", type=str, default="l1") 49 | parser.add_argument("--mask", type=str, default="mag") 50 | parser = Trainer.add_argparse_args(parser) 51 | args = parser.parse_args() 52 | main(args) 53 | -------------------------------------------------------------------------------- /baseline/avse1/utils/generic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mandar Gogate, All rights reserved. 2 | import argparse 3 | import json 4 | import logging 5 | import os 6 | import random 7 | import subprocess 8 | import sys 9 | import tempfile 10 | from collections import OrderedDict 11 | from datetime import datetime 12 | from itertools import repeat 13 | from os import makedirs 14 | from os.path import isdir, isfile, join 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | import pandas as pd 19 | 20 | 21 | def load_json(json_fp: str): 22 | with open(json_fp, 'r') as f: 23 | json_content = json.load(f) 24 | return json_content 25 | 26 | 27 | def subsample_list(inp_list: list, sample_rate: float): 28 | random.shuffle(inp_list) 29 | return [inp_list[i] for i in range(int(len(inp_list) * sample_rate))] 30 | 31 | 32 | def tempdir() -> str: 33 | return tempfile.gettempdir() 34 | 35 | 36 | def ensure_exists(path: str): 37 | makedirs(path, exist_ok=True) 38 | 39 | 40 | def multicore_processing(func, parameters: list, processes=None): 41 | from multiprocessing import Pool 42 | pool = Pool(processes=processes) 43 | result = pool.map(func, parameters) 44 | pool.close() 45 | pool.join() 46 | return result 47 | 48 | 49 | def config_logging(level=logging.INFO): 50 | formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', 51 | datefmt='%Y-%m-%d %H:%M:%S') 52 | logger = logging.getLogger() 53 | logger.setLevel(level) 54 | screen_handler = logging.StreamHandler(stream=sys.stdout) 55 | screen_handler.setFormatter(formatter) 56 | logger.addHandler(screen_handler) 57 | return logger 58 | 59 | 60 | def get_files(files_root: str) -> list: 61 | assert isdir(files_root) 62 | return [ 63 | join(files_root, file) 64 | for file in sorted(os.listdir(files_root)) 65 | ] 66 | 67 | 68 | def shuffle_arr(a, b): 69 | if a.shape[0] != b.shape[0]: raise RuntimeError("Both arrays should have the same elements in the first axis") 70 | p = np.random.permutation(len(a)) 71 | return a[p], b[p] 72 | 73 | 74 | def shuffle_lists(a: list, b: list, seed: int = None): 75 | c = list(zip(a, b)) 76 | if seed is not None: 77 | random.seed(seed) 78 | random.shuffle(c) 79 | a, b = zip(*c) 80 | return a, b 81 | 82 | 83 | def get_utc_time(): 84 | return datetime.utcnow().strftime("%Y-%m-%d_%H:%M:%S") 85 | 86 | 87 | def str2bool(v: str): 88 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 89 | return True 90 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 91 | return False 92 | else: 93 | raise argparse.ArgumentTypeError('Boolean value expected.') 94 | 95 | 96 | def subsample_list(inp_list: list, sample_rate: float): 97 | random.shuffle(inp_list) 98 | return [inp_list[i] for i in range(int(len(inp_list) * sample_rate))] 99 | 100 | 101 | def save_dict(path: str, dict_obj: dict): 102 | with open(path, "w") as f: 103 | json.dump(dict_obj, f, indent=4, sort_keys=True) 104 | 105 | 106 | class DisablePrint: 107 | def __enter__(self): 108 | self._original_stdout = sys.stdout 109 | sys.stdout = open(os.devnull, 'w') 110 | 111 | def __exit__(self, exc_type, exc_val, exc_tb): 112 | sys.stdout.close() 113 | sys.stdout = self._original_stdout 114 | 115 | 116 | def check_repo(): 117 | from git import Repo 118 | repo = Repo(os.getcwd()) 119 | assert not repo.is_dirty(), "Please commit the changes and then run the code" 120 | 121 | 122 | def save_json(json_data, json_path, overwrite=True): 123 | if isfile(json_path) and not overwrite: 124 | raise Exception("JSON path: {} already exists".format(json_path)) 125 | with open(json_path, "w") as f: 126 | json.dump(json_data, f, indent=4, sort_keys=True) 127 | 128 | 129 | def execute(command): 130 | subprocess.call(command, shell=True, stdout=None) 131 | 132 | 133 | def inf_loop(data_loader): 134 | """ wrapper function for endless data loader. """ 135 | for loader in repeat(data_loader): 136 | yield from loader 137 | 138 | 139 | class MetricTracker: 140 | def __init__(self, *keys, writer=None): 141 | self.writer = writer 142 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 143 | self.reset() 144 | 145 | def reset(self): 146 | for col in self._data.columns: 147 | self._data[col].values[:] = 0 148 | 149 | def update(self, key, value, n=1): 150 | if self.writer is not None: 151 | self.writer.add_scalar(key, value) 152 | self._data.total[key] += value * n 153 | self._data.counts[key] += n 154 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 155 | 156 | def avg(self, key): 157 | return self._data.average[key] 158 | 159 | def result(self): 160 | return dict(self._data.average) 161 | 162 | 163 | def read_json(fname): 164 | fname = Path(fname) 165 | with fname.open('rt') as handle: 166 | return json.load(handle, object_hook=OrderedDict) 167 | 168 | 169 | def write_json(content, fname): 170 | fname = Path(fname) 171 | with fname.open('wt') as handle: 172 | json.dump(content, handle, indent=4, sort_keys=False) 173 | 174 | 175 | def multicore_processing(func, parameters: list, processes=None): 176 | from multiprocessing import Pool 177 | pool = Pool(processes=processes) 178 | result = pool.map(func, parameters) 179 | pool.close() 180 | pool.join() 181 | return result 182 | -------------------------------------------------------------------------------- /baseline/avse1/utils/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.tcn import MultibranchTemporalConvNet, TemporalConvNet 5 | 6 | 7 | def unet_conv(input_nc, output_nc, norm_layer=nn.BatchNorm2d): 8 | downconv = nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1) 9 | downrelu = nn.LeakyReLU(0.2, True) 10 | downnorm = norm_layer(output_nc) 11 | return nn.Sequential(*x[downconv, downnorm, downrelu]) 12 | 13 | 14 | def unet_upconv(input_nc, output_nc, outermost=False, norm_layer=nn.BatchNorm2d, kernel_size=4): 15 | upconv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=kernel_size, stride=2, padding=1) 16 | uprelu = nn.ReLU(True) 17 | upnorm = norm_layer(output_nc) 18 | if not outermost: 19 | return nn.Sequential(*[upconv, upnorm, uprelu]) 20 | else: 21 | return nn.Sequential(*[upconv]) 22 | 23 | 24 | class conv_block(nn.Module): 25 | def __init__(self, ch_in, ch_out): 26 | super(conv_block, self).__init__() 27 | self.conv = nn.Sequential( 28 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 29 | nn.BatchNorm2d(ch_out), 30 | nn.LeakyReLU(0.2, True), 31 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 32 | nn.BatchNorm2d(ch_out), 33 | nn.LeakyReLU(0.2, True) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | return x 39 | 40 | 41 | class up_conv(nn.Module): 42 | def __init__(self, ch_in, ch_out, outermost=False): 43 | super(up_conv, self).__init__() 44 | if not outermost: 45 | self.up = nn.Sequential( 46 | nn.Upsample(scale_factor=(2., 1.)), 47 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 48 | nn.BatchNorm2d(ch_out), 49 | nn.ReLU(inplace=True) 50 | ) 51 | else: 52 | self.up = nn.Sequential( 53 | nn.Upsample(scale_factor=(2., 1.)), 54 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 55 | nn.Sigmoid() 56 | ) 57 | 58 | def forward(self, x): 59 | x = self.up(x) 60 | return x 61 | 62 | 63 | def weights_init(m): 64 | classname = m.__class__.__name__ 65 | if classname.find('Conv') != -1: 66 | m.weight.data.normal_(0.0, 0.02) 67 | elif classname.find('BatchNorm') != -1: 68 | m.weight.data.normal_(1.0, 0.02) 69 | m.bias.data.fill_(0) 70 | elif classname.find('Linear') != -1: 71 | m.weight.data.normal_(0.0, 0.02) 72 | 73 | 74 | def threeD_to_2D_tensor(x): 75 | n_batch, n_channels, s_time, sx, sy = x.shape 76 | x = x.transpose(1, 2) 77 | return x.reshape(n_batch * s_time, n_channels, sx, sy) 78 | 79 | 80 | def _average_batch(x, lengths, B): 81 | return torch.stack([torch.mean(x[index][:, 0:i], 1) for index, i in enumerate(lengths)], 0) 82 | 83 | 84 | class MultiscaleMultibranchTCN(nn.Module): 85 | def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False): 86 | super(MultiscaleMultibranchTCN, self).__init__() 87 | 88 | self.kernel_sizes = tcn_options['kernel_size'] 89 | self.num_kernels = len(self.kernel_sizes) 90 | 91 | self.mb_ms_tcn = MultibranchTemporalConvNet(input_size, num_channels, tcn_options, dropout=dropout, relu_type=relu_type, dwpw=dwpw) 92 | self.tcn_output = nn.Linear(num_channels[-1], num_classes) 93 | 94 | self.consensus_func = _average_batch 95 | 96 | def forward(self, x, lengths, B, extract_feats=False): 97 | # x needs to have dimension (N, C, L) in order to be passed into CNN 98 | xtrans = x.transpose(1, 2) 99 | if extract_feats: 100 | out = self.mb_ms_tcn(xtrans) 101 | return out.unsqueeze(-2) 102 | else: 103 | return xtrans.unsqueeze(-2) 104 | 105 | 106 | class TCN(nn.Module): 107 | """Implements Temporal Convolutional Network (TCN) 108 | __https://arxiv.org/pdf/1803.01271.pdf 109 | """ 110 | 111 | def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False): 112 | super(TCN, self).__init__() 113 | self.tcn_trunk = TemporalConvNet(input_size, num_channels, dropout=dropout, tcn_options=tcn_options, relu_type=relu_type, dwpw=dwpw) 114 | self.tcn_output = nn.Linear(num_channels[-1], num_classes) 115 | 116 | self.consensus_func = _average_batch 117 | 118 | self.has_aux_losses = False 119 | 120 | def forward(self, x, lengths, B, extract_feats=False): 121 | # x needs to have dimension (N, C, L) in order to be passed into CNN 122 | if extract_feats: 123 | x = self.tcn_trunk(x.transpose(1, 2)) 124 | return x.unsqueeze(-2) 125 | else: 126 | return x.unsqueeze(-2) 127 | # x = self.consensus_func(x, lengths, B) 128 | # return self.tcn_output(x) 129 | -------------------------------------------------------------------------------- /baseline/avse1/utils/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | def downsample_basic_block(inplanes, outplanes, stride): 12 | return nn.Sequential( 13 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), 14 | nn.BatchNorm2d(outplanes), 15 | ) 16 | 17 | 18 | def downsample_basic_block_v2(inplanes, outplanes, stride): 19 | return nn.Sequential( 20 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), 21 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 22 | nn.BatchNorm2d(outplanes), 23 | ) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'): 30 | super(BasicBlock, self).__init__() 31 | 32 | assert relu_type in ['relu', 'prelu'] 33 | 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | 37 | # type of ReLU is an input option 38 | if relu_type == 'relu': 39 | self.relu1 = nn.ReLU(inplace=True) 40 | self.relu2 = nn.ReLU(inplace=True) 41 | elif relu_type == 'prelu': 42 | self.relu1 = nn.PReLU(num_parameters=planes) 43 | self.relu2 = nn.PReLU(num_parameters=planes) 44 | else: 45 | raise Exception('relu type not implemented') 46 | # -------- 47 | 48 | self.conv2 = conv3x3(planes, planes) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | residual = x 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu1(out) 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu2(out) 66 | 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | 72 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False): 73 | self.inplanes = 64 74 | self.relu_type = relu_type 75 | self.gamma_zero = gamma_zero 76 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block 77 | 78 | super(ResNet, self).__init__() 79 | self.layer1 = self._make_layer(block, 64, layers[0]) 80 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 81 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 82 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 83 | self.avgpool = nn.AdaptiveAvgPool2d(1) 84 | 85 | # default init 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 89 | m.weight.data.normal_(0, math.sqrt(2. / n)) 90 | elif isinstance(m, nn.BatchNorm2d): 91 | m.weight.data.fill_(1) 92 | m.bias.data.zero_() 93 | # nn.init.ones_(m.weight) 94 | # nn.init.zeros_(m.bias) 95 | 96 | if self.gamma_zero: 97 | for m in self.modules(): 98 | if isinstance(m, BasicBlock): 99 | m.bn2.weight.data.zero_() 100 | 101 | def _make_layer(self, block, planes, blocks, stride=1): 102 | 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = self.downsample_block(inplanes=self.inplanes, 106 | outplanes=planes * block.expansion, 107 | stride=stride) 108 | 109 | layers = [] 110 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type)) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.layer1(x) 119 | x = self.layer2(x) 120 | x = self.layer3(x) 121 | x = self.layer4(x) 122 | x = self.avgpool(x) 123 | x = x.view(x.size(0), -1) 124 | return x 125 | -------------------------------------------------------------------------------- /baseline/avse1/utils/tcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm 4 | 5 | 6 | """Implements Temporal Convolutional Network (TCN) 7 | 8 | __https://arxiv.org/pdf/1803.01271.pdf 9 | """ 10 | 11 | class Chomp1d(nn.Module): 12 | def __init__(self, chomp_size, symm_chomp): 13 | super(Chomp1d, self).__init__() 14 | self.chomp_size = chomp_size 15 | self.symm_chomp = symm_chomp 16 | if self.symm_chomp: 17 | assert self.chomp_size % 2 == 0, "If symmetric chomp, chomp size needs to be even" 18 | def forward(self, x): 19 | if self.chomp_size == 0: 20 | return x 21 | if self.symm_chomp: 22 | return x[:, :, self.chomp_size//2:-self.chomp_size//2].contiguous() 23 | else: 24 | return x[:, :, :-self.chomp_size].contiguous() 25 | 26 | 27 | class ConvBatchChompRelu(nn.Module): 28 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, relu_type, dwpw=False): 29 | super(ConvBatchChompRelu, self).__init__() 30 | self.dwpw = dwpw 31 | if dwpw: 32 | self.conv = nn.Sequential( 33 | # -- dw 34 | nn.Conv1d( n_inputs, n_inputs, kernel_size, stride=stride, 35 | padding=padding, dilation=dilation, groups=n_inputs, bias=False), 36 | nn.BatchNorm1d(n_inputs), 37 | Chomp1d(padding, True), 38 | nn.PReLU(num_parameters=n_inputs) if relu_type == 'prelu' else nn.ReLU(inplace=True), 39 | # -- pw 40 | nn.Conv1d( n_inputs, n_outputs, 1, 1, 0, bias=False), 41 | nn.BatchNorm1d(n_outputs), 42 | nn.PReLU(num_parameters=n_outputs) if relu_type == 'prelu' else nn.ReLU(inplace=True) 43 | ) 44 | else: 45 | self.conv = nn.Conv1d(n_inputs, n_outputs, kernel_size, 46 | stride=stride, padding=padding, dilation=dilation) 47 | self.batchnorm = nn.BatchNorm1d(n_outputs) 48 | self.chomp = Chomp1d(padding,True) 49 | self.non_lin = nn.PReLU(num_parameters=n_outputs) if relu_type == 'prelu' else nn.ReLU() 50 | 51 | def forward(self, x): 52 | if self.dwpw: 53 | return self.conv(x) 54 | else: 55 | out = self.conv( x ) 56 | out = self.batchnorm( out ) 57 | out = self.chomp( out ) 58 | return self.non_lin( out ) 59 | 60 | 61 | 62 | # --------- MULTI-BRANCH VERSION --------------- 63 | class MultibranchTemporalBlock(nn.Module): 64 | def __init__(self, n_inputs, n_outputs, kernel_sizes, stride, dilation, padding, dropout=0.2, 65 | relu_type = 'relu', dwpw=False): 66 | super(MultibranchTemporalBlock, self).__init__() 67 | 68 | self.kernel_sizes = kernel_sizes 69 | self.num_kernels = len( kernel_sizes ) 70 | self.n_outputs_branch = n_outputs // self.num_kernels 71 | assert n_outputs % self.num_kernels == 0, "Number of output channels needs to be divisible by number of kernels" 72 | 73 | 74 | 75 | for k_idx,k in enumerate( self.kernel_sizes ): 76 | cbcr = ConvBatchChompRelu( n_inputs, self.n_outputs_branch, k, stride, dilation, padding[k_idx], relu_type, dwpw=dwpw) 77 | setattr( self,'cbcr0_{}'.format(k_idx), cbcr ) 78 | self.dropout0 = nn.Dropout(dropout) 79 | 80 | for k_idx,k in enumerate( self.kernel_sizes ): 81 | cbcr = ConvBatchChompRelu( n_outputs, self.n_outputs_branch, k, stride, dilation, padding[k_idx], relu_type, dwpw=dwpw) 82 | setattr( self,'cbcr1_{}'.format(k_idx), cbcr ) 83 | self.dropout1 = nn.Dropout(dropout) 84 | 85 | # downsample? 86 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if (n_inputs//self.num_kernels) != n_outputs else None 87 | 88 | # final relu 89 | if relu_type == 'relu': 90 | self.relu_final = nn.ReLU() 91 | elif relu_type == 'prelu': 92 | self.relu_final = nn.PReLU(num_parameters=n_outputs) 93 | 94 | def forward(self, x): 95 | 96 | # first multi-branch set of convolutions 97 | outputs = [] 98 | for k_idx in range( self.num_kernels ): 99 | branch_convs = getattr(self,'cbcr0_{}'.format(k_idx)) 100 | outputs.append( branch_convs(x) ) 101 | out0 = torch.cat(outputs, 1) 102 | out0 = self.dropout0( out0 ) 103 | 104 | # second multi-branch set of convolutions 105 | outputs = [] 106 | for k_idx in range( self.num_kernels ): 107 | branch_convs = getattr(self,'cbcr1_{}'.format(k_idx)) 108 | outputs.append( branch_convs(out0) ) 109 | out1 = torch.cat(outputs, 1) 110 | out1 = self.dropout1( out1 ) 111 | 112 | # downsample? 113 | res = x if self.downsample is None else self.downsample(x) 114 | 115 | return self.relu_final(out1 + res) 116 | 117 | class MultibranchTemporalConvNet(nn.Module): 118 | def __init__(self, num_inputs, num_channels, tcn_options, dropout=0.2, relu_type='relu', dwpw=False): 119 | super(MultibranchTemporalConvNet, self).__init__() 120 | 121 | self.ksizes = tcn_options['kernel_size'] 122 | 123 | layers = [] 124 | num_levels = len(num_channels) 125 | for i in range(num_levels): 126 | dilation_size = 2 ** i 127 | in_channels = num_inputs if i == 0 else num_channels[i-1] 128 | out_channels = num_channels[i] 129 | 130 | 131 | padding = [ (s-1)*dilation_size for s in self.ksizes] 132 | layers.append( MultibranchTemporalBlock( in_channels, out_channels, self.ksizes, 133 | stride=1, dilation=dilation_size, padding = padding, dropout=dropout, relu_type = relu_type, 134 | dwpw=dwpw) ) 135 | 136 | self.network = nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | return self.network(x) 140 | # -------------------------------- 141 | 142 | 143 | # --------------- STANDARD VERSION (SINGLE BRANCH) ------------------------ 144 | class TemporalBlock(nn.Module): 145 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2, 146 | symm_chomp = False, no_padding = False, relu_type = 'relu', dwpw=False): 147 | super(TemporalBlock, self).__init__() 148 | 149 | self.no_padding = no_padding 150 | if self.no_padding: 151 | downsample_chomp_size = 2*padding-4 152 | padding = 1 # hack-ish thing so that we can use 3 layers 153 | 154 | if dwpw: 155 | self.net = nn.Sequential( 156 | # -- first conv set within block 157 | # -- dw 158 | nn.Conv1d( n_inputs, n_inputs, kernel_size, stride=stride, 159 | padding=padding, dilation=dilation, groups=n_inputs, bias=False), 160 | nn.BatchNorm1d(n_inputs), 161 | Chomp1d(padding, True), 162 | nn.PReLU(num_parameters=n_inputs) if relu_type == 'prelu' else nn.ReLU(inplace=True), 163 | # -- pw 164 | nn.Conv1d( n_inputs, n_outputs, 1, 1, 0, bias=False), 165 | nn.BatchNorm1d(n_outputs), 166 | nn.PReLU(num_parameters=n_outputs) if relu_type == 'prelu' else nn.ReLU(inplace=True), 167 | nn.Dropout(dropout), 168 | # -- second conv set within block 169 | # -- dw 170 | nn.Conv1d( n_outputs, n_outputs, kernel_size, stride=stride, 171 | padding=padding, dilation=dilation, groups=n_outputs, bias=False), 172 | nn.BatchNorm1d(n_outputs), 173 | Chomp1d(padding, True), 174 | nn.PReLU(num_parameters=n_outputs) if relu_type == 'prelu' else nn.ReLU(inplace=True), 175 | # -- pw 176 | nn.Conv1d( n_outputs, n_outputs, 1, 1, 0, bias=False), 177 | nn.BatchNorm1d(n_outputs), 178 | nn.PReLU(num_parameters=n_outputs) if relu_type == 'prelu' else nn.ReLU(inplace=True), 179 | nn.Dropout(dropout), 180 | ) 181 | else: 182 | self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size, 183 | stride=stride, padding=padding, dilation=dilation) 184 | self.batchnorm1 = nn.BatchNorm1d(n_outputs) 185 | self.chomp1 = Chomp1d(padding,symm_chomp) if not self.no_padding else None 186 | if relu_type == 'relu': 187 | self.relu1 = nn.ReLU() 188 | elif relu_type == 'prelu': 189 | self.relu1 = nn.PReLU(num_parameters=n_outputs) 190 | self.dropout1 = nn.Dropout(dropout) 191 | 192 | self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size, 193 | stride=stride, padding=padding, dilation=dilation) 194 | self.batchnorm2 = nn.BatchNorm1d(n_outputs) 195 | self.chomp2 = Chomp1d(padding,symm_chomp) if not self.no_padding else None 196 | if relu_type == 'relu': 197 | self.relu2 = nn.ReLU() 198 | elif relu_type == 'prelu': 199 | self.relu2 = nn.PReLU(num_parameters=n_outputs) 200 | self.dropout2 = nn.Dropout(dropout) 201 | 202 | 203 | if self.no_padding: 204 | self.net = nn.Sequential(self.conv1, self.batchnorm1, self.relu1, self.dropout1, 205 | self.conv2, self.batchnorm2, self.relu2, self.dropout2) 206 | else: 207 | self.net = nn.Sequential(self.conv1, self.batchnorm1, self.chomp1, self.relu1, self.dropout1, 208 | self.conv2, self.batchnorm2, self.chomp2, self.relu2, self.dropout2) 209 | 210 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None 211 | if self.no_padding: 212 | self.downsample_chomp = Chomp1d(downsample_chomp_size,True) 213 | if relu_type == 'relu': 214 | self.relu = nn.ReLU() 215 | elif relu_type == 'prelu': 216 | self.relu = nn.PReLU(num_parameters=n_outputs) 217 | 218 | def forward(self, x): 219 | out = self.net(x) 220 | if self.no_padding: 221 | x = self.downsample_chomp(x) 222 | res = x if self.downsample is None else self.downsample(x) 223 | return self.relu(out + res) 224 | 225 | 226 | class TemporalConvNet(nn.Module): 227 | def __init__(self, num_inputs, num_channels, tcn_options, dropout=0.2, relu_type='relu', dwpw=False): 228 | super(TemporalConvNet, self).__init__() 229 | self.ksize = tcn_options['kernel_size'][0] if isinstance(tcn_options['kernel_size'], list) else tcn_options['kernel_size'] 230 | layers = [] 231 | num_levels = len(num_channels) 232 | for i in range(num_levels): 233 | dilation_size = 2 ** i 234 | in_channels = num_inputs if i == 0 else num_channels[i-1] 235 | out_channels = num_channels[i] 236 | layers.append( TemporalBlock(in_channels, out_channels, self.ksize, stride=1, dilation=dilation_size, 237 | padding=(self.ksize-1) * dilation_size, dropout=dropout, symm_chomp = True, 238 | no_padding = False, relu_type=relu_type, dwpw=dwpw) ) 239 | 240 | self.network = nn.Sequential(*layers) 241 | 242 | def forward(self, x): 243 | return self.network(x) 244 | # -------------------------------- 245 | -------------------------------------------------------------------------------- /baseline/avse2/README.md: -------------------------------------------------------------------------------- 1 | ## Baseline model for 2nd COG-MHEAR Audio-Visual Speech Enhancement Challenge 2 | 3 | [Challenge link](https://challenge.cogmhear.org/) 4 | 5 | ## Requirements 6 | * Python >= 3.6 7 | * [PyTorch](https://pytorch.org/) 8 | * [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/) 9 | * [Decord](https://github.com/dmlc/decord) 10 | 11 | ```bash 12 | # You can install all requirements using 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Usage 17 | Update DATA_ROOT in config.py 18 | ```bash 19 | # Expected folder structure 20 | |-- train 21 | | `-- scenes 22 | |-- dev 23 | | `-- scenes 24 | |-- eval 25 | | `-- scenes 26 | ``` 27 | 28 | ### Train 29 | ```bash 30 | python train.py --log_dir ./logs --batch_size 2 --lr 0.001 --gpu 1 --max_epochs 20 31 | 32 | optional arguments: 33 | -h, --help show this help message and exit 34 | --batch_size 4 Batch size for training 35 | --lr 0.001 Learning rate for training 36 | --log_dir LOG_DIR Path to save tensorboard logs 37 | ``` 38 | 39 | ### Test 40 | ```bash 41 | usage: test.py [-h] --ckpt_path ./model.pth --save_root ./enhanced --model_uid avse [--dev_set False] [--eval_set True] [--cpu True] 42 | 43 | optional arguments: 44 | -h, --help show this help message and exit 45 | --ckpt_path CKPT_PATH Path to model checkpoint 46 | --save_root SAVE_ROOT Path to save enhanced audio 47 | --model_uid MODEL_UID Folder name to save enhanced audio 48 | --dev_set True Evaluate model on dev set 49 | --eval_set False Evaluate model on eval set 50 | --cpu True Evaluate on CPU (default is GPU) 51 | ``` 52 | 53 | -------------------------------------------------------------------------------- /baseline/avse2/config.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile 2 | 3 | SEED = 1143 # Random seed for reproducibility 4 | sampling_rate = 16000 # Sampling rate for audio 5 | max_frames = 75 # Maximum number of frames per video for training 6 | max_audio_len = sampling_rate * 3 # Maximum number of audio samples per video for training 7 | img_height, img_width = 224, 224 # Image height and width for training 8 | 9 | DATA_ROOT = "./data" 10 | assert not isfile(DATA_ROOT), "Please set DATA_ROOT in config.py to the correct path to the avsec dataset" 11 | -------------------------------------------------------------------------------- /baseline/avse2/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from os.path import join 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from decord import VideoReader 10 | from decord import cpu 11 | from pytorch_lightning import LightningDataModule 12 | from scipy.io import wavfile 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | 16 | from config import * 17 | from utils import subsample_list 18 | 19 | 20 | class AVSEDataset(Dataset): 21 | def __init__(self, scenes_root, shuffle=True, seed=SEED, subsample=1, 22 | clipped_batch=True, sample_items=True, test_set=False): 23 | super(AVSEDataset, self).__init__() 24 | self.test_set = test_set 25 | self.clipped_batch = clipped_batch 26 | self.scenes_root = scenes_root 27 | self.files_list = self.build_files_list 28 | if shuffle: 29 | random.seed(SEED) 30 | random.shuffle(self.files_list) 31 | if subsample != 1: 32 | self.files_list = subsample_list(self.files_list, sample_rate=subsample) 33 | logging.info("Found {} utterances".format(len(self.files_list))) 34 | self.data_count = len(self.files_list) 35 | self.batch_index = 0 36 | self.total_batches_seen = 0 37 | self.batch_input = {"noisy": None} 38 | self.index = 0 39 | self.max_len = len(self.files_list) 40 | self.max_cache = 0 41 | self.seed = seed 42 | self.window = "hann" 43 | self.fading = False 44 | self.sample_items = sample_items 45 | 46 | @property 47 | def build_files_list(self): 48 | files_list = [] 49 | for file in os.listdir(self.scenes_root): 50 | if file.endswith("mixed.wav"): 51 | files = (join(self.scenes_root, file.replace("mixed", "target")), 52 | join(self.scenes_root, file.replace("mixed", "interferer")), 53 | join(self.scenes_root, file), 54 | join(self.scenes_root, file.replace("_mixed.wav", "_silent.mp4")), 55 | ) 56 | if not self.test_set: 57 | if all([isfile(f) for f in files]): 58 | files_list.append(files) 59 | else: 60 | files_list.append(files) 61 | return files_list 62 | 63 | def __len__(self): 64 | return len(self.files_list) 65 | 66 | def __getitem__(self, idx): 67 | while True: 68 | try: 69 | data = {} 70 | if self.sample_items: 71 | clean_file, noise_file, noisy_file, mp4_file = random.sample(self.files_list, 1)[0] 72 | else: 73 | clean_file, noise_file, noisy_file, mp4_file = self.files_list[idx] 74 | data["noisy_audio"], data["clean"], data["video_frames"] = self.get_data(clean_file, noise_file, 75 | noisy_file, mp4_file) 76 | data['scene'] = clean_file.replace(self.scenes_root, "").replace("_target.wav", "").replace("/", "") 77 | return data 78 | except Exception as e: 79 | logging.error("Error in loading data: {}".format(e)) 80 | 81 | def load_wav(self, wav_path): 82 | return wavfile.read(wav_path)[1].astype(np.float32) / (2 ** 15) 83 | 84 | def get_data(self, clean_file, noise_file, noisy_file, mp4_file): 85 | noisy = self.load_wav(noisy_file) 86 | vr = VideoReader(mp4_file, ctx=cpu(0)) 87 | if isfile(clean_file): 88 | clean = self.load_wav(clean_file) 89 | else: 90 | # clean file for test set is not available 91 | clean = np.zeros(noisy.shape) 92 | if self.clipped_batch: 93 | if clean.shape[0] > 48000: 94 | clip_idx = random.randint(0, clean.shape[0] - 48000) 95 | video_idx = int((clip_idx / 16000) * 25) 96 | clean = clean[clip_idx:clip_idx + 48000] 97 | noisy = noisy[clip_idx:clip_idx + 48000] 98 | else: 99 | video_idx = -1 100 | clean = np.pad(clean, pad_width=[0, 48000 - clean.shape[0]], mode="constant") 101 | noisy = np.pad(noisy, pad_width=[0, 48000 - noisy.shape[0]], mode="constant") 102 | if len(vr) < 75: 103 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 104 | else: 105 | max_idx = min(video_idx + 75, len(vr)) 106 | frames = vr.get_batch(list(range(video_idx, max_idx))).asnumpy() 107 | bg_frames = np.array( 108 | [cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))]).astype(np.float32) 109 | bg_frames /= 255.0 110 | if len(bg_frames) < 75: 111 | bg_frames = np.concatenate( 112 | (bg_frames, np.zeros((75 - len(bg_frames), img_height, img_width)).astype(bg_frames.dtype)), 113 | axis=0) 114 | else: 115 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 116 | bg_frames = np.array( 117 | [cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))]).astype(np.float32) 118 | bg_frames /= 255.0 119 | return noisy, clean, bg_frames[np.newaxis, ...] 120 | 121 | 122 | class AVSEDataModule(LightningDataModule): 123 | def __init__(self, batch_size=16): 124 | super(AVSEDataModule, self).__init__() 125 | self.train_dataset_batch = AVSEDataset(join(DATA_ROOT, "train/scenes")) 126 | self.dev_dataset_batch = AVSEDataset(join(DATA_ROOT, "dev/scenes")) 127 | self.dev_dataset = AVSEDataset(join(DATA_ROOT, "dev/scenes"), clipped_batch=False, sample_items=False) 128 | self.eval_dataset = AVSEDataset(join(DATA_ROOT, "eval/scenes"), clipped_batch=False, sample_items=False, 129 | test_set=True) 130 | self.batch_size = batch_size 131 | 132 | def train_dataloader(self): 133 | assert len(self.train_dataset_batch) > 0, "No training data found" 134 | return torch.utils.data.DataLoader(self.train_dataset_batch, batch_size=self.batch_size, num_workers=4, 135 | pin_memory=True, persistent_workers=True) 136 | 137 | def val_dataloader(self): 138 | assert len(self.dev_dataset_batch) > 0, "No validation data found" 139 | return torch.utils.data.DataLoader(self.dev_dataset_batch, batch_size=self.batch_size, num_workers=4, 140 | pin_memory=True, 141 | persistent_workers=True) 142 | 143 | def test_dataloader(self): 144 | return torch.utils.data.DataLoader(self.dev_dataset, batch_size=self.batch_size, num_workers=4) 145 | 146 | 147 | if __name__ == '__main__': 148 | 149 | dataset = AVSEDataModule(batch_size=1).train_dataset_batch 150 | for i in tqdm(range(len(dataset)), ascii=True): 151 | data = dataset[i] 152 | for k, v in data.items(): 153 | print(k, v) 154 | break -------------------------------------------------------------------------------- /baseline/avse2/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from pytorch_lightning import LightningModule 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | 11 | from utils.dnn import BasicBlock, ResNet, Swish, cal_si_snr 12 | 13 | 14 | class AudioEncoder(nn.Module): 15 | def __init__(self, kernel_size=2, out_channels=64): 16 | super(AudioEncoder, self).__init__() 17 | self.conv1d = nn.Conv1d(in_channels=1, out_channels=out_channels, 18 | kernel_size=kernel_size, stride=kernel_size // 2, groups=1, bias=False) 19 | 20 | def forward(self, x): 21 | x = torch.unsqueeze(x, dim=1) 22 | x = self.conv1d(x) 23 | x = F.relu(x) 24 | return x 25 | 26 | 27 | class AudioDecoder(nn.ConvTranspose1d): 28 | def __init__(self, *args, **kwargs): 29 | super(AudioDecoder, self).__init__(*args, **kwargs) 30 | 31 | def forward(self, x): 32 | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) 33 | if torch.squeeze(x).dim() == 1: 34 | x = torch.squeeze(x, dim=1) 35 | else: 36 | x = torch.squeeze(x) 37 | return x 38 | 39 | 40 | class VisualFeatNet(nn.Module): 41 | def __init__(self, relu_type='swish'): 42 | super(VisualFeatNet, self).__init__() 43 | self.frontend_nout = 64 44 | self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) 45 | if relu_type == 'relu': 46 | frontend_relu = nn.ReLU(True) 47 | elif relu_type == 'prelu': 48 | frontend_relu = nn.PReLU(self.frontend_nout) 49 | elif relu_type == 'swish': 50 | frontend_relu = Swish() 51 | self.frontend3D = nn.Sequential( 52 | nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), 53 | nn.BatchNorm3d(self.frontend_nout), 54 | frontend_relu, 55 | nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))) 56 | 57 | self.nn_out = nn.Linear(512, 256, bias=False) 58 | torch.nn.init.xavier_uniform_(self.nn_out.weight) 59 | self._initialize_weights_randomly() 60 | 61 | def forward(self, x): 62 | B, C, T, H, W = x.size() 63 | x = self.frontend3D(x) 64 | Tnew = x.shape[2] 65 | n_batch, n_channels, s_time, sx, sy = x.shape 66 | x = x.transpose(1, 2).reshape(n_batch * s_time, n_channels, sx, sy) 67 | x = self.trunk(x) 68 | x = x.view(B, Tnew, x.size(1)) 69 | return torch.relu(self.nn_out(x)) 70 | 71 | def _initialize_weights_randomly(self): 72 | f = lambda n: math.sqrt(2.0 / float(n)) 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 75 | n = np.prod(m.kernel_size) * m.out_channels 76 | m.weight.data.normal_(0, f(n)) 77 | if m.bias is not None: 78 | m.bias.data.zero_() 79 | 80 | elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | 84 | elif isinstance(m, nn.Linear): 85 | n = float(m.weight.data[0].nelement()) 86 | m.weight.data = m.weight.data.normal_(0, f(n)) 87 | 88 | 89 | class SeparatorBlock(nn.Module): 90 | def __init__(self, out_channels, hidden_channels, dropout=0, bidirectional=False): 91 | super(SeparatorBlock, self).__init__() 92 | self.intra_rnn = nn.LSTM(out_channels, hidden_channels, 1, batch_first=True, dropout=dropout, 93 | bidirectional=bidirectional) 94 | self.inter_rnn = nn.LSTM(out_channels, hidden_channels, 1, batch_first=True, dropout=dropout, 95 | bidirectional=bidirectional) 96 | self.intra_norm = nn.GroupNorm(1, out_channels, eps=1e-8) 97 | self.inter_norm = nn.GroupNorm(1, out_channels, eps=1e-8) 98 | self.intra_linear = nn.Linear(hidden_channels * 2 if bidirectional else hidden_channels, out_channels) 99 | self.inter_linear = nn.Linear(hidden_channels * 2 if bidirectional else hidden_channels, out_channels) 100 | 101 | def forward(self, x): 102 | B, N, K, S = x.shape 103 | intra_rnn = x.permute(0, 3, 2, 1).contiguous().view(B * S, K, N) 104 | intra_rnn, _ = self.intra_rnn(intra_rnn) 105 | intra_rnn = self.intra_linear(intra_rnn.contiguous().view(B * S * K, -1)).view(B * S, K, -1) 106 | intra_rnn = intra_rnn.view(B, S, K, N) 107 | intra_rnn = intra_rnn.permute(0, 3, 2, 1).contiguous() 108 | intra_rnn = self.intra_norm(intra_rnn) 109 | intra_rnn = intra_rnn + x 110 | inter_rnn = intra_rnn.permute(0, 2, 3, 1).contiguous().view(B * K, S, N) 111 | inter_rnn, _ = self.inter_rnn(inter_rnn) 112 | inter_rnn = self.inter_linear(inter_rnn.contiguous().view(B * S * K, -1)).view(B * K, S, -1) 113 | inter_rnn = inter_rnn.view(B, K, S, N) 114 | inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous() 115 | inter_rnn = self.inter_norm(inter_rnn) 116 | out = inter_rnn + intra_rnn 117 | return out 118 | 119 | 120 | class Separator(nn.Module): 121 | def __init__(self, in_channels, out_channels, hidden_channels, dropout=0, 122 | bidirectional=False, num_layers=4, K=200): 123 | super(Separator, self).__init__() 124 | self.K = K 125 | self.num_layers = num_layers 126 | self.input_conv = nn.Sequential(nn.GroupNorm(1, in_channels, eps=1e-8), 127 | nn.Conv1d(in_channels, out_channels, 1, bias=False)) 128 | self.separator_blocks = nn.Sequential(*[SeparatorBlock(out_channels, hidden_channels, dropout=dropout, 129 | bidirectional=bidirectional) for _ in range(num_layers)]) 130 | self.conv2d = nn.Conv2d(out_channels, out_channels, kernel_size=1) 131 | self.end_conv1x1 = nn.Conv1d(out_channels, 256, 1, bias=False) 132 | self.prelu = nn.PReLU() 133 | self.activation = nn.ReLU() 134 | self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()) 135 | self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()) 136 | 137 | def forward(self, x): 138 | x = self.input_conv(x) 139 | x, gap = self._segment(x, self.K) 140 | x = self.separator_blocks(x) 141 | x = self.prelu(x) 142 | x = self.conv2d(x) 143 | B, _, K, S = x.shape 144 | x = x.view(B, -1, K, S) 145 | x = self._over_add(x, gap) 146 | x = self.output(x) * self.output_gate(x) 147 | x = self.end_conv1x1(x) 148 | _, N, L = x.shape 149 | x = x.view(B, -1, N, L) 150 | x = self.activation(x) 151 | return x.transpose(0, 1)[0] 152 | 153 | def _padding(self, input, K): 154 | B, N, L = input.shape 155 | P = K // 2 156 | gap = K - (P + L % K) % K 157 | if gap > 0: 158 | pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type()) 159 | input = torch.cat([input, pad], dim=2) 160 | 161 | _pad = torch.Tensor(torch.zeros(B, N, P)).type(input.type()) 162 | input = torch.cat([_pad, input, _pad], dim=2) 163 | return input, gap 164 | 165 | def _segment(self, input, K): 166 | B, N, L = input.shape 167 | P = K // 2 168 | input, gap = self._padding(input, K) 169 | input1 = input[:, :, :-P].contiguous().view(B, N, -1, K) 170 | input2 = input[:, :, P:].contiguous().view(B, N, -1, K) 171 | input = torch.cat([input1, input2], dim=3).view( 172 | B, N, -1, K).transpose(2, 3) 173 | return input.contiguous(), gap 174 | 175 | def _over_add(self, input, gap): 176 | B, N, K, S = input.shape 177 | P = K // 2 178 | input = input.transpose(2, 3).contiguous().view(B, N, -1, K * 2) 179 | input1 = input[:, :, :, :K].contiguous().view(B, N, -1)[:, :, P:] 180 | input2 = input[:, :, :, K:].contiguous().view(B, N, -1)[:, :, :-P] 181 | input = input1 + input2 182 | if gap > 0: 183 | input = input[:, :, :-gap] 184 | return input 185 | 186 | 187 | class AVSE(nn.Module): 188 | def __init__(self): 189 | super(AVSE, self).__init__() 190 | self.audio_encoder = AudioEncoder(kernel_size=16, out_channels=256) 191 | self.audio_decoder = AudioDecoder(in_channels=256, out_channels=1, kernel_size=16, stride=8, bias=False) 192 | self.visual_encoder = VisualFeatNet() 193 | self.separator = Separator(512, 64, 128, num_layers=6, bidirectional=True) 194 | 195 | def forward(self, input): 196 | noisy = input["noisy_audio"] 197 | encoded_audio = self.audio_encoder(noisy) 198 | video_frames = input["video_frames"] 199 | encoded_visual = self.visual_encoder(video_frames) 200 | _, _, time_steps = encoded_audio.shape 201 | _, _, vis_feat_size = encoded_visual.shape 202 | upsampled_visual_feat = F.interpolate(encoded_visual.unsqueeze(1), size=(time_steps, vis_feat_size), 203 | mode="bilinear").reshape(-1, time_steps, vis_feat_size).moveaxis(1, 2) 204 | encoded_av = torch.cat((upsampled_visual_feat, encoded_audio), dim=-2) 205 | mask = self.separator(encoded_av) 206 | out = mask * encoded_audio 207 | audio = self.audio_decoder(out) 208 | return audio 209 | 210 | 211 | class AVSEModule(LightningModule): 212 | def __init__(self, lr=0.00015, val_dataset=None): 213 | super(AVSEModule, self).__init__() 214 | self.lr = lr 215 | self.val_dataset = val_dataset 216 | self.loss = cal_si_snr 217 | self.model = AVSE() 218 | 219 | def forward(self, data): 220 | """ Processes the input tensor x and returns an output tensor.""" 221 | est_source = self.model(data) 222 | return est_source 223 | 224 | def training_step(self, batch_inp, batch_idx): 225 | loss = self.cal_loss(batch_inp) 226 | self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 227 | return loss 228 | 229 | def validation_step(self, batch_inp, batch_idx): 230 | loss = self.cal_loss(batch_inp) 231 | self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 232 | return loss 233 | 234 | def enhance(self, data): 235 | inputs = dict(noisy_audio=torch.from_numpy(data["noisy_audio"][np.newaxis, ...]).to(self.device), 236 | video_frames=torch.from_numpy(data["video_frames"][np.newaxis, ...]).to(self.device)) 237 | estimated_audio = self(inputs).cpu().numpy() 238 | estimated_audio /= np.max(np.abs(estimated_audio)) 239 | return estimated_audio 240 | 241 | def training_epoch_end(self, outputs): 242 | if self.val_dataset is not None: 243 | with torch.no_grad(): 244 | tensorboard = self.logger.experiment 245 | for index in range(5): 246 | rand_int = random.randint(0, len(self.val_dataset)) 247 | data = self.val_dataset[rand_int] 248 | estimated_audio = self.enhance(data) 249 | tensorboard.add_audio("{}/{}_clean".format(self.current_epoch, index), 250 | data["clean"][np.newaxis, ...], 251 | sample_rate=16000) 252 | tensorboard.add_audio("{}/{}_noisy".format(self.current_epoch, index), 253 | data["noisy_audio"][np.newaxis, ...], 254 | sample_rate=16000) 255 | tensorboard.add_audio("{}/{}_enhanced".format(self.current_epoch, index), 256 | estimated_audio.reshape(-1)[np.newaxis, ...], 257 | sample_rate=16000) 258 | 259 | def cal_loss(self, batch_inp): 260 | mask = batch_inp["clean"].T 261 | pred_mask = self(batch_inp).T.reshape(mask.shape) 262 | loss = self.loss(pred_mask.unsqueeze(2), mask.unsqueeze(2)) 263 | loss[loss < -30] = -30 264 | return torch.mean(loss) 265 | 266 | def configure_optimizers(self): 267 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 268 | return { 269 | "optimizer": optimizer, 270 | "lr_scheduler": { 271 | "scheduler": ReduceLROnPlateau(optimizer, factor=0.8, patience=5), 272 | "monitor": "val_loss_epoch", 273 | }, 274 | } 275 | -------------------------------------------------------------------------------- /baseline/avse2/requirements.txt: -------------------------------------------------------------------------------- 1 | decord==0.6.0 2 | numpy==1.23.5 3 | opencv_contrib_python==4.7.0.68 4 | opencv_python==4.7.0.68 5 | opencv_python_headless==4.7.0.68 6 | pytorch_lightning==1.8.0 7 | scipy==1.7.1 8 | SoundFile==0.10.3.post1 9 | torch==2.0.0 10 | tqdm==4.62.3 11 | -------------------------------------------------------------------------------- /baseline/avse2/test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os.path import isfile 3 | from os import makedirs 4 | from os.path import join 5 | 6 | import soundfile as sf 7 | import torch 8 | from tqdm import tqdm 9 | 10 | 11 | from config import sampling_rate 12 | 13 | from dataset import AVSEDataModule 14 | from model import AVSEModule 15 | from utils import str2bool 16 | 17 | 18 | def main(args): 19 | enhanced_root = join(args.save_root, args.model_uid) 20 | makedirs(args.save_root, exist_ok=True) 21 | makedirs(enhanced_root, exist_ok=True) 22 | datamodule = AVSEDataModule(batch_size=1) 23 | if args.dev_set and args.eval_set: 24 | raise RuntimeError("Select either dev set or test set") 25 | elif args.dev_set: 26 | dataset = datamodule.dev_dataset 27 | elif args.eval_set: 28 | dataset = datamodule.eval_dataset 29 | else: 30 | raise RuntimeError("Select one of dev set and test set") 31 | try: 32 | model = AVSEModule.load_from_checkpoint(args.ckpt_path) 33 | print("Model loaded") 34 | except Exception as e: 35 | raise FileNotFoundError("Cannot load model weights: {}".format(args.ckpt_path)) 36 | if not args.cpu: 37 | model.to("cuda:0") 38 | model.eval() 39 | with torch.no_grad(): 40 | for i in tqdm(range(len(dataset))): 41 | data = dataset[i] 42 | filename = f"{data['scene']}.wav" 43 | enhanced_path = join(enhanced_root, filename) 44 | if not isfile(enhanced_path): 45 | estimated_audio = model.enhance(data).reshape(-1) 46 | sf.write(enhanced_path, estimated_audio, samplerate=sampling_rate) 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = ArgumentParser() 51 | parser.add_argument("--ckpt_path", type=str, required=True, help="Path to model checkpoint") 52 | parser.add_argument("--save_root", type=str, required=True, help="Path to save enhanced audio") 53 | parser.add_argument("--model_uid", type=str, required=True, help="Folder name to save enhanced audio") 54 | parser.add_argument("--dev_set", type=str2bool, default=False, help="Evaluate model on dev set") 55 | parser.add_argument("--eval_set", type=str2bool, default=False, help="Evaluate model on eval set") 56 | parser.add_argument("--cpu", type=str2bool, required=False, help="Evaluate model on CPU") 57 | args = parser.parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /baseline/avse2/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from config import SEED 4 | # fix random seeds for reproducibility 5 | torch.manual_seed(SEED) 6 | torch.backends.cudnn.deterministic = False 7 | torch.backends.cudnn.benchmark = True 8 | np.random.seed(SEED) 9 | 10 | from argparse import ArgumentParser 11 | from pytorch_lightning import Trainer 12 | from pytorch_lightning.callbacks import ModelCheckpoint 13 | 14 | from dataset import AVSEDataModule 15 | from model import AVSEModule 16 | 17 | def main(args): 18 | checkpoint_callback = ModelCheckpoint(monitor="val_loss_epoch") 19 | datamodule = AVSEDataModule(batch_size=args.batch_size) 20 | model = AVSEModule(val_dataset=datamodule.dev_dataset, lr=args.lr) 21 | trainer = Trainer.from_argparse_args(args, default_root_dir=args.log_dir, callbacks=[checkpoint_callback], 22 | accelerator="gpu", devices=1, max_epochs=50) 23 | trainer.fit(model, datamodule) 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = ArgumentParser() 28 | parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training") 29 | parser.add_argument("--lr", type=float, default=0.0003, help="Learning rate for training") 30 | parser.add_argument("--log_dir", type=str, required=True, help="Path to save tensorboard logs") 31 | parser = Trainer.add_argparse_args(parser) 32 | args = parser.parse_args() 33 | main(args) 34 | -------------------------------------------------------------------------------- /baseline/avse2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | 5 | def subsample_list(inp_list: list, sample_rate: float): 6 | random.shuffle(inp_list) 7 | return [inp_list[i] for i in range(int(len(inp_list) * sample_rate))] 8 | 9 | 10 | def str2bool(v: str): 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | -------------------------------------------------------------------------------- /baseline/avse2/utils/dnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def get_mask(source, source_lengths): 8 | mask = source.new_ones(source.size()[:-1]).unsqueeze(-1).transpose(1, -2) 9 | B = source.size(-2) 10 | for i in range(B): 11 | mask[source_lengths[i]:, i] = 0 12 | return mask.transpose(-2, 1) 13 | 14 | 15 | def cal_si_snr(source, estimate_source): 16 | EPS = 1e-8 17 | assert source.size() == estimate_source.size() 18 | device = estimate_source.device.type 19 | 20 | source_lengths = torch.tensor( 21 | [estimate_source.shape[0]] * estimate_source.shape[-2], device=device 22 | ) 23 | mask = get_mask(source, source_lengths) 24 | estimate_source *= mask 25 | 26 | num_samples = ( 27 | source_lengths.contiguous().reshape(1, -1, 1).float() 28 | ) # [1, B, 1] 29 | mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples 30 | mean_estimate = ( 31 | torch.sum(estimate_source, dim=0, keepdim=True) / num_samples 32 | ) 33 | zero_mean_target = source - mean_target 34 | zero_mean_estimate = estimate_source - mean_estimate 35 | # mask padding position along T 36 | zero_mean_target *= mask 37 | zero_mean_estimate *= mask 38 | 39 | # Step 2. SI-SNR with PIT 40 | # reshape to use broadcast 41 | s_target = zero_mean_target # [T, B, C] 42 | s_estimate = zero_mean_estimate # [T, B, C] 43 | # s_target = s / ||s||^2 44 | dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True) # [1, B, C] 45 | s_target_energy = ( 46 | torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS 47 | ) # [1, B, C] 48 | proj = dot * s_target / s_target_energy # [T, B, C] 49 | # e_noise = s' - s_target 50 | e_noise = s_estimate - proj # [T, B, C] 51 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) 52 | si_snr_beforelog = torch.sum(proj ** 2, dim=0) / ( 53 | torch.sum(e_noise ** 2, dim=0) + EPS 54 | ) 55 | si_snr = 10 * torch.log10(si_snr_beforelog + EPS) # [B, C] 56 | 57 | return -si_snr.unsqueeze(0) 58 | 59 | 60 | 61 | class Swish(nn.Module): 62 | def forward(self, x): 63 | return x * torch.sigmoid(x) 64 | 65 | 66 | def conv3x3(in_planes, out_planes, stride=1): 67 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | 70 | 71 | def downsample_basic_block(inplanes, outplanes, stride): 72 | return nn.Sequential( 73 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), 74 | nn.BatchNorm2d(outplanes), 75 | ) 76 | 77 | 78 | def downsample_basic_block_v2(inplanes, outplanes, stride): 79 | return nn.Sequential( 80 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), 81 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 82 | nn.BatchNorm2d(outplanes), 83 | ) 84 | 85 | 86 | class BasicBlock(nn.Module): 87 | expansion = 1 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='prelu'): 90 | super(BasicBlock, self).__init__() 91 | 92 | assert relu_type in ['relu', 'prelu', 'swish'] 93 | 94 | self.conv1 = conv3x3(inplanes, planes, stride) 95 | self.bn1 = nn.BatchNorm2d(planes) 96 | 97 | # type of ReLU is an input option 98 | if relu_type == 'relu': 99 | self.relu1 = nn.ReLU(inplace=True) 100 | self.relu2 = nn.ReLU(inplace=True) 101 | elif relu_type == 'prelu': 102 | self.relu1 = nn.PReLU(num_parameters=planes) 103 | self.relu2 = nn.PReLU(num_parameters=planes) 104 | elif relu_type == 'swish': 105 | self.relu1 = Swish() 106 | self.relu2 = Swish() 107 | else: 108 | raise Exception('relu type not implemented') 109 | # -------- 110 | 111 | self.conv2 = conv3x3(planes, planes) 112 | self.bn2 = nn.BatchNorm2d(planes) 113 | 114 | self.downsample = downsample 115 | self.stride = stride 116 | 117 | def forward(self, x): 118 | residual = x 119 | out = self.conv1(x) 120 | out = self.bn1(out) 121 | out = self.relu1(out) 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | if self.downsample is not None: 125 | residual = self.downsample(x) 126 | 127 | out += residual 128 | out = self.relu2(out) 129 | 130 | return out 131 | 132 | 133 | class ResNet(nn.Module): 134 | 135 | def __init__(self, block, layers, relu_type='relu', gamma_zero=False, avg_pool_downsample=False): 136 | self.inplanes = 64 137 | self.relu_type = relu_type 138 | self.gamma_zero = gamma_zero 139 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block 140 | 141 | super(ResNet, self).__init__() 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 146 | self.avgpool = nn.AdaptiveAvgPool2d(1) 147 | 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d): 150 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | m.weight.data.normal_(0, math.sqrt(2. / n)) 152 | elif isinstance(m, nn.BatchNorm2d): 153 | m.weight.data.fill_(1) 154 | m.bias.data.zero_() 155 | 156 | if self.gamma_zero: 157 | for m in self.modules(): 158 | if isinstance(m, BasicBlock): 159 | m.bn2.weight.data.zero_() 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1): 162 | 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = self.downsample_block(inplanes=self.inplanes, 166 | outplanes=planes * block.expansion, 167 | stride=stride) 168 | 169 | layers = [] 170 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type)) 171 | self.inplanes = planes * block.expansion 172 | for i in range(1, blocks): 173 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type)) 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def forward(self, x): 178 | x = self.layer1(x) 179 | x = self.layer2(x) 180 | x = self.layer3(x) 181 | x = self.layer4(x) 182 | x = self.avgpool(x) 183 | x = x.view(x.size(0), -1) 184 | return x 185 | -------------------------------------------------------------------------------- /baseline/avse3/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Baseline model for the 3rd COG-MHEAR Audio-Visual Speech Enhancement Challenge 3 | 4 | [![Challenge Registration](https://img.shields.io/badge/Challenge-%20Registration-blue.svg)](https://challenge.cogmhear.org/#/) 5 | [![Try In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17EEK6Q5hbCwf1rNwaZAytdAiC5aK32vI?usp=sharing) 6 | 7 | 8 | ## Requirements 9 | * Python 3.6+ 10 | * [PyTorch 2.0+](https://pytorch.org/get-started/locally/) or [Tensorflow 2.0+](https://www.tensorflow.org/install) 11 | * [Keras 3.0+](https://keras.io/getting_started/) 12 | * [Decord](https://github.com/dmlc/decord) 13 | * [Librosa](https://librosa.org/doc/main/install.html) 14 | * [OpenCV](https://pypi.org/project/opencv-python/) 15 | * [Numpy](https://numpy.org/install/) 16 | * [Soundfile](https://pypi.org/project/SoundFile/) 17 | * [TQDM](https://pypi.org/project/tqdm/) 18 | 19 | ## Usage 20 | 21 | ```text 22 | # Expected directory structure 23 | avsec3_data_root 24 | ├── dev 25 | │ ├── lips 26 | │ │ └── S37890_silent.mp4 27 | │ └── scenes 28 | │ ├── S37890_interferer.wav 29 | │ ├── S37890_mixed.wav 30 | │ ├── S37890_silent.mp4 31 | │ └── S37890_target.wav 32 | └── train 33 | ├── lips 34 | │ └── S34526_silent.mp4 35 | └── scenes 36 | ├── S34526_interferer.wav 37 | ├── S34526_mixed.wav 38 | ├── S34526_silent.mp4 39 | └── S34526_target.wav 40 | ``` 41 | - Change KERAS_BACKEND to 'torch' or 'tensorflow' in `config.py` based on the backend you are using. 42 | ### Train 43 | To train the model, run the following command: 44 | ```bash 45 | python train.py --data_root 46 | ``` 47 | where `` is the name of the model to be trained, `` is the path to the training data, `` is the path to the validation data, and `` is the path to save the trained model. 48 | 49 | ### Test 50 | To test the model, run the following command: 51 | ```bash 52 | python test.py --data_root --weight_path --save_root 53 | ``` 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /baseline/avse3/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KERAS_BACKEND"] = "tensorflow" # "torch" 3 | from scipy import signal 4 | SEED = 42 5 | stft_size = 512 6 | window_size = 400 7 | window_shift = 160 8 | sampling_rate = 16000 9 | windows = signal.windows.hann 10 | max_audio_length = 40800 11 | max_video_length = 64 12 | video_frame_size = (88, 88) -------------------------------------------------------------------------------- /baseline/avse3/dataset.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | import logging 3 | import random 4 | from os.path import join, isfile 5 | 6 | import cv2 7 | import librosa 8 | import numpy as np 9 | import torch 10 | from decord import VideoReader 11 | from decord import cpu 12 | from scipy.io import wavfile 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | 16 | 17 | def subsample_list(inp_list: list, sample_rate: float): 18 | random.shuffle(inp_list) 19 | return [inp_list[i] for i in range(int(len(inp_list) * sample_rate))] 20 | 21 | 22 | class AVSEDataset(Dataset): 23 | def __init__(self, files_list, shuffle=True, seed=SEED, subsample=1, 24 | clipped_batch=True, sample_items=True, time_domain=False): 25 | super(AVSEDataset, self).__init__() 26 | self.time_domain = time_domain 27 | self.clipped_batch = clipped_batch 28 | self.files_list = files_list 29 | if shuffle: 30 | random.seed(SEED) 31 | random.shuffle(self.files_list) 32 | if subsample != 1: 33 | self.files_list = subsample_list(self.files_list, sample_rate=subsample) 34 | logging.info("Found {} utterances".format(len(self.files_list))) 35 | self.data_count = len(self.files_list) 36 | self.batch_index = 0 37 | self.total_batches_seen = 0 38 | self.batch_input = {"noisy": None} 39 | self.index = 0 40 | self.max_len = len(self.files_list) 41 | self.max_cache = 0 42 | self.seed = seed 43 | self.window = "hann" 44 | self.fading = False 45 | self.sample_items = sample_items 46 | 47 | def __len__(self): 48 | return len(self.files_list) 49 | 50 | def __getitem__(self, idx): 51 | while True: 52 | try: 53 | data = {} 54 | if self.sample_items: 55 | clean_file, noise_file, noisy_file, mp4_file, scene_id = random.sample(self.files_list, 1)[0] 56 | else: 57 | clean_file, noise_file, noisy_file, mp4_file, scene_id = self.files_list[idx] 58 | data["noisy_stft"] = self.get_stft(self.load_wav(noisy_file)).T 59 | data["clean"] = self.load_wav(clean_file) 60 | data["scene"] = scene_id 61 | 62 | data["noisy_audio"], clean_audio, data["video_frames"] = self.get_data(clean_file, noise_file, 63 | noisy_file, mp4_file) 64 | return data, clean_audio 65 | except Exception as e: 66 | logging.error("Error in loading data: {}".format(e)) 67 | 68 | def load_wav(self, wav_path): 69 | return wavfile.read(wav_path)[1].astype(np.float32) / (2 ** 15) 70 | 71 | def get_stft(self, audio): 72 | return librosa.stft(audio, win_length=window_size, n_fft=stft_size, hop_length=window_shift, window=self.window, 73 | center=True) 74 | 75 | def get_audio_features(self, audio): 76 | return np.abs(self.get_stft(audio)).transpose(1, 0).astype(np.float32) 77 | 78 | def get_data(self, clean_file, noise_file, noisy_file, mp4_file): 79 | noisy = self.load_wav(noisy_file) 80 | vr = VideoReader(mp4_file, ctx=cpu(0)) 81 | if isfile(clean_file): 82 | clean = self.load_wav(clean_file) 83 | else: 84 | # clean file for test set is not available 85 | clean = np.zeros(noisy.shape) 86 | if self.clipped_batch: 87 | if clean.shape[0] > max_audio_length: 88 | clip_idx = random.randint(0, clean.shape[0] - max_audio_length) 89 | video_idx = int((clip_idx / 16000) * 25) 90 | clean = clean[clip_idx:clip_idx + max_audio_length] 91 | noisy = noisy[clip_idx:clip_idx + max_audio_length] 92 | else: 93 | video_idx = -1 94 | clean = np.pad(clean, pad_width=[0, max_audio_length - clean.shape[0]], mode="constant") 95 | noisy = np.pad(noisy, pad_width=[0, max_audio_length - noisy.shape[0]], mode="constant") 96 | if len(vr) < max_video_length: 97 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 98 | else: 99 | max_idx = min(video_idx + max_video_length, len(vr)) 100 | frames = vr.get_batch(list(range(video_idx, max_idx))).asnumpy() 101 | bg_frames = [cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))] 102 | bg_frames = np.array([cv2.resize(bg_frames[i], video_frame_size) for i in range(len(bg_frames))]).astype( 103 | np.float32) 104 | bg_frames /= 255.0 105 | if len(bg_frames) < max_video_length: 106 | bg_frames = np.concatenate( 107 | (bg_frames, 108 | np.zeros((max_video_length - len(bg_frames), video_frame_size[0], video_frame_size[1])).astype(bg_frames.dtype)), 109 | axis=0) 110 | else: 111 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 112 | bg_frames = np.array( 113 | [cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))]).astype(np.float32) 114 | bg_frames = np.array([cv2.resize(bg_frames[i], video_frame_size) for i in range(len(bg_frames))]).astype( 115 | np.float32) 116 | 117 | bg_frames /= 255.0 118 | if self.time_domain: 119 | return noisy, clean, bg_frames[..., np.newaxis] 120 | return self.get_audio_features(noisy)[..., np.newaxis], self.get_audio_features(clean), bg_frames[..., np.newaxis] 121 | 122 | 123 | class AVSEChallengeDataModule: 124 | def __init__(self, data_root, batch_size=4, time_domain=False): 125 | super(AVSEChallengeDataModule, self).__init__() 126 | self.train_dataset_batch = AVSEDataset(self.get_files_list(join(data_root, "dev")), time_domain=time_domain) 127 | self.dev_dataset_batch = AVSEDataset(self.get_files_list(join(data_root, "dev")), time_domain=time_domain) 128 | self.dev_dataset = AVSEDataset(self.get_files_list(join(data_root, "dev")), 129 | clipped_batch=False, sample_items=False, time_domain=time_domain) 130 | # !TODO Uncomment this for test set 131 | # self.test_dataset = AVSEDataset(self.get_files_list(join(data_root, "eval"), test_set=True), sample_items=False, 132 | # clipped_batch=False, time_domain=time_domain) 133 | self.batch_size = batch_size 134 | 135 | @staticmethod 136 | def get_files_list(data_root, test_set=False): 137 | files_list = [] 138 | for file in os.listdir(join(data_root, "scenes")): 139 | if file.endswith("mixed.wav"): 140 | files = (join(data_root, "scenes", file.replace("mixed", "target")), 141 | join(data_root, "scenes", file.replace("mixed", "interferer")), 142 | join(data_root, "scenes", file), 143 | join(data_root, "lips", file.replace("_mixed.wav", "_silent.mp4")), 144 | file.replace("_mixed.wav", "") 145 | ) 146 | if not test_set: 147 | if all([isfile(f) for f in files[:-1]]): 148 | files_list.append(files) 149 | else: 150 | files_list.append(files) 151 | return files_list 152 | 153 | def train_dataloader(self): 154 | assert len(self.train_dataset_batch) > 0, "No training data found" 155 | return torch.utils.data.DataLoader(self.train_dataset_batch, batch_size=self.batch_size, num_workers=4, 156 | pin_memory=True, persistent_workers=True) 157 | 158 | def val_dataloader(self): 159 | assert len(self.dev_dataset_batch) > 0, "No validation data found" 160 | return torch.utils.data.DataLoader(self.dev_dataset_batch, batch_size=self.batch_size, num_workers=4, 161 | pin_memory=True, 162 | persistent_workers=True) 163 | -------------------------------------------------------------------------------- /baseline/avse3/loss.py: -------------------------------------------------------------------------------- 1 | import keras.ops as ops 2 | 3 | 4 | def l2_norm(s1, s2): 5 | norm = ops.sum(s1 * s2, -1, keepdims=True) 6 | return norm 7 | 8 | 9 | def si_snr_loss(s1, s2, eps=1e-8): 10 | s1_s2_norm = l2_norm(s1, s2) 11 | s2_s2_norm = l2_norm(s2, s2) 12 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 13 | e_nosie = ops.convert_to_tensor(s1) - s_target 14 | target_norm = l2_norm(s_target, s_target) 15 | noise_norm = l2_norm(e_nosie, e_nosie) 16 | snr = 10 * ops.log10((target_norm) / (noise_norm + eps) + eps) 17 | return -1 * ops.mean(snr) -------------------------------------------------------------------------------- /baseline/avse3/model.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from keras import Input 3 | from model_utils import * 4 | 5 | 6 | @keras.saving.register_keras_serializable(name="VisualFeatNet") 7 | class VisualFeatNet(Layer): 8 | def __init__(self, tcn_options=None, hidden_dim=256): 9 | super().__init__(name="visual_feat_extract") 10 | if tcn_options is None: 11 | self.tcn_options = dict(num_layers=4, kernel_size=[3], dropout=0.2, width_mult=2) 12 | self.frontend_nout = 64 13 | self.backend_out = 512 14 | self.hidden_dim = hidden_dim 15 | self.trunk = ResNet18() 16 | self.frontend3D = Sequential([ 17 | nn.Conv3D(self.frontend_nout, kernel_size=(5, 7, 7), strides=(1, 2, 2), padding="same", use_bias=False), 18 | nn.BatchNormalization(), 19 | nn.ReLU(), 20 | nn.MaxPool3D(pool_size=(1, 3, 3), strides=(1, 2, 2), padding="valid")]) 21 | self.tcn = TCN([self.hidden_dim * len(self.tcn_options['kernel_size']) * self.tcn_options['width_mult']] * 22 | self.tcn_options['num_layers'], 23 | self.tcn_options["kernel_size"], 24 | self.tcn_options["num_layers"], 25 | dilations=[1, 2, 4, 8], return_sequences=True, activation="relu", use_batch_norm=True, 26 | padding="same", 27 | dropout_rate=self.tcn_options["dropout"]) 28 | 29 | def call(self, x): 30 | x = self.frontend3D(x) 31 | B, T, H, W, C = x.shape 32 | if B is None: 33 | B = 1 34 | x = ops.reshape(x, (-1, H, W, C)) 35 | x = self.trunk(x) 36 | x = ops.reshape(x, (B, T, -1)) 37 | x = self.tcn(x) 38 | return ops.reshape(x, (B, 1, T, -1)) 39 | 40 | def compute_output_shape(self, input_shape): 41 | return input_shape[0], 1, input_shape[1], 512 42 | 43 | 44 | @keras.saving.register_keras_serializable(name="UNet") 45 | class UNet(Layer): 46 | def __init__(self, filters=64, output_nc=2, av_embedding=1024, a_only=True, activation='sigmoid'): 47 | super().__init__(name="audio_separator") 48 | self.a_only = a_only 49 | self.filters = filters 50 | self.output_nc = output_nc 51 | self.av_embedding = av_embedding 52 | self.activation = activation 53 | self.conv1 = unet_conv(self.filters) 54 | self.conv2 = unet_conv(self.filters * 2) 55 | self.conv3 = conv_block(self.filters * 4) 56 | self.conv4 = conv_block(self.filters * 8) 57 | self.conv5 = conv_block(self.filters * 8) 58 | self.conv6 = conv_block(self.filters * 8) 59 | self.conv7 = conv_block(self.filters * 8) 60 | self.conv8 = conv_block(self.filters * 8) 61 | self.frequency_pool = nn.MaxPool2D([2, 1]) 62 | if not self.a_only: 63 | self.upconv1 = up_conv(self.filters, self.filters * 8) 64 | else: 65 | self.upconv1 = up_conv(self.filters * 8) 66 | self.upconv2 = up_conv(self.filters * 8) 67 | self.upconv3 = up_conv(self.filters * 8) 68 | self.upconv4 = up_conv(self.filters * 8) 69 | self.upconv5 = up_conv(self.filters * 4) 70 | self.upconv6 = up_conv(self.filters * 2) 71 | self.upconv7 = unet_upconv(self.filters) 72 | self.upconv8 = unet_upconv(self.output_nc, True) 73 | self.activation = nn.Activation(self.activation.lower()) 74 | 75 | def call(self, mix_spec, visual_feat=None): 76 | noisy_stft_real, noisy_stft_imag = ops.stft(mix_spec, sequence_length=window_size, sequence_stride=window_shift, 77 | fft_length=stft_size) 78 | noisy_stft_real = ops.expand_dims(noisy_stft_real, axis=-1) 79 | noisy_stft_imag = ops.expand_dims(noisy_stft_imag, axis=-1) 80 | noisy_stft = ops.concatenate((noisy_stft_real, noisy_stft_imag), axis=-1)#** 0.3 81 | feat, pads = pad(noisy_stft, 32) 82 | conv1feat = self.conv1(feat) 83 | conv2feat = self.conv2(conv1feat) 84 | conv3feat = self.conv3(conv2feat) 85 | conv3feat = self.frequency_pool(conv3feat) 86 | conv4feat = self.conv4(conv3feat) 87 | conv4feat = self.frequency_pool(conv4feat) 88 | conv5feat = self.conv5(conv4feat) 89 | conv5feat = self.frequency_pool(conv5feat) 90 | conv6feat = self.conv6(conv5feat) 91 | conv6feat = self.frequency_pool(conv6feat) 92 | conv7feat = self.conv7(conv6feat) 93 | conv7feat = self.frequency_pool(conv7feat) 94 | conv8feat = self.conv8(conv7feat) 95 | conv8feat = self.frequency_pool(conv8feat) 96 | if self.a_only: 97 | av_feat = conv8feat 98 | else: 99 | B, H, W, C = conv8feat.shape 100 | upsample_visuals = ops.image.resize(visual_feat, (H, W)) 101 | av_feat = ops.concatenate((conv8feat, upsample_visuals), axis=-1) 102 | upconv1feat = self.upconv1(av_feat) 103 | upconv2feat = self.upconv2(ops.concatenate((upconv1feat, conv7feat), axis=-1)) 104 | upconv3feat = self.upconv3(ops.concatenate((upconv2feat, conv6feat), axis=-1)) 105 | upconv4feat = self.upconv4(ops.concatenate((upconv3feat, conv5feat), axis=-1)) 106 | upconv5feat = self.upconv5(ops.concatenate((upconv4feat, conv4feat), axis=-1)) 107 | upconv6feat = self.upconv6(ops.concatenate((upconv5feat, conv3feat), axis=-1)) 108 | upconv7feat = self.upconv7(ops.concatenate((upconv6feat, conv2feat), axis=-1)) 109 | predicted_mask = self.upconv8(ops.concatenate((upconv7feat, conv1feat), axis=-1)) 110 | pred_mask = self.activation(predicted_mask) 111 | pred_mask = unpad(pred_mask, pads) 112 | enhanced_stft = ops.multiply(pred_mask, noisy_stft) #** (1/0.3) 113 | enhanced_audio = ops.istft((enhanced_stft[:, :, :, 0], enhanced_stft[:, :, :, 1]), 114 | sequence_length=window_size, sequence_stride=window_shift, 115 | fft_length=stft_size) 116 | return enhanced_audio 117 | 118 | def compute_output_shape(self, input_shape): 119 | return input_shape 120 | 121 | 122 | def AVSE(video_frames=64, audio_frames=40800, batch_size=1): 123 | visual_input = Input(batch_shape=(batch_size, video_frames, video_frame_size[0], video_frame_size[1], 1), 124 | name="video_frames") 125 | audio_input = Input(batch_shape=(batch_size, audio_frames), name="noisy_audio") 126 | visual_feat = VisualFeatNet()(visual_input) 127 | output = UNet(a_only=False)(audio_input, visual_feat) 128 | return Model(inputs=[audio_input, visual_input], outputs=output) 129 | 130 | 131 | if __name__ == '__main__': 132 | model = AVSE(batch_size=1) 133 | print(model.predict({"noisy_audio": ops.ones((1, 40800)), "video_frames": ops.ones((1, 64, 88, 88, 1))}).shape) 134 | model.summary() 135 | -------------------------------------------------------------------------------- /baseline/avse3/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from model_utils.nn import * 2 | from model_utils.visual import * 3 | from model_utils.generic import * -------------------------------------------------------------------------------- /baseline/avse3/model_utils/generic.py: -------------------------------------------------------------------------------- 1 | import keras.ops as ops 2 | 3 | 4 | def pad(x, stride): 5 | h, w = x.shape[1:3] 6 | if h % stride > 0: 7 | new_h = h + stride - h % stride 8 | else: 9 | new_h = h 10 | if w % stride > 0: 11 | new_w = w + stride - w % stride 12 | else: 13 | new_w = w 14 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) 15 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) 16 | pads = ([0, 0], [lh, uh], [lw, uw], [0, 0]) 17 | out = ops.pad(x, pads, "constant", 0) 18 | return out, pads 19 | 20 | 21 | def unpad(x, pad): 22 | [_, _], [lw, uw], [lh, uh], [_, _] = pad 23 | if lh + uh > 0: 24 | x = x[:, :, lh:-uh, :] 25 | if lw + uw > 0: 26 | x = x[:, lw:-uw, :, :] 27 | return x 28 | -------------------------------------------------------------------------------- /baseline/avse3/model_utils/nn.py: -------------------------------------------------------------------------------- 1 | import keras.ops as ops 2 | import keras.layers as nn 3 | from keras import Sequential 4 | 5 | 6 | def unet_conv(output_nc, norm_layer=nn.BatchNormalization): 7 | unet_conv = Sequential( 8 | [nn.Conv2D(output_nc, kernel_size=4, strides=(2, 2), padding="same"), 9 | norm_layer(), 10 | nn.LeakyReLU(0.2)] 11 | ) 12 | return unet_conv 13 | 14 | 15 | def unet_upconv(output_nc, outermost=False, norm_layer=nn.BatchNormalization, kernel_size=4): 16 | upconv = nn.Conv2DTranspose(output_nc, kernel_size=kernel_size, strides=2, padding="same") 17 | uprelu = nn.ReLU(True) 18 | upnorm = norm_layer() 19 | if not outermost: 20 | return Sequential([upconv, upnorm, uprelu]) 21 | else: 22 | return Sequential([upconv]) 23 | 24 | 25 | def conv_block(ch_out): 26 | block = Sequential( 27 | [nn.Conv2D(ch_out, kernel_size=3, strides=1, padding="same", use_bias=True), 28 | nn.BatchNormalization(), 29 | nn.LeakyReLU(0.2), 30 | nn.Conv2D(ch_out, kernel_size=3, strides=1, padding="same", use_bias=True), 31 | nn.BatchNormalization(), 32 | nn.LeakyReLU(0.2)] 33 | ) 34 | return block 35 | 36 | 37 | def up_conv(ch_out, outermost=False): 38 | if not outermost: 39 | up = Sequential( 40 | [nn.UpSampling2D(size=(2, 1)), 41 | nn.Conv2D(ch_out, kernel_size=3, strides=1, padding="same", use_bias=True), 42 | nn.BatchNormalization(), 43 | nn.ReLU()] 44 | ) 45 | else: 46 | up = Sequential( 47 | [nn.UpSampling2D(size=(2, 1)), 48 | nn.Conv2D(ch_out, kernel_size=3, strides=1, padding='same', use_bias=True), 49 | nn.Activation(activation="sigmoid")] 50 | ) 51 | return up 52 | 53 | 54 | if __name__ == '__main__': 55 | model = up_conv(1) 56 | print(model(ops.ones((1, 128, 128, 1))).shape) -------------------------------------------------------------------------------- /baseline/avse3/test.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from argparse import ArgumentParser 3 | from os import makedirs 4 | from os.path import isfile, join 5 | 6 | import soundfile as sf 7 | from tqdm import tqdm 8 | 9 | from config import sampling_rate 10 | from dataset import AVSEChallengeDataModule 11 | from model import AVSE 12 | from utils import * 13 | 14 | 15 | def main(args): 16 | datamodule = AVSEChallengeDataModule(data_root=args.data_root, batch_size=1, time_domain=True) 17 | # can be changed to test_dataset 18 | test_dataset = datamodule.dev_dataset 19 | # test_dataset = datamodule.test_dataset 20 | 21 | makedirs(args.save_root, exist_ok=True) 22 | 23 | model = AVSE(64, 40800, batch_size=1) 24 | model.load_weights(args.weight_path) 25 | for i in tqdm(range(len(test_dataset))): 26 | data = test_dataset[i][0] 27 | filename = data["scene"] + ".wav" 28 | enhanced_path = join(args.save_root, filename) 29 | if not isfile(enhanced_path): 30 | estimated_audio = get_enhanced(model, data) 31 | estimated_audio /= np.max(np.abs(estimated_audio)) 32 | sf.write(enhanced_path, estimated_audio, samplerate=sampling_rate) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = ArgumentParser() 37 | parser.add_argument("--weight_path", type=str, required=True, help="Path to model weights") 38 | parser.add_argument("--save_root", type=str, default="./enhanced", help="Root directory to save enhanced audio") 39 | parser.add_argument("--data_root", type=str, required=True, help="Root directory of dataset") 40 | args = parser.parse_args() 41 | main(args) 42 | -------------------------------------------------------------------------------- /baseline/avse3/train.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | import argparse 3 | from loss import si_snr_loss 4 | from model import AVSE 5 | from config import * 6 | import keras 7 | import time 8 | from os.path import join 9 | 10 | from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping, TensorBoard 11 | 12 | 13 | from argparse import ArgumentParser 14 | from dataset import AVSEChallengeDataModule 15 | 16 | 17 | def str2bool(v: str): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | 26 | def main(args): 27 | dataset = AVSEChallengeDataModule(data_root=args.data_root, batch_size=args.batch_size, time_domain=True) 28 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 29 | patience=3, min_lr=10 ** (-10), cooldown=1) 30 | early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, 31 | patience=5, mode='auto') 32 | checkpointer = ModelCheckpoint(join(args.log_dir, '{epoch:03d}_{val_loss:04f}.weights.h5'), 33 | monitor='val_loss', save_best_only=False, save_weights_only=True, 34 | mode='auto', save_freq='epoch') 35 | tensorboard = TensorBoard(log_dir=args.log_dir) 36 | 37 | optimizer = keras.optimizers.Adam(learning_rate=args.lr) 38 | model = AVSE(batch_size=args.batch_size, video_frames=max_video_length, audio_frames=max_audio_length) 39 | model.summary() 40 | model.compile(optimizer=optimizer, loss=si_snr_loss) 41 | if args.checkpoint is not None: 42 | model.load_weights(args.checkpoint) 43 | start = time.time() 44 | model.fit(dataset.train_dataloader(), epochs=args.max_epochs, 45 | validation_data=dataset.val_dataloader(), 46 | callbacks=[checkpointer, reduce_lr, early_stopping, tensorboard]) 47 | print(f"Time taken {time.time() - start} sec") 48 | 49 | 50 | if __name__ == '__main__': 51 | parser = ArgumentParser() 52 | parser.add_argument("--max_epochs", type=int, default=100) 53 | parser.add_argument("--data_root", type=str, required=True, help="Path to data root") 54 | parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint") 55 | parser.add_argument("--batch_size", type=int, default=16, help="Batch size") 56 | parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") 57 | parser.add_argument("--log_dir", type=str, default="./logs", help="Path to log directory") 58 | args = parser.parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /baseline/avse3/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pad_audio(audio, length): 5 | if len(audio) < length: 6 | audio = np.pad(audio, (0, length - len(audio))) 7 | return audio 8 | 9 | 10 | def pad_video(video, length): 11 | if len(video) < length: 12 | video = np.pad(video, ((0, length - len(video)), (0, 0), (0, 0), (0, 0))) 13 | return video 14 | 15 | 16 | def get_enhanced(model, data): 17 | enhanced_audio = np.zeros(len(data["noisy_audio"])) 18 | for i in range(0, len(data["noisy_audio"]), 40800): 19 | video_idx = (i // 40800) * 64 20 | noisy_audio = data["noisy_audio"][i:i + 40800] 21 | inputs = dict(noisy_audio=pad_audio(noisy_audio, 40800)[np.newaxis, ...], 22 | video_frames=pad_video(data["video_frames"][video_idx:video_idx + 64], 64)[np.newaxis, ...]) 23 | estimated_audio = model.predict(inputs, verbose=0)[0, :] 24 | if len(enhanced_audio) < 40800: 25 | return estimated_audio[:len(enhanced_audio)] 26 | if len(noisy_audio) < 40800: 27 | enhanced_audio[i:i + len(noisy_audio)] = estimated_audio[:len(noisy_audio)] 28 | else: 29 | enhanced_audio[i:i + 40800] = estimated_audio 30 | return enhanced_audio 31 | -------------------------------------------------------------------------------- /baseline/avse4/README.md: -------------------------------------------------------------------------------- 1 | ## Baseline model for 4th COG-MHEAR Audio-Visual Speech Enhancement Challenge 2 | 3 | [Challenge link](https://challenge.cogmhear.org/) 4 | 5 | ## Requirements 6 | * [Python >= 3.6](https://www.anaconda.com/docs/getting-started/miniconda/install) 7 | * [PyTorch](https://pytorch.org/) 8 | * [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/) 9 | * [Decord](https://github.com/dmlc/decord) 10 | * [Hydra](https://hydra.cc) 11 | * [SpeechBrain](https://github.com/speechbrain/speechbrain) 12 | * [TQDM](https://github.com/tqdm/tqdm) 13 | 14 | ## Usage 15 | 16 | ```bash 17 | # Expected folder structure for the dataset 18 | data_root 19 | |-- train 20 | | `-- scenes 21 | |-- dev 22 | | `-- scenes 23 | |-- eval 24 | | `-- scenes 25 | ``` 26 | 27 | ### Train 28 | ```bash 29 | python train.py data.root="./avsec4" data.num_channels=2 trainer.log_dir="./logs" data.batch_size=8 trainer.accelerator=gpu trainer.gpus=1 30 | 31 | more arguments in conf/train.yaml 32 | ``` 33 | 34 | ### Download pretrained checkpoints 35 | ``` 36 | git lfs install 37 | git clone https://huggingface.co/cogmhear/avse4_baseline 38 | ``` 39 | 40 | ### Test 41 | ```bash 42 | 43 | # evaluating binaural avse4 baseline 44 | python test.py data.root=./avsec4 data.num_channels=2 ckpt_path=avse4_baseline/pretrained.ckpt save_dir="./eval" model_uid="avse4_binaural" 45 | 46 | # evaluating single channel avse4 baseline 47 | python test.py data.root=./avsec4 data.num_channels=1 ckpt_path=avse4_baseline/pretrained_mono.ckpt save_dir="./eval" model_uid="avse4_mono" 48 | 49 | more arguments in conf/eval.yaml 50 | ``` 51 | 52 | -------------------------------------------------------------------------------- /baseline/avse4/conf/eval.yaml: -------------------------------------------------------------------------------- 1 | ckpt_path: ??? # Path to checkpoint for evaluation (null for fresh start) 2 | save_dir: ??? # Directory where evaluation results will be saved 3 | model_uid: ??? # Unique identifier for the model 4 | cpu: False # Whether to use CPU for evaluation (True) or GPU (False) 5 | 6 | data: 7 | root: ??? # Path to the root directory containing train, dev and eval scenes data 8 | num_channels: 2 # Number of audio channels (1 for mono, 2 for stereo) 9 | audio_norm: False # Whether to normalize audio data 10 | rgb: False # Whether to use RGB images (True) or grayscale (False) 11 | dev_set: True # Whether to use dev set for evaluation (True) or not (False) 12 | eval_set: False # Whether to use eval set for evaluation (True) or not (False) -------------------------------------------------------------------------------- /baseline/avse4/conf/train.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration for AVSE4 baseline model 2 | data: 3 | root: ??? # Path to the root directory containing training data 4 | num_channels: 2 # Number of audio channels (1 for mono, 2 for stereo) 5 | batch_size: 8 # Number of samples per batch 6 | audio_norm: False # Whether to normalize audio data 7 | rgb: False # Whether to use RGB images (True) or grayscale (False) 8 | 9 | trainer: 10 | log_dir: ??? # Directory where training logs and checkpoints will be saved 11 | ckpt_path: null # Path to checkpoint for resuming training (null for fresh start) 12 | max_epochs: 100 # Maximum number of training epochs 13 | lr: 0.0001 # Learning rate for optimizer 14 | deterministic: False # Whether to use deterministic algorithms (may affect performance) 15 | fast_dev_run: False # If True, runs a single batch for train/val/test for debugging 16 | gpus: 1 # Number of GPUs to use for training 17 | accelerator: gpu # Hardware accelerator type ('cpu', 'gpu', etc.) 18 | strategy: auto # Distributed training strategy ('ddp', 'dp', 'auto', etc.) 19 | precision: 32 # Numerical precision for training (16, 32, or 16-mixed) 20 | accumulate_grad_batches: 1 # Number of batches to accumulate gradients over 21 | gradient_clip_val: null # Maximum gradient norm value (null for no clipping) 22 | log_every_n_steps: 50 # How often to log metrics (in steps) 23 | num_sanity_val_steps: 0 # Number of validation steps to run before training 24 | detect_anomaly: False # Whether to enable PyTorch anomaly detection 25 | limit_train_batches: null # Limit training to a percentage of the dataset (null for full dataset) 26 | limit_val_batches: null # Limit validation to a percentage of the dataset (null for full dataset) 27 | profiler: null # Profiler type to use (null for no profiling) -------------------------------------------------------------------------------- /baseline/avse4/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import isfile, join 3 | import logging 4 | import random 5 | from typing import List, Tuple 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | from decord import VideoReader, cpu 11 | from pytorch_lightning import LightningDataModule 12 | from scipy.io import wavfile 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | # Constants 16 | MAX_FRAMES = 75 17 | MAX_AUDIO_LEN = 48000 18 | SEED = 1143 19 | SAMPLING_RATE = 16000 20 | FRAMES_PER_SECOND = 25 21 | 22 | 23 | def subsample_list(inp_list: List, sample_rate: float) -> List: 24 | random.shuffle(inp_list) 25 | return [inp_list[i] for i in range(int(len(inp_list) * sample_rate))] 26 | 27 | 28 | class AVSE4Dataset(Dataset): 29 | def __init__(self, scenes_root, shuffle=False, seed=SEED, subsample=1, 30 | clipped_batch=False, sample_items=False, test_set=False, rgb=False, 31 | audio_norm=False, num_channels=1): 32 | super().__init__() 33 | assert num_channels in [1, 2], "Number of channels must be 1 or 2" 34 | assert os.path.isdir(scenes_root), f"Scenes root {scenes_root} not found" 35 | self.num_channels = num_channels 36 | self.mono = num_channels == 1 37 | self.img_size = 112 38 | self.audio_norm = audio_norm 39 | self.test_set = test_set 40 | self.clipped_batch = clipped_batch 41 | self.scenes_root = scenes_root 42 | self.files_list = self.build_files_list() 43 | if shuffle: 44 | random.seed(seed) 45 | random.shuffle(self.files_list) 46 | if subsample != 1: 47 | self.files_list = subsample_list(self.files_list, sample_rate=subsample) 48 | logging.info(f"Found {len(self.files_list)} utterances") 49 | self.rgb = rgb 50 | self.sample_items = sample_items 51 | 52 | def build_files_list(self) -> List[Tuple[str, str, str, str]]: 53 | if isinstance(self.scenes_root, list): 54 | return [file for root in self.scenes_root for file in self.get_files_list(root)] 55 | return self.get_files_list(self.scenes_root) 56 | 57 | def get_files_list(self, scenes_root: str) -> List[Tuple[str, str, str, str]]: 58 | files_list = [] 59 | for file in os.listdir(scenes_root): 60 | if file.endswith("_target_anechoic.wav"): 61 | files = ( 62 | join(scenes_root, file), 63 | join(scenes_root, file.replace("target_anechoic", "interferer")), 64 | join(scenes_root, file.replace("target_anechoic", "mono_mix")), 65 | join(scenes_root, file.replace("target_anechoic.wav", "silent.mp4")), 66 | join(scenes_root, file.replace("target_anechoic", "mix")), 67 | 68 | ) 69 | if not self.test_set and all(isfile(f) for f in files if not f.endswith("_interferer.wav")): 70 | files_list.append(files) 71 | elif self.test_set: 72 | files_list.append(files) 73 | return files_list 74 | 75 | def __len__(self) -> int: 76 | return len(self.files_list) 77 | 78 | def __getitem__(self, idx: int) -> dict: 79 | while True: 80 | try: 81 | if self.sample_items: 82 | clean_file, noise_file, noisy_file, mp4_file, noisy_binaural_file = random.choice(self.files_list) 83 | else: 84 | clean_file, noise_file, noisy_file, mp4_file, noisy_binaural_file = self.files_list[idx] 85 | if self.num_channels == 2: 86 | noisy_file = noisy_binaural_file 87 | noisy_audio, clean, vis_feat = self.get_data(clean_file, noise_file, noisy_file, mp4_file) 88 | data = dict(noisy_audio=noisy_audio, clean=clean, vis_feat=vis_feat) 89 | if not isinstance(self.scenes_root, list): 90 | data['scene'] = clean_file.replace(self.scenes_root, "").replace("_target_anechoic.wav", "").replace("/", "") 91 | return data 92 | except Exception as e: 93 | logging.error(f"Error in loading data: {e}, {mp4_file}, {noisy_file}") 94 | 95 | @staticmethod 96 | def load_wav(wav_path: str, mono=False) -> np.ndarray: 97 | data = wavfile.read(wav_path)[1].astype(np.float32) / 32768.0 98 | if mono and len(data.shape) > 1: 99 | data = np.mean(data, axis=1) 100 | return data 101 | def get_data(self, clean_file: str, noise_file: str, noisy_file: str, mp4_file: str) -> Tuple[ 102 | np.ndarray, np.ndarray, np.ndarray]: 103 | noisy = self.load_wav(noisy_file, self.mono) 104 | vr = VideoReader(mp4_file, ctx=cpu(0)) 105 | clean = self.load_wav(clean_file, self.mono) if isfile(clean_file) else np.zeros_like(noisy) 106 | 107 | if self.clipped_batch: 108 | noisy, clean, bg_frames = self.process_clipped_batch(noisy, clean, vr) 109 | else: 110 | bg_frames = self.process_full_batch(vr) 111 | 112 | if self.audio_norm: 113 | clean = clean / np.abs(clean).max() 114 | noisy = noisy / np.abs(noisy).max() 115 | if self.mono: 116 | clean = clean[np.newaxis, :] 117 | noisy = noisy[np.newaxis, :] 118 | else: 119 | clean = clean.T 120 | noisy = noisy.T 121 | return (noisy, clean, 122 | bg_frames if not self.rgb else bg_frames.transpose(0, 3, 1, 2)) 123 | 124 | def process_clipped_batch(self, noisy: np.ndarray, clean: np.ndarray, vr: VideoReader) -> Tuple[ 125 | np.ndarray, np.ndarray, np.ndarray]: 126 | if clean.shape[0] > MAX_AUDIO_LEN: 127 | clip_idx = random.randint(0, clean.shape[0] - MAX_AUDIO_LEN) 128 | video_idx = int((clip_idx / SAMPLING_RATE) * FRAMES_PER_SECOND) 129 | clean = clean[clip_idx:clip_idx + MAX_AUDIO_LEN] 130 | noisy = noisy[clip_idx:clip_idx + MAX_AUDIO_LEN] 131 | else: 132 | video_idx = -1 133 | if self.num_channels == 2: 134 | clean_pad = np.zeros((MAX_AUDIO_LEN, 2)) 135 | noisy_pad = np.zeros((MAX_AUDIO_LEN, 2)) 136 | else: 137 | clean_pad = np.zeros(MAX_AUDIO_LEN) 138 | noisy_pad = np.zeros(MAX_AUDIO_LEN) 139 | clean_pad[:clean.shape[0]] = clean 140 | noisy_pad[:noisy.shape[0]] = noisy 141 | clean = clean_pad 142 | noisy = noisy_pad 143 | frames = self.get_video_frames(vr, video_idx) 144 | bg_frames = self.process_frames(frames) 145 | return noisy, clean, bg_frames 146 | 147 | def process_full_batch(self, vr: VideoReader) -> np.ndarray: 148 | frames = vr.get_batch(list(range(len(vr)))).asnumpy() 149 | return self.process_frames(frames) 150 | 151 | def get_video_frames(self, vr: VideoReader, video_idx: int) -> np.ndarray: 152 | if len(vr) < MAX_FRAMES: 153 | return vr.get_batch(list(range(len(vr)))).asnumpy() 154 | max_idx = min(video_idx + MAX_FRAMES, len(vr)) 155 | return vr.get_batch(list(range(video_idx, max_idx))).asnumpy() 156 | 157 | def process_frames(self, frames: np.ndarray) -> np.ndarray: 158 | frames = np.array([frame[56:-56,56:-56,:] for frame in frames]) 159 | 160 | if not self.rgb: 161 | bg_frames = np.array([cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) for frame in frames]).astype(np.float32) 162 | else: 163 | bg_frames = frames.astype(np.float32) 164 | bg_frames /= 255.0 165 | 166 | if len(bg_frames) < MAX_FRAMES: 167 | pad_shape = (MAX_FRAMES - len(bg_frames), self.img_size, self.img_size, 3) if self.rgb else ( 168 | MAX_FRAMES - len(bg_frames), self.img_size, self.img_size) 169 | bg_frames = np.concatenate((bg_frames, np.zeros(pad_shape, dtype=bg_frames.dtype)), axis=0) 170 | 171 | return bg_frames[np.newaxis, ...] if not self.rgb else bg_frames 172 | 173 | 174 | class AVSE4DataModule(LightningDataModule): 175 | def __init__(self, data_root, batch_size=16, audio_norm=False, rgb=True, num_channels=1): 176 | super().__init__() 177 | self.train_dataset_batch = AVSE4Dataset(join(data_root, "train/scenes"), rgb=rgb, shuffle=True, 178 | num_channels=num_channels,clipped_batch=True, sample_items=True, 179 | audio_norm=audio_norm) 180 | self.dev_dataset_batch = AVSE4Dataset(join(data_root, "dev/scenes"), rgb=rgb, 181 | num_channels=num_channels,clipped_batch=True, 182 | audio_norm=audio_norm) 183 | self.dev_dataset = AVSE4Dataset(join(data_root, "dev/scenes"), clipped_batch=True, rgb=rgb, 184 | num_channels=num_channels, sample_items=False, 185 | audio_norm=audio_norm) 186 | self.eval_dataset = AVSE4Dataset(join(data_root, "dev/scenes"), clipped_batch=False, rgb=rgb, 187 | num_channels=num_channels, 188 | audio_norm=audio_norm, sample_items=False, 189 | test_set=True) 190 | self.batch_size = batch_size 191 | 192 | def train_dataloader(self): 193 | assert len(self.train_dataset_batch) > 0, "No training data found" 194 | return torch.utils.data.DataLoader(self.train_dataset_batch, batch_size=self.batch_size, num_workers=4, 195 | pin_memory=True, persistent_workers=True) 196 | 197 | def val_dataloader(self): 198 | assert len(self.dev_dataset_batch) > 0, "No validation data found" 199 | return torch.utils.data.DataLoader(self.dev_dataset_batch, batch_size=self.batch_size, num_workers=4, 200 | pin_memory=True, persistent_workers=True) 201 | 202 | def test_dataloader(self): 203 | return torch.utils.data.DataLoader(self.eval_dataset, batch_size=self.batch_size, num_workers=4) 204 | 205 | 206 | if __name__ == '__main__': 207 | dataset = AVSE4DataModule(data_root="/home/m_gogate/data/avsec4", batch_size=1, 208 | audio_norm=False, rgb=True, 209 | num_channels=2).train_dataset_batch 210 | for i in tqdm(range(len(dataset)), ascii=True): 211 | data = dataset[i] 212 | for k, v in data.items(): 213 | if isinstance(v, np.ndarray): 214 | print(k, v.shape, v.dtype) 215 | break -------------------------------------------------------------------------------- /baseline/avse4/test.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile 2 | from os import makedirs 3 | from os.path import join 4 | 5 | import soundfile as sf 6 | import torch 7 | from tqdm import tqdm 8 | from omegaconf import DictConfig 9 | import hydra 10 | 11 | from dataset import AVSE4DataModule 12 | from model import AVSE4BaselineModule 13 | 14 | SAMPLE_RATE = 16000 15 | 16 | @hydra.main(config_path="conf", config_name="eval", version_base="1.2") 17 | def main(cfg: DictConfig): 18 | enhanced_root = join(cfg.save_dir, cfg.model_uid) 19 | makedirs(cfg.save_dir, exist_ok=True) 20 | makedirs(enhanced_root, exist_ok=True) 21 | datamodule = AVSE4DataModule(data_root=cfg.data.root,batch_size=1,rgb=cfg.data.rgb, 22 | num_channels=cfg.data.num_channels, audio_norm=cfg.data.audio_norm) 23 | if cfg.data.dev_set and cfg.data.eval_set: 24 | raise RuntimeError("Select either dev set or test set") 25 | elif cfg.data.dev_set: 26 | dataset = datamodule.dev_dataset 27 | elif cfg.data.eval_set: 28 | dataset = datamodule.eval_dataset 29 | else: 30 | raise RuntimeError("Select one of dev set and test set") 31 | try: 32 | model = AVSE4BaselineModule.load_from_checkpoint(cfg.ckpt_path) 33 | print("Model loaded") 34 | except Exception as e: 35 | raise FileNotFoundError("Cannot load model weights: {}".format(cfg.ckpt_path)) 36 | if not cfg.cpu: 37 | model.to("cuda:0") 38 | model.eval() 39 | with torch.no_grad(): 40 | for i in tqdm(range(len(dataset))): 41 | data = dataset[i] 42 | filename = f"{data['scene']}.wav" 43 | enhanced_path = join(enhanced_root, filename) 44 | if not isfile(enhanced_path): 45 | clean, noisy, estimated_audio = model.enhance(data) 46 | sf.write(enhanced_path, estimated_audio.T, samplerate=SAMPLE_RATE) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /baseline/avse4/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | from utils import seed_everything 5 | seed_everything(1143) 6 | import torch 7 | torch.set_float32_matmul_precision('medium') 8 | from omegaconf import DictConfig 9 | import hydra 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 12 | from model import AVSE4BaselineModule 13 | from dataset import AVSE4DataModule 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | @hydra.main(config_path="conf", config_name="train", version_base="1.2") 19 | def main(cfg: DictConfig): 20 | checkpoint_callback = ModelCheckpoint(monitor="val_loss_epoch", 21 | filename="model-{epoch:02d}-{val_loss:.3f}", save_top_k=2, save_last=True) 22 | callbacks = [checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=6)] 23 | datamodule = AVSE4DataModule(data_root=cfg.data.root, batch_size=cfg.data.batch_size, 24 | audio_norm=cfg.data.audio_norm, rgb=cfg.data.rgb, 25 | num_channels=cfg.data.num_channels) 26 | model = AVSE4BaselineModule(num_channels=cfg.data.num_channels) 27 | 28 | trainer = Trainer(default_root_dir=cfg.trainer.log_dir, 29 | callbacks=callbacks, deterministic=cfg.trainer.deterministic, 30 | log_every_n_steps=cfg.trainer.log_every_n_steps, 31 | fast_dev_run=cfg.trainer.fast_dev_run, devices=cfg.trainer.gpus, 32 | accelerator=cfg.trainer.accelerator, 33 | precision=cfg.trainer.precision, strategy=cfg.trainer.strategy, 34 | max_epochs=cfg.trainer.max_epochs, 35 | accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, 36 | detect_anomaly=cfg.trainer.detect_anomaly, 37 | limit_train_batches=cfg.trainer.limit_train_batches, 38 | limit_val_batches=cfg.trainer.limit_val_batches, 39 | num_sanity_val_steps=cfg.trainer.num_sanity_val_steps, 40 | gradient_clip_val=cfg.trainer.gradient_clip_val, 41 | profiler=cfg.trainer.profiler 42 | ) 43 | start = time.time() 44 | trainer.fit(model, datamodule, ckpt_path=cfg.trainer.ckpt_path) 45 | log.info(f"Time taken {time.time() - start} sec") 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /baseline/avse4/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from pytorch_lightning import LightningModule 8 | 9 | EPS = np.finfo(float).eps 10 | normMean = 0.4161 11 | normStd = 0.1688 12 | 13 | 14 | 15 | class VisualConv1D(nn.Module): 16 | def __init__(self): 17 | super(VisualConv1D, self).__init__() 18 | relu = nn.ReLU() 19 | norm_1 = nn.BatchNorm1d(512) 20 | dsconv = nn.Conv1d(512, 21 | 512, 22 | 3, 23 | stride=1, 24 | padding=1, 25 | dilation=1, 26 | groups=512, 27 | bias=False) 28 | prelu = nn.PReLU() 29 | norm_2 = nn.BatchNorm1d(512) 30 | pw_conv = nn.Conv1d(512, 512, 1, bias=False) 31 | 32 | self.net = nn.Sequential(relu, norm_1, dsconv, prelu, norm_2, pw_conv) 33 | 34 | def forward(self, x): 35 | out = self.net(x) 36 | return out + x 37 | 38 | 39 | 40 | class ResNetLayer(nn.Module): 41 | def __init__(self, inplanes, outplanes, stride): 42 | super(ResNetLayer, self).__init__() 43 | self.conv1a = nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 45 | self.conv2a = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False) 46 | self.stride = stride 47 | self.downsample = nn.Conv2d(inplanes, outplanes, kernel_size=(1,1), stride=stride, bias=False) 48 | self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 49 | 50 | self.conv1b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False) 51 | self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 52 | self.conv2b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False) 53 | self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 54 | 55 | 56 | def forward(self, inputBatch): 57 | batch = F.relu(self.bn1a(self.conv1a(inputBatch))) 58 | batch = self.conv2a(batch) 59 | if self.stride == 1: 60 | residualBatch = inputBatch 61 | else: 62 | residualBatch = self.downsample(inputBatch) 63 | batch = batch + residualBatch 64 | intermediateBatch = batch 65 | batch = F.relu(self.outbna(batch)) 66 | 67 | batch = F.relu(self.bn1b(self.conv1b(batch))) 68 | batch = self.conv2b(batch) 69 | residualBatch = intermediateBatch 70 | batch = batch + residualBatch 71 | outputBatch = F.relu(self.outbnb(batch)) 72 | return outputBatch 73 | 74 | 75 | 76 | class ResNet(nn.Module): 77 | def __init__(self): 78 | super(ResNet, self).__init__() 79 | self.layer1 = ResNetLayer(64, 64, stride=1) 80 | self.layer2 = ResNetLayer(64, 128, stride=2) 81 | self.layer3 = ResNetLayer(128, 256, stride=2) 82 | self.layer4 = ResNetLayer(256, 512, stride=2) 83 | self.avgpool = nn.AvgPool2d(kernel_size=(4,4), stride=(1,1)) 84 | return 85 | 86 | 87 | def forward(self, inputBatch): 88 | batch = self.layer1(inputBatch) 89 | batch = self.layer2(batch) 90 | batch = self.layer3(batch) 91 | batch = self.layer4(batch) 92 | outputBatch = self.avgpool(batch) 93 | return outputBatch 94 | 95 | 96 | 97 | class VisualFrontend(LightningModule): 98 | def __init__(self): 99 | super(VisualFrontend, self).__init__() 100 | self.frontend3D = nn.Sequential( 101 | nn.Conv3d(1, 64, kernel_size=(5,7,7), stride=(1,2,2), padding=(2,3,3), bias=False), 102 | nn.BatchNorm3d(64, momentum=0.01, eps=0.001), 103 | nn.ReLU(), 104 | nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) 105 | ) 106 | self.resnet = ResNet() 107 | return 108 | 109 | 110 | def forward(self, inputBatch): 111 | batchsize = inputBatch.shape[0] 112 | inputBatch = (inputBatch - normMean) / normStd 113 | batch = self.frontend3D(inputBatch) 114 | batch = batch.transpose(1, 2) 115 | batch = batch.reshape(batch.shape[0]*batch.shape[1], batch.shape[2], batch.shape[3], batch.shape[4]) 116 | outputBatch = self.resnet(batch) 117 | outputBatch = outputBatch.reshape(batchsize, -1, 512) 118 | return outputBatch 119 | 120 | 121 | 122 | def subsample_list(inp_list: list, sample_rate: float): 123 | random.shuffle(inp_list) 124 | return [inp_list[i] for i in range(int(len(inp_list) * sample_rate))] 125 | 126 | def seed_everything(seed: int): 127 | random.seed(seed) 128 | np.random.seed(seed) 129 | torch.manual_seed(seed) 130 | torch.backends.cudnn.deterministic = False 131 | torch.backends.cudnn.benchmark = True 132 | 133 | 134 | def is_clipped(audio, clipping_threshold=0.99): 135 | return any(abs(audio) > clipping_threshold) 136 | 137 | 138 | def rms_normalize(audio, target_level=-25): 139 | rms = (audio ** 2).mean() ** 0.5 140 | scalar = 10 ** (target_level / 20) / (rms + EPS) 141 | audio = audio * scalar 142 | return audio -------------------------------------------------------------------------------- /data_preparation/avse1/README.md: -------------------------------------------------------------------------------- 1 | # Audio-Visual Speech Enhancement Challenge (AVSE) 2 | 3 | Human performance in everyday noisy situations is known to be dependent upon both aural and visual senses that are contextually combined by the brain’s multi-level integration strategies. The multimodal nature of speech is well established, with listeners known to unconsciously lip read to improve the intelligibility of speech in a real noisy environment. Studies in neuroscience have shown that the visual aspect of speech has a potentially strong impact on the ability of humans to focus their auditory attention on a particular stimulus. 4 | 5 | Over the last few decades, there have been major advances in machine learning applied to speech technology made possible by Machine Learning related Challenges including CHiME, REVERB, Blizzard, Clarity and Hurricane. However, the aforementioned challenges are based on single and multi-channel audio-only processing and have not exploited the multimodal nature of speech. The aim of this first audio visual (AV) speech enhancement challenge is to bring together the wider computer vision, hearing and speech research communities to explore novel approaches to multimodal speech-in-noise processing. 6 | 7 | In this repository, you will find code to support the AVSE Challenge, including the baseline and scripts for preparing the necessary data. 8 | 9 | More details can be found on the challenge website: 10 | https://challenge.cogmhear.org 11 | 12 | ## Announcements 13 | 14 | Any announcements about the challenge will be made in our mailing list (avse-challenge@mlist.is.ed.ac.uk). 15 | See [here](https://challenge.cogmhear.org/#/docs?id=announcements) on how to subscribe to it. 16 | 17 | ## Installation 18 | 19 | ```bash 20 | # Clone repository 21 | git clone https://github.com/cogmhear/avse-challenge.git 22 | cd avse-challenge 23 | 24 | # Create & activate environment with conda, see https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html 25 | conda create --name avse python=3.8 26 | conda activate avse 27 | 28 | # Install ffmpeg 2.8 29 | conda install -c rmg ffmpeg 30 | 31 | # Install requirements 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Data preparation 36 | 37 | These scripts should be run in a unix environment and require an installed version of the [ffmpeg](https://www.ffmpeg.org) tool (required version 2.8; see Installation for the correct installation command). 38 | 39 | 1) Download necessary data: 40 | - target videos: 41 | Lip Reading Sentences 3 (LRS3) Dataset 42 | https://mm.kaist.ac.kr/datasets/lip_reading/ 43 | 44 | Follow the instructions on the website to obtain credentials to download the videos. 45 | - noise maskers and metadata (AVSEC-3): 46 | https://data.cstr.ed.ac.uk/cogmhear/protected/avsec3_data.tar 47 | Please register for the AVSE challenge to obtain the download credentials: [registration form](https://challenge.cogmhear.org/#/getting-started/register) 48 | 49 | Noise maskers and metadata (AVSEC-1 and AVSEC-2): https://data.cstr.ed.ac.uk/cogmhear/protected/avse2_data.tar 50 | 51 | **Note that the AVSEC-2 dataset is identical to that used in the 1st edition of the Challenge, ** 52 | 53 | 2) Set up data structure and create speech maskers (see EDIT_THIS to change local paths): 54 | ```bash 55 | cd data_preparation/avse1 56 | ./setup_avse1_data.sh 57 | ``` 58 | 59 | 3) Change root path defined in [data_preparation/avse1/data_config.yaml](data_preparation/avse1/data_config.yaml) to the location of the data. 60 | 61 | 4) Prepare noisy data: 62 | ```bash 63 | cd data_preparation/avse1 64 | python prepare_avse1_data.py 65 | ``` 66 | 67 | ## Baseline 68 | 69 | [code](./baseline/avse1/) 70 | 71 | [pretrained_model](https://data.cstr.ed.ac.uk/cogmhear/protected/avse1_baseline.ckpt) 72 | 73 | The credentials to download the pretrained model are the same as the ones used to download the noise maskers and the metadata. 74 | 75 | ## Evaluation 76 | 77 | We provide a script to extract STOI and PESQ for the devset. 78 | 79 | Note: before running this script please edit the paths and file name formats defined in evaluation/avse1/config.yaml (see EDIT_THIS). 80 | 81 | ``` 82 | cd evaluation/avse1/ 83 | python objective_evaluation.py 84 | ``` 85 | 86 | that require the following libraries: 87 | ``` 88 | pip install pystoi==0.3.3 89 | pip install pesq==0.0.4 90 | ``` 91 | 92 | ## Challenges 93 | 94 | Current challenge 95 | 96 | - The 1st Audio-Visual Speech Enhancement Challenge (AVSE1) 97 | [data_preparation](./data_preparation/avse1/) 98 | [baseline](./baseline/avse1/) 99 | [evaluation](./evaluation/avse1/) 100 | 101 | ## License 102 | 103 | Videos are derived from: 104 | - [LRS3 dataset](https://mm.kaist.ac.kr/datasets/lip_reading/) 105 | Creative Commons BY-NC-ND 4.0 license 106 | 107 | Interferers are derived from: 108 | - [Clarity Enhancement Challenge (CEC1)](https://github.com/claritychallenge/clarity/tree/main/recipes/cec1) 109 | Creative Commons Attribution Share Alike 4.0 International 110 | 111 | - [DEMAND](https://zenodo.org/record/1227121#.YpZHLRPMLPY): 112 | Creative Commons Attribution 4.0 International 113 | 114 | - [DNS Challenge second edition](https://github.com/microsoft/DNS-Challenge). 115 | Only Freesound clips were selected 116 | Creative Commons 0 License 117 | 118 | - [LRS3 dataset](https://mm.kaist.ac.kr/datasets/lip_reading/) 119 | Creative Commons BY-NC-ND 4.0 license 120 | 121 | - [MedleyDB audio](https://medleydb.weebly.com/) 122 | 123 | The dataset is licensed under CC BY-NC-SA 4.0. 124 | 125 | - [ESC-50 Dataset for Environmental Sound Classification](https://github.com/karolpiczak/ESC-50) 126 | 127 | Creative Commons Attribution-NonCommercial license 128 | 129 | Data preparation scripts were adapted from original code by [Clarity Challenge](https://github.com/claritychallenge/clarity). Modifications include: extracting target target audio from video and different settings for sampling rate (16kHz), number of channels (one channel) and scenes simulation (additive noise only, no room impulse responses and no room simulation). 130 | 131 | 132 | -------------------------------------------------------------------------------- /data_preparation/avse1/build_scenes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Adapted from original code by Clarity Challenge 4 | https://github.com/claritychallenge/clarity 5 | ''' 6 | 7 | import os 8 | import logging 9 | 10 | import hydra 11 | from omegaconf import DictConfig 12 | from scene_builder_avse1 import SceneBuilder, set_random_seed 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | def instantiate_scenes(cfg): 17 | set_random_seed(cfg.random_seed) 18 | for dataset in cfg.scene_datasets: 19 | scene_file = os.path.join(cfg.metadata_dir, f"scenes.{dataset}.json") 20 | if not os.path.exists(scene_file): 21 | logger.info(f"instantiate scenes for {dataset} set") 22 | sb = SceneBuilder( 23 | scene_datasets=cfg.scene_datasets[dataset], 24 | target=cfg.target, 25 | interferer=cfg.interferer, 26 | snr_range=cfg.snr_range[dataset], 27 | ) 28 | sb.instantiate_scenes(dataset=dataset) 29 | sb.save_scenes(scene_file) 30 | else: 31 | logger.info(f"scenes.{dataset}.json exists, skip") 32 | 33 | 34 | @hydra.main(config_path=".", config_name="data_config") 35 | def run(cfg: DictConfig) -> None: 36 | logger.info("Instantiating scenes") 37 | instantiate_scenes(cfg) 38 | 39 | 40 | if __name__ == "__main__": 41 | run() 42 | -------------------------------------------------------------------------------- /data_preparation/avse1/create_speech_maskers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Create speech maskers data 4 | ''' 5 | import sys 6 | import numpy as np 7 | import os 8 | import glob 9 | import json 10 | import soundfile as sf 11 | from tqdm import tqdm 12 | from concurrent.futures import ProcessPoolExecutor 13 | 14 | fs = 16000 15 | 16 | def create_dir(directory): 17 | if not os.path.exists(directory): 18 | os.makedirs(directory) 19 | 20 | def create_speech_maskers(datadir, metafile, wavdir): 21 | 22 | with open(metafile, "r") as f: 23 | maskers = json.load(f) 24 | 25 | futures = [] 26 | ncores = 20 27 | with ProcessPoolExecutor(max_workers=ncores) as executor: 28 | for masker in maskers: 29 | futures.append(executor.submit(create_masker_for_spk, datadir, wavdir, masker['speaker'])) 30 | proc_list = [future.result() for future in tqdm(futures)] 31 | 32 | def create_masker_for_spk(datadir, wavdir, spk): 33 | 34 | create_dir(f"{wavdir}/{spk}/") 35 | 36 | # Extract audio from videos and join them into one long masker file 37 | y = np.array([]) 38 | for file in glob.iglob(f'{datadir}/*train*/{spk}/*.mp4'): 39 | basename = os.path.basename(file).split('.')[0] 40 | target_fn = f"{wavdir}/{spk}/{basename}.wav" 41 | command = ("ffmpeg -v 8 -y -i %s -vn -acodec pcm_s16le -ar %s -ac 1 %s < /dev/null" % (file, str(fs), target_fn)) 42 | os.system(command) 43 | x = sf.read(target_fn)[0] 44 | y = np.concatenate((y, x), axis=-1) 45 | sf.write(f"{wavdir}/{spk}.wav", y, fs) 46 | 47 | command = ("rm -r %s/%s" % (wavdir,spk)) 48 | os.system(command) 49 | 50 | if __name__ == '__main__': 51 | 52 | datadir = sys.argv[1] # '/group/corpora/public/lipreading/LRS3/' 53 | metafile = sys.argv[2] # '/disk/scratch6/cvbotinh/av2022/av2022_data/metadata/masker_speech_list.json' 54 | wavdir = sys.argv[3] # '/disk/scratch6/cvbotinh/av2022/av2022_data/maskers_speech/' 55 | 56 | # Create speech masker files 57 | create_speech_maskers(datadir, metafile, wavdir) 58 | 59 | -------------------------------------------------------------------------------- /data_preparation/avse1/data_config.yaml: -------------------------------------------------------------------------------- 1 | root: /tmp/avse1_data/ 2 | input_path: ${root}/ 3 | metadata_dir: ${input_path}/metadata/ 4 | 5 | random_seed: 0 6 | 7 | target: 8 | target_speakers: ${metadata_dir}/target_speech_list.json 9 | target_selection: SEQUENTIAL 10 | 11 | snr_range: 12 | train: 13 | speech: [-15, 5] 14 | noise: [-10, 10] 15 | dev: 16 | speech: [-15, 5] 17 | noise: [-10, 10] 18 | 19 | interferer: 20 | speech_interferers: ${metadata_dir}/masker_speech_list.json 21 | noise_interferers: ${metadata_dir}/masker_noise_list.json 22 | number: [1] 23 | start_time_range: [0, 0] 24 | end_early_time_range: [0, 0] 25 | 26 | # Instantiate_scenes 27 | scene_datasets: 28 | train: 29 | n_scenes: 34525 30 | scene_start_index: 1 31 | dev: 32 | n_scenes: 3365 33 | scene_start_index: 34526 34 | 35 | datasets: 36 | train: 37 | metafile_path: ${metadata_dir}/scenes.train.json 38 | scene_folder: ${input_path}/train/scenes/ 39 | dev: 40 | metafile_path: ${metadata_dir}/scenes.dev.json 41 | scene_folder: ${input_path}/dev/scenes/ 42 | 43 | num_channels: 1 44 | 45 | fs: 16000 46 | 47 | scene_renderer: 48 | train: 49 | metadata: 50 | scene_definitions: ${path.metadata_dir}/scenes.train.json 51 | chunk_size: 12 # TOCHECK 52 | dev: 53 | metadata: 54 | scene_definitions: ${path.metadata_dir}/scenes.dev.json 55 | chunk_size: 5 # TOCHECK 56 | 57 | # disable hydra loggings 58 | defaults: 59 | - override hydra/job_logging: disabled 60 | 61 | hydra: 62 | output_subdir: Null 63 | run: 64 | dir: . -------------------------------------------------------------------------------- /data_preparation/avse1/prepare_avse1_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Adapted from original code by Clarity Challenge 4 | https://github.com/claritychallenge/clarity 5 | ''' 6 | 7 | import os 8 | import json 9 | import logging 10 | from tqdm import tqdm 11 | import hydra 12 | from omegaconf import DictConfig, OmegaConf 13 | from concurrent.futures import ProcessPoolExecutor 14 | 15 | from scene_renderer_avse1 import Renderer, check_scene_exists 16 | 17 | def run_renderer(renderer, scene, scene_folder): 18 | 19 | if check_scene_exists(scene, scene_folder): 20 | logging.info(f"Skipping processed scene {scene['scene']}.") 21 | else: 22 | renderer.render( 23 | dataset=scene["dataset"], 24 | target=scene["target"]["name"], 25 | noise_type=scene["interferer"]["type"], 26 | interferer=scene["interferer"]["name"], 27 | scene=scene["scene"], 28 | offset=scene["interferer"]["offset"], 29 | snr_dB=scene["SNR"], 30 | ) 31 | 32 | def prepare_data( 33 | root_path, metafile_path, scene_folder, num_channels, fs, 34 | ): 35 | """ 36 | Generate scene data given dataset (train or dev) 37 | Args: 38 | root_path: AVSE root path 39 | metafile_path: scene metafile path 40 | scene_folder: folder containing generated scenes 41 | num_channels: number of channels 42 | fs: sampling frequency (Hz) 43 | """ 44 | with open(metafile_path, "r") as f: 45 | scenes = json.load(f) 46 | 47 | os.makedirs(scene_folder, exist_ok=True) 48 | 49 | renderer = Renderer(input_path=root_path, output_path=scene_folder, num_channels=num_channels,fs=fs) 50 | 51 | # for scene in scenes: 52 | # run_renderer(renderer, scene, scene_folder) 53 | 54 | futures = [] 55 | ncores = 20 56 | with ProcessPoolExecutor(max_workers=ncores) as executor: 57 | for scene in scenes: 58 | futures.append(executor.submit(run_renderer,renderer, scene, scene_folder)) 59 | proc_list = [future.result() for future in tqdm(futures)] 60 | 61 | @hydra.main(config_path=".", config_name="data_config") 62 | def run(cfg: DictConfig) -> None: 63 | for dataset in cfg["datasets"]: 64 | prepare_data( 65 | cfg["input_path"], 66 | cfg["datasets"][dataset]["metafile_path"], 67 | cfg["datasets"][dataset]["scene_folder"], 68 | cfg["num_channels"], 69 | cfg["fs"], 70 | ) 71 | 72 | if __name__ == "__main__": 73 | run() 74 | -------------------------------------------------------------------------------- /data_preparation/avse1/requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | hydra-core==1.1.1 3 | hydra-submitit-launcher==1.1.6 4 | librosa==0.8.1 5 | matplotlib==3.5.1 6 | numpy==1.20.3 7 | omegaconf==2.1.1 8 | pandas==1.3.5 9 | pyloudnorm==0.1.0 10 | scikit-learn==1.0.2 11 | scipy==1.7.3 12 | SoundFile==0.10.3.post1 13 | tqdm==4.62.3 14 | pesq==0.0.4 15 | pystoi==0.3.3 -------------------------------------------------------------------------------- /data_preparation/avse1/scene_builder_avse1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Adapted from original code by Clarity Challenge 4 | https://github.com/claritychallenge/clarity 5 | ''' 6 | 7 | """Code for building the scenes.json files.""" 8 | import itertools 9 | import json 10 | import logging 11 | import math 12 | import random 13 | import re 14 | from enum import Enum 15 | 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | # A logger for this file 20 | log = logging.getLogger(__name__) 21 | 22 | # Get json output to round to 4 dp 23 | json.encoder.c_make_encoder = None 24 | 25 | class RoundingFloat(float): 26 | """Round a float to 4 decimal places.""" 27 | __repr__ = staticmethod(lambda x: format(x, ".4f")) 28 | 29 | json.encoder.float = RoundingFloat 30 | 31 | def set_random_seed(random_seed): 32 | if random_seed is not None: 33 | random.seed(random_seed) 34 | np.random.seed(random_seed) 35 | 36 | def add_this_target_to_scene(target, scene): 37 | """Add the target details to the scene dict. 38 | Adds given target to given scene. Target details will be taken 39 | from the target dict but the start time will be 40 | according to the AVSE1 target start time specification. 41 | Args: 42 | target (dict): target dict read from target metadata file 43 | scene (dict): complete scene dict 44 | """ 45 | scene_target = {} 46 | scene_target["name"] = target["wavfile"] 47 | scene["target"] = scene_target 48 | scene["duration"] = target["nsamples"] 49 | 50 | # SNR handling 51 | def generate_snr(snr_range): 52 | """Generate a random SNR.""" 53 | return random.uniform(*snr_range) 54 | 55 | # Interferer handling 56 | class InterfererType(Enum): 57 | """Enum for interferer types.""" 58 | 59 | SPEECH = "speech" 60 | NOISE = "noise" 61 | 62 | def select_interferer_types(allowed_n_interferers): 63 | """Select the interferer types to use. 64 | The number of interferer is drawn randomly 65 | Args: 66 | allowed_n_interferers (list): list of allowed number of interferers 67 | Returns: 68 | list(InterfererType): list of interferer types to use 69 | """ 70 | 71 | n_interferers = random.choice(allowed_n_interferers) 72 | selection = None 73 | while selection is None: 74 | selection = random.choices(list(InterfererType), k=n_interferers) 75 | return selection 76 | 77 | 78 | def select_random_interferer(interferers, dataset, required_samples): 79 | """Randomly select an interferer. 80 | Interferers stored as list of list. First randomly select a sublist 81 | then randomly select an item from sublist matching constraints. 82 | Args: 83 | interferers (list(list)): interferers as list of lists 84 | dataset (str): desired data [train, dev, eval] 85 | required_samples (int): required number of samples 86 | Raises: 87 | ValueError: if no suitable interferer is found 88 | Returns: 89 | dict: the interferer dict 90 | """ 91 | interferer_not_found = True 92 | while interferer_not_found: 93 | interferer_group = random.choice(interferers) 94 | filtered_interferer_group = [ 95 | i 96 | for i in interferer_group 97 | if i["dataset"] == dataset and i["nsamples"] >= required_samples 98 | ] 99 | 100 | if filtered_interferer_group: 101 | interferer = random.choice(filtered_interferer_group) 102 | interferer_not_found = False 103 | # else: 104 | # if interferer_group[0]['type']=="noise": 105 | # print(f"No suitable interferer found in class {interferer_group[0]['class']} for required samples {required_samples}") 106 | # else: 107 | # print(f"No suitable interferer found in class {interferer_group[0]['speaker']} for required samples {required_samples}") 108 | 109 | return interferer 110 | 111 | 112 | def get_random_interferer_offset(interferer, required_samples): 113 | """Generate a random offset sample for interferer. 114 | The offset sample is the point within the masker signal at which the interferer 115 | segment will be extracted. Randomly selected but with care for it not to start 116 | too late, i.e. such that the required samples would overrun the end of the masker 117 | signal will be used is taken. 118 | Args: 119 | interferer (dict): the interferer metadata 120 | required_samples (int): number of samples that is going to be required 121 | Returns: 122 | int: a valid randomly selected offset 123 | """ 124 | masker_nsamples = interferer["nsamples"] 125 | latest_start = masker_nsamples - required_samples 126 | if latest_start < 0: 127 | log.error(f"Interferer {interferer['ID']} does not has enough samples.") 128 | 129 | assert ( 130 | latest_start >= 0 131 | ) # This should never happen - mean masker was too short for the scene 132 | return random.randint(0, latest_start) 133 | 134 | 135 | def add_interferer_to_scene_inner( 136 | scene, interferers, number, start_time_range, end_early_time_range 137 | ): 138 | """Randomly select interferers and add them to the given scene. 139 | A random number of interferers is chosen, then each is given a random type 140 | selected from the possible speech, nonspeech types. 141 | Interferers are then chosen from the available lists according to the type 142 | and also taking care to match the scenes 'dataset' field, ie. train, dev, test. 143 | The interferer data is supplied as a dictionary of lists of lists. The key 144 | being "speech", "nonspeech", and the list of list being a partitioned 145 | list of interferers for that type. 146 | The idea of using a list of lists is that interferers can be split by 147 | subcondition and then the randomization draws equally from each subcondition, 148 | e.g. for nonspeech there is "washing machine", "microwave" etc. This ensures that 149 | each subcondition is equally represented even if the number of exemplars of 150 | each subcondition is different. 151 | Note, there is no return. The scene is modified in place. 152 | Args: 153 | scene (dict): the scene description 154 | interferers (dict): the interferer metadata 155 | number: number of interferers 156 | start_time_range: when to start 157 | end_early_time_range: when to end 158 | """ 159 | dataset = scene["dataset"] 160 | selected_interferer_types = select_interferer_types(number) 161 | n_interferers = len(selected_interferer_types) 162 | 163 | 164 | scene["interferer"] = [{"type": scene_type.value} for scene_type in selected_interferer_types] 165 | 166 | # Randomly instantiate each interferer in the scene 167 | for scene_interferer, scene_type in zip( 168 | scene["interferer"], selected_interferer_types 169 | ): 170 | desired_start_time = random.randint(*start_time_range) 171 | 172 | scene_interferer["time_start"] = min(scene["duration"], desired_start_time) 173 | desired_end_time = scene["duration"] - random.randint(*end_early_time_range) 174 | 175 | scene_interferer["time_end"] = max( 176 | scene_interferer["time_start"], desired_end_time 177 | ) 178 | 179 | required_samples = scene_interferer["time_end"] - scene_interferer["time_start"] 180 | interferer = select_random_interferer( 181 | interferers[scene_type], dataset, required_samples 182 | ) 183 | # scene_interferer["type"] = scene_type.value 184 | scene_interferer["name"] = interferer["ID"] 185 | scene_interferer["offset"] = get_random_interferer_offset( 186 | interferer, required_samples 187 | ) 188 | 189 | scene["interferer"] = scene["interferer"][0] 190 | 191 | class SceneBuilder: 192 | """Functions for building a list scenes.""" 193 | 194 | def __init__( 195 | self, 196 | scene_datasets, 197 | target, 198 | interferer, 199 | snr_range, 200 | ): 201 | self.scenes = [] 202 | self.scene_datasets = scene_datasets 203 | self.target = target 204 | self.interferer = interferer 205 | self.snr_range = snr_range 206 | 207 | def save_scenes(self, filename): 208 | """Save the list of scenes to a json file.""" 209 | scenes = [s for s in self.scenes] 210 | # Replace the room structure with the room ID 211 | # for scene in scenes: 212 | # scene["room"] = scene["room"]["name"] 213 | json.dump(self.scenes, open(filename, "w"), indent=2) 214 | 215 | def instantiate_scenes(self, dataset): 216 | print(f"Initialise {dataset} scenes") 217 | self.initialise_scenes(dataset, **self.scene_datasets) 218 | print("adding targets to scenes") 219 | self.add_target_to_scene(dataset, **self.target) 220 | print("adding interferers to scenes") 221 | self.add_interferer_to_scene(**self.interferer) 222 | print("assigning an SNR to each scene") 223 | self.add_SNR_to_scene(self.snr_range) 224 | 225 | def initialise_scenes(self, dataset, n_scenes, scene_start_index): 226 | """ 227 | Initialise the scenes for a given dataset. 228 | Args: 229 | dataset: train, dev, or eval set 230 | n_scenes: number of scenes to generate 231 | scene_start_index: index to start for scene IDs 232 | """ 233 | 234 | # Construct the scenes adding the room and dataset label 235 | self.scenes = [] 236 | scenes = [{"dataset": dataset} for _ in range(n_scenes)] 237 | 238 | # Set the scene ID 239 | for index, scene in enumerate(scenes, scene_start_index): 240 | scene["scene"] = f"S{index:05d}" 241 | self.scenes.extend(scenes) 242 | 243 | def add_target_to_scene( 244 | self, 245 | dataset, 246 | target_speakers, 247 | target_selection, 248 | ): 249 | """Add target info to the scenes. 250 | Target speaker file set via config. 251 | Raises: 252 | Exception: _description_ 253 | """ 254 | targets = json.load(open(target_speakers, "r")) 255 | 256 | targets_dataset = [t for t in targets if t["dataset"] == dataset] 257 | scenes_dataset = [s for s in self.scenes if s["dataset"] == dataset] 258 | 259 | random.shuffle(targets_dataset) 260 | 261 | if target_selection == "SEQUENTIAL": 262 | # Sequential mode: Cycle through targets sequentially 263 | for scene, target in zip(scenes_dataset, itertools.cycle(targets_dataset)): 264 | add_this_target_to_scene( 265 | target, scene 266 | ) 267 | elif target_selection == "RANDOM": 268 | # Random mode: randomly select target with replacement 269 | for scene in scenes_dataset: 270 | add_this_target_to_scene( 271 | random.choice(targets_dataset), 272 | scene, 273 | ) 274 | else: 275 | assert False, "Unknown target selection mode" 276 | 277 | def add_SNR_to_scene(self, snr_range): 278 | """Add the SNR info to the scenes.""" 279 | for scene in tqdm(self.scenes): 280 | scene["SNR"] = generate_snr(snr_range[scene["interferer"]["type"]]) 281 | scene["pre_samples"] = 0 282 | scene["post_samples"] = 0 283 | 284 | def add_interferer_to_scene( 285 | self, 286 | speech_interferers, 287 | noise_interferers, 288 | number, 289 | start_time_range, 290 | end_early_time_range, 291 | ): 292 | """Add interferer to the scene description file.""" 293 | # Load and prepare speech interferer metadata 294 | interferers_speech = json.load(open(speech_interferers, "r")) 295 | for interferer in interferers_speech: 296 | interferer["ID"] = ( 297 | interferer["speaker"] # + ".wav" 298 | ) # selection require a unique "ID" field 299 | # Selection process requires list of lists 300 | interferers_speech = [interferers_speech] 301 | 302 | # Load and prepare noise (i.e. noise) interferer metadata 303 | interferers_noise = json.load(open(noise_interferers, "r")) 304 | # for interferer in interferers_noise: 305 | # interferer["ID"] += ".wav" 306 | interferer_by_type = dict() 307 | for interferer in interferers_noise: 308 | interferer_by_type.setdefault(interferer["class"], []).append(interferer) 309 | interferers_noise = list(interferer_by_type.values()) 310 | 311 | interferers = { 312 | InterfererType.SPEECH: interferers_speech, 313 | InterfererType.NOISE: interferers_noise, 314 | } 315 | 316 | for scene in tqdm(self.scenes): 317 | add_interferer_to_scene_inner( 318 | scene, interferers, number, start_time_range, end_early_time_range 319 | ) 320 | -------------------------------------------------------------------------------- /data_preparation/avse1/scene_renderer_avse1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Adapted from original code by Clarity Challenge 4 | https://github.com/claritychallenge/clarity 5 | ''' 6 | 7 | import os 8 | import math 9 | import logging 10 | import numpy as np 11 | import soundfile 12 | 13 | from soundfile import SoundFile 14 | from scipy.signal import convolve 15 | 16 | from utils import speechweighted_snr, sum_signals, pad 17 | 18 | def create_dir(directory): 19 | if not os.path.exists(directory): 20 | os.makedirs(directory) 21 | 22 | class Renderer: 23 | """ 24 | SceneGenerator of AVSE1 training and development sets. The render() function generates all simulated signals for each 25 | scene given the parameters specified in the metadata/scenes.train.json or metadata/scenes.dev.json file. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | input_path, 31 | output_path, 32 | num_channels=1, 33 | fs=44100, 34 | ramp_duration=0.5, 35 | tail_duration=0.2, 36 | test_nbits=16, 37 | ): 38 | 39 | self.input_path = input_path 40 | self.output_path = output_path 41 | self.fs = fs 42 | self.ramp_duration = ramp_duration 43 | self.n_tail = int(tail_duration * fs) 44 | self.test_nbits = test_nbits 45 | self.floating_point = False 46 | 47 | self.channels = list(range(num_channels)) 48 | 49 | def read_signal( 50 | self, filename, offset=0, nsamples=-1, nchannels=0, offset_is_samples=False 51 | ): 52 | """Read a wavefile and return as numpy array of floats. 53 | Args: 54 | filename (string): Name of file to read 55 | offset (int, optional): Offset in samples or seconds (from start). Defaults to 0. 56 | nchannels: expected number of channel (default: 0 = any number OK) 57 | offset_is_samples (bool): measurement units for offset (default: False) 58 | Returns: 59 | ndarray: audio signal 60 | """ 61 | try: 62 | wave_file = SoundFile(filename) 63 | except: 64 | # Ensure incorrect error (24 bit) is not generated 65 | raise Exception(f"Unable to read {filename}.") 66 | 67 | if nchannels != 0 and wave_file.channels != nchannels: 68 | raise Exception( 69 | f"Wav file ({filename}) was expected to have {nchannels} channels." 70 | ) 71 | 72 | if wave_file.samplerate != self.fs: 73 | raise Exception(f"Sampling rate is not {self.fs} for filename {filename}.") 74 | 75 | if not offset_is_samples: # Default behaviour 76 | offset = int(offset * wave_file.samplerate) 77 | 78 | if offset != 0: 79 | wave_file.seek(offset) 80 | 81 | x = wave_file.read(frames=nsamples) 82 | return x 83 | 84 | def write_signal(self, filename, x, fs, floating_point=True): 85 | """Write a signal as fixed or floating point wav file.""" 86 | 87 | if fs != self.fs: 88 | logging.warning(f"Sampling rate mismatch: {filename} with sr={fs}.") 89 | # raise ValueError("Sampling rate mismatch") 90 | 91 | if floating_point is False: 92 | if self.test_nbits == 16: 93 | subtype = "PCM_16" 94 | # If signal is float and we want int16 95 | x *= 32768 96 | x = x.astype(np.dtype("int16")) 97 | assert np.max(x) <= 32767 and np.min(x) >= -32768 98 | elif self.test_nbits == 24: 99 | subtype = "PCM_24" 100 | else: 101 | subtype = "FLOAT" 102 | 103 | soundfile.write(filename, x, fs, subtype=subtype) 104 | 105 | def save_signal_16bit(self, filename, signal, fs, norm=1.0): 106 | """Saves a signal to a 16 bit wav file. 107 | Args: 108 | filename (string): filename 109 | signal (np.array): signal 110 | norm (float): normalisation factor 111 | """ 112 | signal /= norm 113 | n_clipped = np.sum(np.abs(signal) > 1.0) 114 | if n_clipped > 0: 115 | print("CLIPPED {} {} {}".format(norm,np.max(signal),np.min(signal))) 116 | logging.warning(f"Writing {filename}: {n_clipped} samples clipped") 117 | np.clip(signal, -1.0, 1.0, out=signal) 118 | signal_16 = (32767 * signal).astype(np.int16) 119 | 120 | # wavfile.write(filename, FS, signal_16) 121 | soundfile.write(filename, signal_16, fs, subtype="PCM_16") 122 | 123 | def apply_ramp(self, x, dur): 124 | """Apply half cosine ramp into and out of signal 125 | 126 | dur - ramp duration in seconds 127 | """ 128 | ramp = np.cos(np.linspace(math.pi, 2 * math.pi, int(self.fs * dur))) 129 | ramp = (ramp + 1) / 2 130 | y = np.array(x) 131 | y[0 : len(ramp)] *= ramp 132 | y[-len(ramp) :] *= ramp[::-1] 133 | return y 134 | 135 | def compute_snr(self, target, noise): 136 | """Return the SNR. 137 | Take the overlapping segment of the noise and get the speech-weighted 138 | better ear SNR. (Note, SNR is a ratio -- not in dB.) 139 | """ 140 | segment_target = target 141 | segment_noise = noise 142 | assert len(segment_target) == len(segment_noise) 143 | 144 | snr = speechweighted_snr(segment_target, segment_noise) 145 | 146 | return snr 147 | 148 | def render( 149 | self, 150 | target, 151 | noise_type, 152 | interferer, 153 | scene, 154 | offset, 155 | snr_dB, 156 | dataset, 157 | ): 158 | 159 | target_video_fn = f"{self.input_path}/{dataset}/targets_video/{target}.mp4" 160 | target_fn = f"{self.input_path}/{dataset}/targets/{target}.wav" 161 | 162 | target_fn_dir = os.path.dirname(target_fn) 163 | create_dir(target_fn_dir) 164 | 165 | command = ("ffmpeg -v 8 -y -i %s -vn -acodec pcm_s16le -ar %s -ac 1 %s < /dev/null" % (target_video_fn, str(self.fs), target_fn)) 166 | os.system(command) 167 | 168 | interferer_fn = ( 169 | f"{self.input_path}/{dataset}/interferers/{noise_type}/{interferer}.wav" 170 | ) 171 | 172 | target = self.read_signal(target_fn) 173 | 174 | interferer = self.read_signal( 175 | interferer_fn, offset=offset, nsamples=len(target), offset_is_samples=True 176 | ) 177 | 178 | if len(target) != len(interferer): 179 | logging.debug("Target and interferer have different lengths") 180 | 181 | # Apply 500ms half-cosine ramp 182 | interferer = self.apply_ramp(interferer, dur=self.ramp_duration) 183 | 184 | prefix = f"{self.output_path}/{scene}" 185 | 186 | snr_ref = None 187 | 188 | target_at_ear = target 189 | interferer_at_ear = interferer 190 | 191 | # Scale interferer to obtain SNR specified in scene description 192 | logging.info(f"Scaling interferer to obtain mixture SNR = {snr_dB} dB.") 193 | 194 | if snr_ref is None: 195 | # snr_ref computed for first channel in the list and then 196 | # same scaling applied to all 197 | snr_ref = self.compute_snr( 198 | target_at_ear, 199 | interferer_at_ear, 200 | ) 201 | 202 | if snr_ref == np.Inf: 203 | print(f"Scene {scene} was skipped") 204 | return 205 | 206 | # Apply snr_ref reference scaling to get 0 dB and then scale to target snr_dB 207 | interferer_at_ear = interferer_at_ear * snr_ref 208 | interferer_at_ear = interferer_at_ear * 10 ** ((-snr_dB) / 20) 209 | 210 | # Sum target and scaled and ramped interferer 211 | signal_at_ear = sum_signals([target_at_ear, interferer_at_ear]) 212 | outputs = [ 213 | (f"{prefix}_mixed.wav", signal_at_ear), 214 | (f"{prefix}_target.wav", target_at_ear), 215 | (f"{prefix}_interferer.wav", interferer_at_ear), 216 | ] 217 | all_signals = np.concatenate((signal_at_ear,target_at_ear,interferer_at_ear)) 218 | norm = np.max(np.abs(all_signals)) 219 | 220 | # Write all audio output files 221 | for (filename, signal) in outputs: 222 | self.save_signal_16bit(filename, signal, self.fs, norm=norm) 223 | 224 | # Write video file without audio stream 225 | output_video_fn = f"{prefix}_silent.mp4" 226 | command = f"ffmpeg -v 8 -i {target_video_fn} -c:v copy -an {output_video_fn} < /dev/null" 227 | os.system(command) 228 | 229 | def check_scene_exists(scene, output_path): 230 | """Checks correct dataset directory for full set of pre-existing files. 231 | 232 | Args: 233 | scene (dict): dictionary defining the scene to be generated. 234 | 235 | Returns: 236 | status: boolean value indicating whether scene signals exist 237 | or do not exist. 238 | 239 | """ 240 | 241 | pattern = f"{output_path}/{scene['scene']}" 242 | files_to_check = [ 243 | f"{pattern}_mixed.wav", 244 | f"{pattern}_target.wav", 245 | f"{pattern}_interferer.wav", 246 | ] 247 | 248 | scene_exists = True 249 | for filename in files_to_check: 250 | scene_exists = scene_exists and os.path.exists(filename) 251 | return scene_exists 252 | 253 | 254 | def main(): 255 | import json 256 | 257 | scene = json.load( 258 | open( 259 | "/tmp/avse1_data/metadata/scenes.train.json", 260 | "r", 261 | ) 262 | )[0] 263 | 264 | renderer = Renderer( 265 | input_path="/tmp/avse1_data/", 266 | output_path=".", 267 | num_channels=1, 268 | ) 269 | renderer.render( 270 | dataset=scene["dataset"], 271 | target=scene["target"]["name"], 272 | noise_type=scene["interferer"]["type"], 273 | interferer=scene["interferer"]["name"], 274 | scene=scene["scene"], 275 | offset=scene["interferer"]["offset"], 276 | snr_dB=scene["SNR"], 277 | ) 278 | 279 | -------------------------------------------------------------------------------- /data_preparation/avse1/setup_avse1_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # working directory where data will be stored 4 | root=/tmp/avse1_data/ # EDIT_THIS 5 | 6 | # path to LRS3 data (pretrain and trainval directories should be located there) 7 | LRS3=/tmp/LRS3/ # EDIT_THIS 8 | 9 | # path to AVSE1 data 10 | # wget https://data.cstr.ed.ac.uk/cogmhear/protected/avse1_data.tar 11 | avse1data=/tmp/avse1_data.tar # EDIT_THIS 12 | 13 | ########################################################### 14 | # Set up working directory structure and data 15 | ########################################################### 16 | 17 | mkdir -p ${root} 18 | tar -xvf ${avse1data} --directory ${root}/ 19 | masker_noise=${root}/maskers_noise/ 20 | masker_speech=${root}/maskers_speech/ 21 | 22 | mkdir -p ${root}/{train,dev}/{targets,interferers,scenes} 23 | 24 | ln -s ${LRS3} ${root}/train/targets_video 25 | ln -s ${LRS3} ${root}/dev/targets_video 26 | 27 | ln -s ${masker_noise} ${root}/train/interferers/noise 28 | ln -s ${masker_noise} ${root}/dev/interferers/noise 29 | 30 | # Create speech masker data from LRS3 videos 31 | python create_speech_maskers.py ${LRS3} ${root}/metadata/masker_speech_list.json ${masker_speech} 32 | 33 | ln -s ${masker_speech} ${root}/train/interferers/speech 34 | ln -s ${masker_speech} ${root}/dev/interferers/speech 35 | 36 | -------------------------------------------------------------------------------- /data_preparation/avse1/speech_weight.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cogmhear/avse_challenge/0aa1577e738f45da502ef5db5f59495f3bb5c313/data_preparation/avse1/speech_weight.mat -------------------------------------------------------------------------------- /data_preparation/avse1/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Adapted from original code by Clarity Challenge 4 | https://github.com/claritychallenge/clarity 5 | ''' 6 | 7 | import os 8 | import scipy 9 | import scipy.io 10 | import numpy as np 11 | 12 | SPEECH_FILTER = scipy.io.loadmat( 13 | os.path.join( 14 | os.path.dirname(os.path.abspath(__file__)), "speech_weight.mat" 15 | ), 16 | squeeze_me=True, 17 | ) 18 | SPEECH_FILTER = np.array(SPEECH_FILTER["filt"]) 19 | 20 | def speechweighted_snr(target, noise): 21 | """Apply speech weighting filter to signals and get SNR.""" 22 | target_filt = scipy.signal.convolve( 23 | target, SPEECH_FILTER, mode="full", method="fft" 24 | ) 25 | noise_filt = scipy.signal.convolve(noise, SPEECH_FILTER, mode="full", method="fft") 26 | 27 | # rms of the target after speech weighted filter 28 | targ_rms = np.sqrt(np.mean(target_filt ** 2)) 29 | 30 | # rms of the noise after speech weighted filter 31 | noise_rms = np.sqrt(np.mean(noise_filt ** 2)) 32 | 33 | if noise_rms==0: 34 | return np.Inf 35 | 36 | sw_snr = np.divide(targ_rms, noise_rms) 37 | return sw_snr 38 | 39 | 40 | def sum_signals(signals): 41 | """Return sum of a list of signals. 42 | 43 | Signals are stored as a list of ndarrays whose size can vary in the first 44 | dimension, i.e., so can sum mono or stereo signals etc. 45 | Shorter signals are zero padded to the length of the longest. 46 | 47 | Args: 48 | signals (list): List of signals stored as ndarrays 49 | 50 | Returns: 51 | ndarray: The sum of the signals 52 | 53 | """ 54 | max_length = max(x.shape[0] for x in signals) 55 | return sum(pad(x, max_length) for x in signals) 56 | 57 | 58 | def pad(signal, length): 59 | """Zero pad signal to required length. 60 | 61 | Assumes required length is not less than input length. 62 | """ 63 | assert length >= signal.shape[0] 64 | return np.pad( 65 | signal, [(0, length - signal.shape[0])] + [(0, 0)] * (len(signal.shape) - 1) 66 | ) 67 | -------------------------------------------------------------------------------- /data_preparation/avse4/build_scenes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from original code by Clarity Enhancement Challenge 2 3 | https://github.com/claritychallenge/clarity/tree/main/recipes/cec2 4 | ''' 5 | 6 | import logging 7 | from pathlib import Path 8 | 9 | import hydra 10 | from omegaconf import DictConfig 11 | 12 | from clarity.data.scene_builder_cec2 import RoomBuilder, SceneBuilder, set_random_seed 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def build_rooms_from_rpf(cfg): 18 | room_builder = RoomBuilder() 19 | for dataset in cfg.room_datasets: 20 | room_file = Path(cfg.path.metadata_dir) / f"rooms.{dataset}.json" 21 | if not room_file.exists(): 22 | room_builder.build_from_rpf(**cfg.room_datasets[dataset]) 23 | room_builder.save_rooms(str(room_file)) 24 | else: 25 | logger.info(f"rooms.{dataset}.json exists, skip") 26 | 27 | 28 | def instantiate_scenes(cfg): 29 | room_builder = RoomBuilder() 30 | set_random_seed(cfg.random_seed) 31 | for dataset in cfg.scene_datasets: 32 | scene_file = Path(cfg.path.metadata_dir) / f"scenes.{dataset}.json" 33 | if not scene_file.exists(): 34 | logger.info(f"instantiate scenes for {dataset} set") 35 | room_file = Path(cfg.path.metadata_dir) / f"rooms.{dataset}.json" 36 | room_builder.load(str(room_file)) 37 | scene_builder = SceneBuilder( 38 | rb=room_builder, 39 | scene_datasets=cfg.scene_datasets[dataset], 40 | target=cfg.target, 41 | interferer=cfg.interferer, 42 | snr_range=cfg.snr_range[dataset], 43 | listener=cfg.listener, 44 | shuffle_rooms=cfg.shuffle_rooms, 45 | ) 46 | scene_builder.instantiate_scenes(dataset=dataset) 47 | scene_builder.save_scenes(str(scene_file)) 48 | else: 49 | logger.info(f"scenes.{dataset}.json exists, skip") 50 | 51 | 52 | @hydra.main(config_path=".", config_name="config", version_base=None) 53 | def run(cfg: DictConfig) -> None: 54 | logger.info("Building rooms") 55 | build_rooms_from_rpf(cfg) 56 | logger.info("Instantiating scenes") 57 | instantiate_scenes(cfg) 58 | 59 | 60 | # pylint: disable=no-value-for-parameter 61 | if __name__ == "__main__": 62 | run() 63 | -------------------------------------------------------------------------------- /data_preparation/avse4/clarity/data/params/speech_weight.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cogmhear/avse_challenge/0aa1577e738f45da502ef5db5f59495f3bb5c313/data_preparation/avse4/clarity/data/params/speech_weight.mat -------------------------------------------------------------------------------- /data_preparation/avse4/clarity/data/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for data generation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | from typing import Final, Literal 7 | 8 | import numpy as np 9 | import scipy 10 | import scipy.io 11 | 12 | SPEECH_FILTER: Final = np.array( 13 | scipy.io.loadmat( 14 | Path(__file__).parent / "params/speech_weight.mat", 15 | squeeze_me=True, 16 | )["filt"] 17 | ) 18 | 19 | 20 | def better_ear_speechweighted_snr(target: np.ndarray, noise: np.ndarray) -> float: 21 | """Calculate effective better ear SNR. 22 | 23 | Args: 24 | target (np.ndarray): 25 | noise (np.ndarray): 26 | 27 | Returns: 28 | (float) 29 | Maximum Signal Noise Ratio between left and right channel. 30 | """ 31 | if np.ndim(target) == 1: 32 | # analysis left ear and right ear for single channel target 33 | left_snr = speechweighted_snr(target, noise[:, 0]) 34 | right_snr = speechweighted_snr(target, noise[:, 1]) 35 | else: 36 | # analysis left ear and right ear for two channel target 37 | left_snr = speechweighted_snr(target[:, 0], noise[:, 0]) 38 | right_snr = speechweighted_snr(target[:, 1], noise[:, 1]) 39 | # snr is the max of left and right 40 | be_snr = max(left_snr, right_snr) 41 | return be_snr 42 | 43 | 44 | def speechweighted_snr(target: np.ndarray, noise: np.ndarray) -> float: 45 | """Apply speech weighting filter to signals and get SNR. 46 | 47 | Args: 48 | target (np.ndarray): 49 | noise (np.ndarray): 50 | 51 | Returns: 52 | (float): 53 | Signal Noise Ratio 54 | """ 55 | 56 | target_filt = scipy.signal.convolve( 57 | target, SPEECH_FILTER, mode="full", method="fft" 58 | ) 59 | noise_filt = scipy.signal.convolve(noise, SPEECH_FILTER, mode="full", method="fft") 60 | 61 | # rms of the target after speech weighted filter 62 | targ_rms = np.sqrt(np.mean(target_filt**2)) 63 | 64 | # rms of the noise after speech weighted filter 65 | noise_rms = np.sqrt(np.mean(noise_filt**2)) 66 | sw_snr = np.divide(targ_rms, noise_rms) 67 | return sw_snr 68 | 69 | 70 | def sum_signals(signals: list) -> np.ndarray | Literal[0]: 71 | """Return sum of a list of signals. 72 | 73 | Signals are stored as a list of ndarrays whose size can vary in the first 74 | dimension, i.e., so can sum mono or stereo signals etc. 75 | Shorter signals are zero padded to the length of the longest. 76 | 77 | Args: 78 | signals (list): List of signals stored as ndarrays 79 | 80 | Returns: 81 | np.ndarray: The sum of the signals 82 | 83 | """ 84 | max_length = max(x.shape[0] for x in signals) 85 | return sum(pad(x, max_length) for x in signals) 86 | 87 | 88 | def pad(signal: np.ndarray, length: int) -> np.ndarray: 89 | """Zero pad signal to required length. 90 | 91 | Assumes required length is not less than input length. 92 | 93 | Args: 94 | signal (np.array): 95 | length (int): 96 | 97 | Returns: 98 | np.array: 99 | """ 100 | 101 | if length < signal.shape[0]: 102 | raise ValueError("Length must be greater than signal length") 103 | 104 | return np.pad( 105 | signal, ([(0, length - signal.shape[0])] + ([(0, 0)] * (len(signal.shape) - 1))) 106 | ) 107 | -------------------------------------------------------------------------------- /data_preparation/avse4/config.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | root: /disk/data1/aaldana/avsec4/avsec4_data/ 3 | metadata_dir: ${path.root}/metadata/ 4 | 5 | 6 | random_seed: 0 7 | shuffle_rooms: False 8 | 9 | # Build rooms 10 | room_datasets: 11 | train: 12 | rpf_location: ${path.root}/train/rooms/rpf 13 | n_interferers: 3 14 | n_rooms: 6000 15 | start_room: 1 16 | dev: 17 | rpf_location: ${path.root}/dev/rooms/rpf 18 | n_interferers: 3 19 | n_rooms: 2500 20 | start_room: 6001 21 | 22 | # Instantiate_scenes 23 | scene_datasets: 24 | train: 25 | n_scenes: 34525 26 | room_selection: SEQUENTIAL 27 | scene_start_index: 1 28 | dev: 29 | n_scenes: 3365 30 | room_selection: SEQUENTIAL 31 | scene_start_index: 34526 32 | 33 | target: 34 | target_speakers: ${path.metadata_dir}/target_speech_list.json 35 | target_selection: SEQUENTIAL 36 | pre_samples_range: [0, 0] 37 | post_samples_range: [0, 0] 38 | 39 | snr_range: 40 | train: [-10, 10] 41 | dev: [-10, 10] 42 | 43 | interferer: 44 | speech_interferers: ${path.metadata_dir}/masker_speech_list.json 45 | noise_interferers: ${path.metadata_dir}/masker_nonspeech_list.json 46 | music_interferers: ${path.metadata_dir}/masker_music_list.json 47 | number: [1, 2, 3] 48 | start_time_range: [0, 0] 49 | end_early_time_range: [0, 0] 50 | 51 | listener: 52 | heads: ["BuK", "DADEC", "KEMAR", "VP_E1", "VP_E2", "VP_E4", "VP_E5", "VP_E6", "VP_E7", "VP_E8", "VP_E9", "VP_E11", "VP_E12", "VP_E13", "VP_N1", "VP_N3", "VP_N4", "VP_N5", "VP_N6"] 53 | channels: ["ED"] 54 | #head rotation parameters are not used in AVSEC 55 | # parameters all in units of samples or degrees 56 | # The time at which the head turn starts relative to target speaker onset 57 | relative_start_time_range: [0, 0] #[-28004, 38147] # -0.635 s to 0.865 s 58 | # Rotations have a normally distributed duration 59 | duration_mean: 0 #8820 60 | duration_sd: 0 #441 # 100 ms 61 | # Head is initially pointing away from talker, uniform within a range 62 | # Note it can be either offset to left or right - drawn at random 63 | angle_initial_mean: 0 #25 64 | angle_initial_sd: 0 #5 65 | # Head turns to point at the speaker within some margin 66 | angle_final_range: [0, 10] 67 | 68 | # Render scenes 69 | render_starting_chunk: 0 # there are (6000 train + 2500 dev) / (12 + 5) = 500 trunks in total. If multi_run, should be 0, 10, 20, 30, ..., 490 if render_n_chunk_to_process=10 70 | render_n_chunk_to_process: 500 #50 # i.e. (12 train + 5 dev) * 10 scenes to render. If not multi_run, set 50 71 | 72 | scene_renderer: 73 | train: 74 | paths: 75 | hoairs: ${path.root}/train/rooms/HOA_IRs_16k 76 | hrirs: ${path.root}/hrir/HRIRs_MAT 77 | scenes: ${path.root}/train/scenes 78 | targets: ${path.root}/train/targets 79 | videos: ${path.root}/train/targets_video 80 | interferers: ${path.root}/train/interferers/{type} 81 | metadata: 82 | room_definitions: ${path.metadata_dir}/rooms.train.json 83 | scene_definitions: ${path.metadata_dir}/scenes.train.json 84 | hrir_metadata: ${path.metadata_dir}/hrir_data.json 85 | chunk_size: 70 #12 86 | dev: 87 | paths: 88 | hoairs: ${path.root}/dev/rooms/HOA_IRs_16k 89 | hrirs: ${path.root}/hrir/HRIRs_MAT 90 | scenes: ${path.root}/dev/scenes 91 | targets: ${path.root}/dev/targets 92 | videos: ${path.root}/dev/targets_video 93 | interferers: ${path.root}/dev/interferers/{type} 94 | metadata: 95 | room_definitions: ${path.metadata_dir}/rooms.dev.json 96 | scene_definitions: ${path.metadata_dir}/scenes.dev.json 97 | hrir_metadata: ${path.metadata_dir}/hrir_data.json 98 | chunk_size: 7 #5 99 | 100 | 101 | render_params: 102 | ambisonic_order: 6 103 | equalise_loudness: False 104 | reference_channel: 0 105 | channel_norms: [6.0] #not used in AVSEC 106 | binaural_render: True 107 | monoaural_render: True 108 | hydra: 109 | run: 110 | dir: . 111 | job: 112 | chdir: True 113 | 114 | defaults: 115 | - override hydra/launcher: cec2_submitit_local 116 | -------------------------------------------------------------------------------- /data_preparation/avse4/create_speech_maskers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Create speech maskers data 4 | ''' 5 | import sys 6 | import numpy as np 7 | import os 8 | import glob 9 | import json 10 | import soundfile as sf 11 | from tqdm import tqdm 12 | from concurrent.futures import ProcessPoolExecutor 13 | 14 | fs = 16000 15 | 16 | def create_dir(directory): 17 | if not os.path.exists(directory): 18 | os.makedirs(directory) 19 | 20 | def create_speech_maskers(datadir, metafile, wavdir): 21 | 22 | with open(metafile, "r") as f: 23 | maskers = json.load(f) 24 | 25 | futures = [] 26 | ncores = 20 27 | with ProcessPoolExecutor(max_workers=ncores) as executor: 28 | for masker in maskers: 29 | futures.append(executor.submit(create_masker_for_spk, datadir, wavdir, masker['speaker'])) 30 | proc_list = [future.result() for future in tqdm(futures)] 31 | 32 | def create_masker_for_spk(datadir, wavdir, spk): 33 | 34 | create_dir(f"{wavdir}/{spk}/") 35 | 36 | # Extract audio from videos and join them into one long masker file 37 | y = np.array([]) 38 | for file in glob.iglob(f'{datadir}/*train*/{spk}/*.mp4'): 39 | basename = os.path.basename(file).split('.')[0] 40 | target_fn = f"{wavdir}/{spk}/{basename}.wav" 41 | command = ("ffmpeg -v 8 -y -i %s -vn -acodec pcm_s16le -ar %s -ac 1 %s < /dev/null" % (file, str(fs), target_fn)) 42 | os.system(command) 43 | x = sf.read(target_fn)[0] 44 | y = np.concatenate((y, x), axis=-1) 45 | sf.write(f"{wavdir}/{spk}.wav", y, fs) 46 | 47 | command = ("rm -r %s/%s" % (wavdir,spk)) 48 | os.system(command) 49 | 50 | if __name__ == '__main__': 51 | 52 | datadir = sys.argv[1] # LRS3 53 | metafile = sys.argv[2] # masker_speech_list.json 54 | wavdir = sys.argv[3] # maskers_speech folder 55 | 56 | # Create speech masker files 57 | create_speech_maskers(datadir, metafile, wavdir) 58 | 59 | -------------------------------------------------------------------------------- /data_preparation/avse4/hydra/launcher/cec2_submitit_local.yaml: -------------------------------------------------------------------------------- 1 | # Submitit configuration for running data preparation stage locally 2 | 3 | defaults: 4 | - submitit_local 5 | 6 | cpus_per_task: 1 7 | tasks_per_node: 2 8 | mem_gb: 4 9 | nodes: 1 10 | -------------------------------------------------------------------------------- /data_preparation/avse4/hydra/launcher/cec2_submitit_slurm.yaml: -------------------------------------------------------------------------------- 1 | # Submitit configuration for running data preparation stage on slurm cluster 2 | 3 | defaults: 4 | - submitit_slurm 5 | 6 | mem_per_cpu: 4GB 7 | tasks_per_node: 1 8 | timeout_min: 180 9 | additional_parameters: 10 | account: clarity 11 | partition: clarity 12 | setup: ['module load Anaconda3/5.3.0', 'source activate clarity', 'export SLURM_EXPORT_ENV=ALL'] -------------------------------------------------------------------------------- /data_preparation/avse4/render_scenes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from original code by Clarity Enhancement Challenge 2 3 | https://github.com/claritychallenge/clarity/tree/main/recipes/cec2 4 | ''' 5 | 6 | import json 7 | import logging 8 | 9 | import hydra 10 | from omegaconf import DictConfig 11 | 12 | from clarity.data.scene_renderer_cec2 import SceneRenderer 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def render_scenes(cfg): 18 | for dataset in cfg.scene_renderer: 19 | logger.info(f"Beginning scene generation for {dataset} set...") 20 | with open( 21 | cfg.scene_renderer[dataset].metadata.scene_definitions, 22 | encoding="utf-8", 23 | ) as fp: 24 | scenes = json.load(fp) 25 | 26 | starting_scene = ( 27 | cfg.scene_renderer[dataset].chunk_size * cfg.render_starting_chunk 28 | ) 29 | 30 | n_scenes = ( 31 | cfg.scene_renderer[dataset].chunk_size * cfg.render_n_chunk_to_process 32 | ) 33 | 34 | scenes = scenes[starting_scene : starting_scene + n_scenes] 35 | 36 | scene_renderer = SceneRenderer( 37 | cfg.scene_renderer[dataset].paths, 38 | cfg.scene_renderer[dataset].metadata, 39 | **cfg.render_params, 40 | ) 41 | scene_renderer.render_scenes(scenes) 42 | 43 | 44 | @hydra.main(config_path=".", config_name="config", version_base=None) 45 | def run(cfg: DictConfig) -> None: 46 | logger.info("Rendering scenes") 47 | render_scenes(cfg) 48 | 49 | 50 | # pylint: disable=no-value-for-parameter 51 | if __name__ == "__main__": 52 | run() 53 | -------------------------------------------------------------------------------- /data_preparation/avse4/setup_avsec4_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # working directory where data will be stored 4 | root=/tmp/avsec4/ # EDIT_THIS 5 | 6 | # path to LRS3 data (pretrain and trainval directories should be located there) 7 | LRS3=/tmp/LRS3/ # EDIT_THIS 8 | #path to Clarity data and OlHeaD-HRTF Database 9 | # wget https://data.cstr.ed.ac.uk/cogmhear/protected/clarity_cec2_data.tar 10 | clarity_data=/tmp/clarity_cec2_data.tar # EDIT_THIS 11 | # path to AVSE4 data 12 | # wget https://data.cstr.ed.ac.uk/cogmhear/protected/avsec4_data.tar 13 | avsec4data=/tmp/avsec4_data.tar # EDIT_THIS 14 | 15 | ########################################################### 16 | # Set up working directory structure and data 17 | ########################################################### 18 | 19 | mkdir -p ${root} 20 | 21 | tar -xvf ${avsec4data} --directory ${root}/ 22 | 23 | masker_music=${root}/maskers_music/ 24 | masker_noise=${root}/maskers_noise/ 25 | masker_speech=${root}/maskers_speech/ 26 | 27 | mkdir -p ${root}/{train,dev}/{targets,interferers,scenes} 28 | 29 | #Extract impulse responses and room simulation info to dev/train folders and extract hrir folder 30 | tar -xvf ${clarity_data} --directory ${root}/ 31 | 32 | ln -s ${LRS3} ${root}/train/targets_video 33 | ln -s ${LRS3} ${root}/dev/targets_video 34 | 35 | ln -s ${masker_music} ${root}/train/interferers/music 36 | ln -s ${masker_music} ${root}/dev/interferers/music 37 | ln -s ${masker_noise} ${root}/train/interferers/noise 38 | ln -s ${masker_noise} ${root}/dev/interferers/noise 39 | 40 | # Create speech masker data from LRS3 videos 41 | python create_speech_maskers.py ${LRS3} ${root}/metadata/masker_speech_list.json ${masker_speech} 42 | 43 | ln -s ${masker_speech} ${root}/train/interferers/speech 44 | ln -s ${masker_speech} ${root}/dev/interferers/speech 45 | 46 | -------------------------------------------------------------------------------- /evaluation/avse1/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | root: /tmp/avse1_data/ # EDIT_THIS - this should match the path defined in data_preparation/avse1/data_config.yaml 3 | dataset: dev 4 | target: ${root}/${dataset}/scenes # path containig the clean target files 5 | target_suffix: '_target' # format of target files: SXXXXX_target.wav 6 | enhanced: ${root}/${dataset}/scenes # path containing your enhanced files -- EDIT_THIS 7 | enhanced_suffix: '_mixed' # format of enhanced files: SXXXXX_mixed.wav -- EDIT_THIS 8 | scenes_names: ${root}/metadata/scenes.${dataset}.json 9 | metrics_results: ${root}/objective_evaluation/ # output directory where results will be written 10 | 11 | fs: 16000 12 | 13 | objective_metrics: 14 | fs: ${fs} 15 | mode: 'wb' 16 | 17 | defaults: 18 | - override hydra/job_logging: disabled 19 | 20 | hydra: 21 | output_subdir: Null 22 | run: 23 | dir: . -------------------------------------------------------------------------------- /evaluation/avse1/objective_evaluation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from original code by Clarity Challenge 3 | https://github.com/claritychallenge/clarity 4 | ''' 5 | 6 | import hydra 7 | from omegaconf import DictConfig 8 | import os 9 | from tqdm import tqdm 10 | import csv 11 | import json 12 | from soundfile import SoundFile 13 | from pesq import pesq 14 | from pystoi import stoi 15 | from concurrent.futures import ProcessPoolExecutor 16 | 17 | def create_dir(directory): 18 | if not os.path.exists(directory): 19 | os.makedirs(directory) 20 | 21 | def run_pesq(target, enhanced, sr, mode): 22 | """Compute PESQ from: https://github.com/ludlows/python-pesq/blob/master/README.md 23 | Args: 24 | target (string): Name of file to read 25 | enhanced (string): Name of file to read 26 | sr (int): sample rate of files 27 | mode (string): 'wb' = wide-band (16KHz); 'nb' narrow-band (8KHz) 28 | Returns: 29 | PESQ metric (float) 30 | """ 31 | return pesq(sr, target, enhanced, mode) 32 | 33 | def run_stoi(target, enhanced, sr): 34 | """Compute STOI from: https://github.com/mpariente/pystoi 35 | Args: 36 | target (string): Name of file to read 37 | enhanced (string): Name of file to read 38 | sr (int): sample rate of files 39 | Returns: 40 | STOI metric (float) 41 | """ 42 | return stoi(target, enhanced, sr) 43 | 44 | def read_audio(filename): 45 | """Read a wavefile and return as numpy array of floats. 46 | Args: 47 | filename (string): Name of file to read 48 | Returns: 49 | ndarray: audio signal 50 | """ 51 | try: 52 | wave_file = SoundFile(filename) 53 | except: 54 | # Ensure incorrect error (24 bit) is not generated 55 | raise Exception(f"Unable to read {filename}.") 56 | return wave_file.read() 57 | 58 | def run_metrics(scene, enhanced, target, cfg): 59 | 60 | # Retrieve the scene name 61 | scene_name = scene["scene"] 62 | 63 | enh_file = os.path.join(enhanced, f"{scene_name}{cfg['enhanced_suffix']}.wav") 64 | tgt_file = os.path.join(target, f"{scene_name}{cfg['target_suffix']}.wav") 65 | scene_metrics_file = os.path.join(cfg["metrics_results"], f"{scene_name}.csv") 66 | 67 | # Skip processing with files dont exist or metrics have already been computed 68 | if ( not os.path.isfile(enh_file) ) or ( not os.path.isfile(tgt_file) ) or ( os.path.isfile(scene_metrics_file)) : 69 | return 70 | 71 | # Read enhanced signal 72 | enh = read_audio(enh_file) 73 | # Read clean/target signal 74 | clean = read_audio(tgt_file) 75 | 76 | # Check that both files are the same length, otherwise computing the metrics results in an error 77 | if len(clean) != len(enh): 78 | raise Exception( 79 | f"Wav files {enh_file} and {tgt_file} should have the same length" 80 | ) 81 | 82 | # Compute metrics 83 | m_stoi = run_stoi(clean, enh, cfg["objective_metrics"]["fs"]) 84 | m_pesq = run_pesq(clean, enh, cfg["objective_metrics"]["fs"], cfg["objective_metrics"]["mode"]) 85 | 86 | # Store scene metrics in a tmp file 87 | with open(scene_metrics_file, "w") as csv_f: 88 | csv_writer = csv.writer(csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) 89 | csv_writer.writerow([scene_name, m_stoi, m_pesq]) 90 | 91 | @hydra.main(config_path=".", config_name="config") 92 | def compute_metrics(cfg: DictConfig) -> None: 93 | # paths to data 94 | enhanced = os.path.join(cfg["enhanced"]) 95 | target = os.path.join(cfg["target"]) 96 | # json file with info about scenes 97 | scenes_eval = json.load(open(cfg["scenes_names"])) 98 | # csv file to store metrics 99 | create_dir(cfg["metrics_results"]) 100 | metrics_file = os.path.join(cfg["metrics_results"], "metrics.csv") 101 | csv_lines = [["scene", "stoi", "pesq"]] 102 | 103 | futures = [] 104 | ncores = 20 105 | with ProcessPoolExecutor(max_workers=ncores) as executor: 106 | for scene in scenes_eval: 107 | futures.append(executor.submit(run_metrics, scene, enhanced, target, cfg)) 108 | proc_list = [future.result() for future in tqdm(futures)] 109 | 110 | # Store results in one file 111 | with open(metrics_file, "w") as csv_f: 112 | csv_writer = csv.writer(csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) 113 | for scene in tqdm(scenes_eval): 114 | scene_name = scene["scene"] 115 | scene_metrics_file = os.path.join(cfg["metrics_results"], f"{scene_name}.csv") 116 | with open(scene_metrics_file, newline='') as csv_f: 117 | scene_metrics = csv.reader(csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) 118 | for row in scene_metrics: 119 | csv_writer.writerow(row) 120 | # remove tmp file 121 | os.system(f"rm {scene_metrics_file}") 122 | 123 | if __name__ == "__main__": 124 | 125 | compute_metrics() 126 | -------------------------------------------------------------------------------- /evaluation/avse4/config.yaml: -------------------------------------------------------------------------------- 1 | root: /disk/data1/aaldana/avsec4/avsec4_data/ #/tmp/avse1_data/ # EDIT_THIS - this should match the path defined in data_preparation/avse4/config.yaml 2 | dataset: dev 3 | target: ${root}/${dataset}/scenes # path containing target anechoic files 4 | target_suffix: '_target_anechoic' # format of target files: SXXXXX_target.wav 5 | enhanced: ${root}/${dataset}/scenes # path containing your enhanced files -- EDIT_THIS 6 | enhanced_suffix: '_mix' # format of enhanced files: SXXXXX_mixed.wav -- EDIT_THIS 7 | scenes_names: ${root}/metadata/scenes.${dataset}.json 8 | metrics_results: ${root}/objective_evaluation/ # output directory where results will be written 9 | 10 | fs: 16000 11 | 12 | objective_metrics: 13 | fs: ${fs} 14 | #Add parameters for other objective metrics 15 | 16 | defaults: 17 | - override hydra/job_logging: disabled 18 | 19 | hydra: 20 | output_subdir: Null 21 | run: 22 | dir: . 23 | -------------------------------------------------------------------------------- /evaluation/avse4/mbstoi/__init__.py: -------------------------------------------------------------------------------- 1 | """Modified Binaural Short-Time Objective Intelligibility Evaluator""" 2 | 3 | from mbstoi import mbstoi 4 | 5 | __all__ = ["mbstoi"] 6 | -------------------------------------------------------------------------------- /evaluation/avse4/mbstoi/mbstoi.py: -------------------------------------------------------------------------------- 1 | """Modified Binaural Short-Time Objective Intelligibility (MBSTOI) Measure 2 | 3 | Script by Clarity Challenge 4 | https://github.com/claritychallenge/clarity 5 | """ 6 | 7 | import importlib.resources as pkg_resources 8 | import logging 9 | import math 10 | 11 | import numpy as np 12 | import yaml # type: ignore 13 | from numpy import ndarray 14 | from scipy.signal import resample 15 | 16 | from mbstoi.mbstoi_utils import ( 17 | equalisation_cancellation, 18 | remove_silent_frames, 19 | stft, 20 | thirdoct, 21 | ) 22 | 23 | # pylint: disable=too-many-locals 24 | 25 | 26 | # basic stoi parameters from file 27 | params_file = pkg_resources.open_text(__package__, "parameters.yaml") 28 | basic_stoi_parameters = yaml.safe_load(params_file.read()) 29 | 30 | 31 | def mbstoi( 32 | left_ear_clean: ndarray, 33 | right_ear_clean: ndarray, 34 | left_ear_noisy: ndarray, 35 | right_ear_noisy: ndarray, 36 | sr_signal: float, 37 | gridcoarseness: int = 1, 38 | sample_rate: float = 10000.0, 39 | n_frame: int = 256, 40 | fft_size_in_samples: int = 512, 41 | n_third_octave_bands: int = 15, 42 | centre_freq_first_third_octave_hz: int = 150, 43 | n_frames: int = 30, 44 | dyn_range: int = 40, 45 | tau_min: float = -0.001, 46 | tau_max: float = 0.001, 47 | gamma_min: int = -20, 48 | gamma_max: int = 20, 49 | sigma_delta_0: float = 65e-6, 50 | sigma_epsilon_0: float = 1.5, 51 | alpha_0_db: int = 13, 52 | tau_0: float = 1.6e-3, 53 | level_shift_deviation: float = 1.6, 54 | ) -> float: 55 | """The Modified Binaural Short-Time Objective Intelligibility (mbstoi) measure. 56 | 57 | Args: 58 | left_ear_clean (ndarray): Clean speech signal from left ear. 59 | right_ear_clean (ndarray): Clean speech signal from right ear. 60 | left_ear_noisy (ndarray) : Noisy/processed speech signal from left ear. 61 | right_ear_noisy (ndarray) : Noisy/processed speech signal from right ear. 62 | fs_signal (int) : Frequency sample rate of signal. 63 | gridcoarseness (int) : Grid coarseness as denominator of ntaus and ngammas. 64 | Defaults to 1. 65 | sample_rate (int) : Sample Rate. 66 | n_frame (int) : Number of Frames. 67 | fft_size_in_samples (int) : ??? size in samples. 68 | n_third_octave_bands (int) : Number of third octave bands. 69 | centre_freq_first_third_octave_hz (int) : 150, 70 | n_frames (int) : Number of Frames. 71 | dyn_range (int) : Dynamic Range. 72 | tau_min (float) : Min Tau the ??? 73 | tau_max (float) : Max Tau the ??? 74 | gamma_min (int) : Minimum gamma the ??? 75 | gamma_max (int) : Maximum gamma the ??? 76 | sigma_delta_0 (float) : ??? 77 | sigma_epsilon_0 (float) : ??? 78 | alpha_0_db (int) : ??? 79 | tau_0 (float) : ??? 80 | level_shift_deviation (float) : ??? 81 | 82 | Returns: 83 | float : mbstoi index d. 84 | 85 | Notes: 86 | All title, copyrights and pending patents pertaining to mbtsoi[1]_ in and to the 87 | original Matlab software are owned by oticon a/s and/or Aalborg University. 88 | Please see `http://ah-andersen.net/code/` 89 | 90 | 91 | .. [1] A. H. Andersen, J. M. de Haan, Z.-H. Tan, and J. Jensen (2018) Refinement and 92 | validation of the binaural short time objective intelligibility measure for 93 | spatially diverse conditions. Speech Communication vol. 102, pp. 1-13 94 | doi:10.1016/j.specom.2018.06.001 95 | """ 96 | 97 | n_taus = math.ceil(100 / gridcoarseness) # number of tau values to try out 98 | n_gammas = math.ceil(40 / gridcoarseness) # number of gamma values to try out 99 | 100 | # prepare signals, ensuring that inputs are column vectors 101 | left_ear_clean = left_ear_clean.flatten() 102 | right_ear_clean = right_ear_clean.flatten() 103 | left_ear_noisy = left_ear_noisy.flatten() 104 | right_ear_noisy = right_ear_noisy.flatten() 105 | 106 | # Resample signals to 10 kHz 107 | if sr_signal != sample_rate: 108 | logging.debug( 109 | "Resampling signals with sr=%s for MBSTOI calculation.", sample_rate 110 | ) 111 | # Assumes fs_signal is 44.1 kHz 112 | length_left_ear_clean = len(left_ear_clean) 113 | left_ear_clean = resample( 114 | left_ear_clean, int(length_left_ear_clean * (sample_rate / sr_signal) + 1) 115 | ) 116 | right_ear_clean = resample( 117 | right_ear_clean, int(length_left_ear_clean * (sample_rate / sr_signal) + 1) 118 | ) 119 | left_ear_noisy = resample( 120 | left_ear_noisy, int(length_left_ear_clean * (sample_rate / sr_signal) + 1) 121 | ) 122 | right_ear_noisy = resample( 123 | right_ear_noisy, int(length_left_ear_clean * (sample_rate / sr_signal) + 1) 124 | ) 125 | 126 | # Remove silent frames 127 | ( 128 | left_ear_clean, 129 | right_ear_clean, 130 | left_ear_noisy, 131 | right_ear_noisy, 132 | ) = remove_silent_frames( 133 | left_ear_clean, 134 | right_ear_clean, 135 | left_ear_noisy, 136 | right_ear_noisy, 137 | dyn_range, 138 | n_frame, 139 | n_frame / 2, 140 | ) 141 | 142 | # Handle case when signals are zeros 143 | if ( 144 | abs(np.log10(np.linalg.norm(left_ear_clean) / np.linalg.norm(left_ear_noisy))) 145 | > 5.0 146 | or abs( 147 | np.log10(np.linalg.norm(right_ear_clean) / np.linalg.norm(right_ear_noisy)) 148 | ) 149 | > 5.0 150 | ): 151 | sii = 0 152 | 153 | # STDFT and filtering 154 | # Get 1/3 octave band matrix 155 | [ 156 | octave_band_matrix, 157 | centre_frequencies, 158 | frequency_band_edges_indices, 159 | _freq_low, 160 | _freq_high, 161 | ] = thirdoct( 162 | sample_rate, 163 | fft_size_in_samples, 164 | n_third_octave_bands, 165 | centre_freq_first_third_octave_hz, 166 | ) 167 | 168 | # This is now the angular frequency in radians per sec 169 | centre_frequencies = 2 * math.pi * centre_frequencies 170 | 171 | # Apply short time DFT to signals and transpose 172 | left_ear_clean_hat = stft(left_ear_clean, n_frame, fft_size_in_samples).transpose() 173 | right_ear_clean_hat = stft( 174 | right_ear_clean, n_frame, fft_size_in_samples 175 | ).transpose() 176 | left_ear_noisy_hat = stft(left_ear_noisy, n_frame, fft_size_in_samples).transpose() 177 | right_ear_noisy_hat = stft( 178 | right_ear_noisy, n_frame, fft_size_in_samples 179 | ).transpose() 180 | 181 | # Take single sided spectrum of signals 182 | idx_upper = int(fft_size_in_samples / 2 + 1) 183 | left_ear_clean_hat = left_ear_clean_hat[0:idx_upper, :] 184 | right_ear_clean_hat = right_ear_clean_hat[0:idx_upper, :] 185 | left_ear_noisy_hat = left_ear_noisy_hat[0:idx_upper, :] 186 | right_ear_noisy_hat = right_ear_noisy_hat[0:idx_upper, :] 187 | 188 | # Compute intermediate correlation via EC search 189 | logging.info("Starting EC evaluation") 190 | # Here intermediate correlation coefficients are evaluated for a discrete set of 191 | # gamma and tau values (a "grid") and the highest value is chosen. 192 | intermediate_intelligibility_measure_grid = np.zeros( 193 | (n_third_octave_bands, np.shape(left_ear_clean_hat)[1] - n_frames + 1) 194 | ) 195 | p_ec_max = np.zeros( 196 | (n_third_octave_bands, np.shape(left_ear_clean_hat)[1] - n_frames + 1) 197 | ) 198 | 199 | # Interaural compensation time and level values 200 | taus = np.linspace(tau_min, tau_max, n_taus) 201 | gammas = np.linspace(gamma_min, gamma_max, n_gammas) 202 | 203 | # Jitter incorporated below - Equations 5 and 6 in Andersen et al. 2018 204 | sigma_epsilon = ( 205 | np.sqrt(2) 206 | * sigma_epsilon_0 207 | * (1 + (abs(gammas) / alpha_0_db) ** level_shift_deviation) 208 | / 20 209 | ) 210 | gammas = gammas / 20 211 | 212 | sigma_delta = np.sqrt(2) * sigma_delta_0 * (1 + (abs(taus) / tau_0)) 213 | 214 | logging.info("Processing Equalisation Cancellation stage") 215 | updated_intermediate_intelligibility_measure, p_ec_max = equalisation_cancellation( 216 | left_ear_clean_hat, 217 | right_ear_clean_hat, 218 | left_ear_noisy_hat, 219 | right_ear_noisy_hat, 220 | n_third_octave_bands, 221 | n_frames, 222 | frequency_band_edges_indices, 223 | centre_frequencies.flatten(), 224 | taus, 225 | n_taus, 226 | gammas, 227 | n_gammas, 228 | intermediate_intelligibility_measure_grid, 229 | p_ec_max, 230 | sigma_epsilon, 231 | sigma_delta, 232 | ) 233 | 234 | # Compute the better ear STOI 235 | logging.info("Computing better ear intermediate correlation coefficients") 236 | # Arrays for the 1/3 octave envelope 237 | left_ear_clean_third_octave_band = np.zeros( 238 | (n_third_octave_bands, np.shape(left_ear_clean_hat)[1]) 239 | ) 240 | right_ear_clean_third_octave_band = np.zeros( 241 | (n_third_octave_bands, np.shape(left_ear_clean_hat)[1]) 242 | ) 243 | left_ear_noisy_third_octave_band = np.zeros( 244 | (n_third_octave_bands, np.shape(left_ear_clean_hat)[1]) 245 | ) 246 | right_ear_noisy_third_octave_band = np.zeros( 247 | (n_third_octave_bands, np.shape(left_ear_clean_hat)[1]) 248 | ) 249 | 250 | # Apply 1/3 octave bands as described in Eq.(1) of the STOI article 251 | for k in range(np.shape(left_ear_clean_hat)[1]): 252 | left_ear_clean_third_octave_band[:, k] = np.dot( 253 | octave_band_matrix, abs(left_ear_clean_hat[:, k]) ** 2 254 | ) 255 | right_ear_clean_third_octave_band[:, k] = np.dot( 256 | octave_band_matrix, abs(right_ear_clean_hat[:, k]) ** 2 257 | ) 258 | left_ear_noisy_third_octave_band[:, k] = np.dot( 259 | octave_band_matrix, abs(left_ear_noisy_hat[:, k]) ** 2 260 | ) 261 | right_ear_noisy_third_octave_band[:, k] = np.dot( 262 | octave_band_matrix, abs(right_ear_noisy_hat[:, k]) ** 2 263 | ) 264 | 265 | # Arrays for better-ear correlations 266 | dl_interm = np.zeros( 267 | (n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1))) 268 | ) 269 | dr_interm = np.zeros( 270 | (n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1))) 271 | ) 272 | left_improved = np.zeros( 273 | (n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1))) 274 | ) 275 | right_improved = np.zeros( 276 | (n_third_octave_bands, len(range(n_frames, len(left_ear_clean_hat[1]) + 1))) 277 | ) 278 | 279 | # Compute temporary better-ear correlations 280 | for m in range( 281 | n_frames, np.shape(left_ear_clean_hat)[1] 282 | ): # pylint: disable=invalid-name 283 | left_ear_clean_seg = left_ear_clean_third_octave_band[:, (m - n_frames) : m] 284 | right_ear_clean_seg = right_ear_clean_third_octave_band[:, (m - n_frames) : m] 285 | left_ear_noisy_seg = left_ear_noisy_third_octave_band[:, (m - n_frames) : m] 286 | right_ear_noisy_seg = right_ear_noisy_third_octave_band[:, (m - n_frames) : m] 287 | 288 | for n in range(n_third_octave_bands): # pylint: disable=invalid-name 289 | left_ear_clean_n = ( 290 | left_ear_clean_seg[n, :] - np.sum(left_ear_clean_seg[n, :]) / n_frames 291 | ) 292 | right_ear_clean_n = ( 293 | right_ear_clean_seg[n, :] - np.sum(right_ear_clean_seg[n, :]) / n_frames 294 | ) 295 | left_ear_noisy_n = ( 296 | left_ear_noisy_seg[n, :] - np.sum(left_ear_noisy_seg[n, :]) / n_frames 297 | ) 298 | right_ear_noisy_n = ( 299 | right_ear_noisy_seg[n, :] - np.sum(right_ear_noisy_seg[n, :]) / n_frames 300 | ) 301 | np.sum(left_ear_clean_n * left_ear_clean_n) 302 | left_improved[n, m - n_frames] = np.sum( 303 | left_ear_clean_n * left_ear_clean_n 304 | ) / np.sum(left_ear_noisy_n * left_ear_noisy_n) 305 | right_improved[n, m - n_frames] = np.sum( 306 | right_ear_clean_n * right_ear_clean_n 307 | ) / np.sum(right_ear_noisy_n * right_ear_noisy_n) 308 | dl_interm[n, m - n_frames] = np.sum(left_ear_clean_n * left_ear_noisy_n) / ( 309 | np.linalg.norm(left_ear_clean_n) * np.linalg.norm(left_ear_noisy_n) 310 | ) 311 | dr_interm[n, m - n_frames] = np.sum( 312 | right_ear_clean_n * right_ear_noisy_n 313 | ) / (np.linalg.norm(right_ear_clean_n) * np.linalg.norm(right_ear_noisy_n)) 314 | 315 | # Get the better ear intermediate coefficients 316 | dl_interm[~np.isfinite(dl_interm)] = 0 317 | dr_interm[~np.isfinite(dr_interm)] = 0 318 | p_be_max = np.maximum(left_improved, right_improved) 319 | dbe_interm = np.zeros(np.shape(dl_interm)) 320 | 321 | idx_left_better = left_improved > right_improved 322 | dbe_interm[idx_left_better] = dl_interm[idx_left_better] 323 | dbe_interm[~idx_left_better] = dr_interm[~idx_left_better] 324 | 325 | # Compute STOI measure 326 | # Whenever a single ear provides a higher correlation than the corresponding EC 327 | # processed alternative,the better-ear correlation is used. 328 | idx_use_be = p_be_max > p_ec_max 329 | updated_intermediate_intelligibility_measure[idx_use_be] = dbe_interm[idx_use_be] 330 | sii = np.mean(updated_intermediate_intelligibility_measure) 331 | 332 | logging.info("MBSTOI processing complete") 333 | 334 | return sii 335 | -------------------------------------------------------------------------------- /evaluation/avse4/mbstoi/parameters.yaml: -------------------------------------------------------------------------------- 1 | #mbstoi 2 | sample_rate: 16000 # sample rate of proposed intelligibility measure in hz #This is 10000 Hz in original MATLAB implementation 3 | n_frame: 256 # window support in samples 4 | fft_size_in_samples: 512 # fft size in samples 5 | n_third_octave_bands: 15 # number of one-third octave bands 6 | centre_freq_first_third_octave_hz: 150 # centre frequency of first 1/3 octave band in hz 7 | n_frames: 30 # number of frames for intermediate intelligibility measure (length analysis window) 8 | dyn_range: 40 # speech dynamic range in db 9 | # values to define ec grid 10 | tau_min: -0.001 # minimum interaural delay compensation in seconds. b: -0.01. 11 | tau_max: 0.001 # maximum interaural delay compensation in seconds. b: 0.01. 12 | gamma_min: -20 # minimum interaural level compensation in db 13 | gamma_max: 20 # maximum interaural level compensation in db 14 | # constants for jitter 15 | # itd compensation standard deviation in seconds. equation 6 andersen et al. 2018 refinement 16 | sigma_delta_0: 0.000065 17 | # ild compensation standard deviation. equation 5 andersen et al. 2018 18 | sigma_epsilon_0: 1.5 19 | # constant for level shift deviation in db. equation 5 andersen et al. 2018 20 | alpha_0_db: 13 21 | # constant for time shift deviation in seconds. equation 6 andersen et al. 2018 22 | tau_0: 0.0016 23 | # constant for level shift deviation. power for calculation of sigma delta gamma 24 | # in equation 5 andersen et al. 2018. 25 | level_shift_deviation: 1.6 26 | 27 | -------------------------------------------------------------------------------- /evaluation/avse4/objective_evaluation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from original code by Clarity Challenge 3 | https://github.com/claritychallenge/clarity 4 | ''' 5 | 6 | import hydra 7 | from omegaconf import DictConfig 8 | import os 9 | from tqdm import tqdm 10 | import csv 11 | import json 12 | from soundfile import SoundFile 13 | from pesq import pesq 14 | from pystoi import stoi 15 | from concurrent.futures import ProcessPoolExecutor 16 | 17 | from mbstoi.mbstoi import mbstoi 18 | 19 | def create_dir(directory): 20 | if not os.path.exists(directory): 21 | os.makedirs(directory) 22 | 23 | def read_audio(filename): 24 | """Read a wavefile and return as numpy array of floats. 25 | Args: 26 | filename (string): Name of file to read 27 | Returns: 28 | ndarray: audio signal 29 | """ 30 | try: 31 | wave_file = SoundFile(filename) 32 | except: 33 | # Ensure incorrect error (24 bit) is not generated 34 | raise Exception(f"Unable to read {filename}.") 35 | return wave_file.read() 36 | 37 | def compute_pesq(target, enhanced, sr, mode): 38 | """Compute PESQ from: https://github.com/ludlows/python-pesq/blob/master/README.md 39 | Args: 40 | target (string): Name of file to read 41 | enhanced (string): Name of file to read 42 | sr (int): sample rate of files 43 | mode (string): 'wb' = wide-band (16KHz); 'nb' narrow-band (8KHz) 44 | Returns: 45 | PESQ metric (float) 46 | """ 47 | return pesq(sr, target, enhanced, mode) 48 | 49 | def compute_stoi(target, enhanced, sr): 50 | """Compute STOI from: https://github.com/mpariente/pystoi 51 | Args: 52 | target (string): Name of file to read 53 | enhanced (string): Name of file to read 54 | sr (int): sample rate of files 55 | Returns: 56 | STOI metric (float) 57 | """ 58 | return stoi(target, enhanced, sr) 59 | 60 | def compute_mbstoi(clean_signal, enhanced_signal, sr): 61 | """compute MBSTOI""" 62 | left_ear_clean = clean_signal[:,0] 63 | right_ear_clean = clean_signal[:,1] 64 | left_ear_noisy= enhanced_signal[:,0] 65 | right_ear_noisy= enhanced_signal[:,1] 66 | 67 | #to modify mbstoi parameters see mbstoi/parameters.yaml 68 | mbstoi_score = mbstoi(left_ear_clean, right_ear_clean, left_ear_noisy, right_ear_noisy, sr_signal=sr) # signal sample rate 69 | return mbstoi_score 70 | 71 | 72 | def run_metrics(scene, enhanced, target, cfg): 73 | 74 | # Retrieve the scene name 75 | scene_name = scene["scene"] 76 | 77 | enh_file = os.path.join(enhanced, f"{scene_name}{cfg['enhanced_suffix']}.wav") 78 | tgt_file = os.path.join(target, f"{scene_name}{cfg['target_suffix']}.wav") 79 | scene_metrics_file = os.path.join(cfg["metrics_results"], f"{scene_name}.csv") 80 | 81 | # Skip processing with files don't exist or metrics have already been computed 82 | if ( not os.path.isfile(enh_file) ) or ( not os.path.isfile(tgt_file) ) or ( os.path.isfile(scene_metrics_file)) : 83 | return 84 | 85 | # Read enhanced signal 86 | enh = read_audio(enh_file) 87 | # Read clean/target signal 88 | clean = read_audio(tgt_file) 89 | 90 | # Check that both files are the same length, otherwise computing the metrics results in an error 91 | if len(clean) != len(enh): 92 | raise Exception( 93 | f"Wav files {enh_file} and {tgt_file} should have the same length" 94 | ) 95 | 96 | # Compute binaural metrics 97 | m_mbstoi = compute_mbstoi(clean, enh, cfg["objective_metrics"]["fs"]) 98 | 99 | # Store scene metrics in a tmp file 100 | with open(scene_metrics_file, "w") as csv_f: 101 | csv_writer = csv.writer(csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) 102 | csv_writer.writerow([scene_name, m_mbstoi]) 103 | 104 | @hydra.main(config_path=".", config_name="config", version_base="1.1") 105 | def compute_metrics(cfg: DictConfig) -> None: 106 | # paths to data 107 | enhanced = os.path.join(cfg["enhanced"]) 108 | target = os.path.join(cfg["target"]) 109 | # json file with info about scenes 110 | scenes_eval = json.load(open(cfg["scenes_names"])) 111 | # csv file to store metrics 112 | create_dir(cfg["metrics_results"]) 113 | metrics_file = os.path.join(cfg["metrics_results"], "objective_metrics.csv") 114 | csv_lines = ["scene", "mbstoi"] 115 | 116 | futures = [] 117 | ncores = 20 118 | with ProcessPoolExecutor(max_workers=ncores) as executor: 119 | for scene in scenes_eval: 120 | futures.append(executor.submit(run_metrics, scene, enhanced, target, cfg)) 121 | proc_list = [future.result() for future in tqdm(futures)] 122 | 123 | # Store results in one file 124 | with open(metrics_file, "w") as csv_f: 125 | csv_writer = csv.writer(csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) 126 | csv_writer.writerow(csv_lines) 127 | for scene in tqdm(scenes_eval): 128 | scene_name = scene["scene"] 129 | scene_metrics_file = os.path.join(cfg["metrics_results"], f"{scene_name}.csv") 130 | with open(scene_metrics_file, newline='') as csv_f: 131 | scene_metrics = csv.reader(csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) 132 | for row in scene_metrics: 133 | csv_writer.writerow(row) 134 | # remove tmp file 135 | os.system(f"rm {scene_metrics_file}") 136 | 137 | if __name__ == "__main__": 138 | 139 | compute_metrics() 140 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | hydra-core==1.3.2 3 | hydra-submitit-launcher==1.2.0 4 | librosa==0.8.1 5 | matplotlib==3.5.1 6 | numpy==1.20.3 7 | omegaconf==2.3.0 8 | pandas==1.3.5 9 | pyloudnorm==0.1.0 10 | scikit-learn==1.0.2 11 | scipy==1.7.3 12 | SoundFile==0.10.3.post1 13 | tqdm==4.67.1 14 | pesq==0.0.4 15 | pystoi==0.3.3 --------------------------------------------------------------------------------