├── __init__.py ├── assets ├── task.png └── model.png ├── .gitignore ├── requirements.txt ├── make_filelist.sh ├── rename_dirs.sh ├── train.py ├── params.py ├── sde.py ├── dataset.py ├── utils.py ├── sampler.py ├── README.md ├── inference.py ├── learner.py └── model.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoonjinXD/T-FOLEY/HEAD/assets/task.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoonjinXD/T-FOLEY/HEAD/assets/model.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs/ 3 | pretrained/ 4 | results/ 5 | DCASE_2023_Challenge_Task_7_Dataset/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | ffmpeg==1.4 3 | librosa==0.10.1 4 | llvmlite==0.41.1 5 | matplotlib==3.7.5 6 | numpy==1.24.4 7 | pillow==10.2.0 8 | protobuf==4.25.3 9 | pydub==0.25.1 10 | pyparsing 11 | python-dateutil==2.9.0.post0 12 | scikit-learn==1.3.2 13 | scipy==1.10.1 14 | soundfile==0.12.1 15 | tensorboard==2.14.0 16 | tqdm==4.66.2 17 | typing_extensions==4.10.0 18 | urllib3==2.2.1 -------------------------------------------------------------------------------- /make_filelist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Base directories 4 | BASE_DIR_DEV="./DCASE_2023_Challenge_Task_7_Dataset/dev" 5 | BASE_DIR_EVAL="./DCASE_2023_Challenge_Task_7_Dataset/eval" 6 | 7 | # Generate file list for training 8 | find "$BASE_DIR_DEV" -type f > "./DCASE_2023_Challenge_Task_7_Dataset/train.txt" 9 | 10 | # Generate file list for evaluation 11 | find "$BASE_DIR_EVAL" -type f > "./DCASE_2023_Challenge_Task_7_Dataset/eval.txt" 12 | 13 | echo "Filelists are created in the dataset directory." 14 | -------------------------------------------------------------------------------- /rename_dirs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Base directories 4 | BASE_DIR_DEV="./DCASE_2023_Challenge_Task_7_Dataset/dev" 5 | BASE_DIR_EVAL="./DCASE_2023_Challenge_Task_7_Dataset/eval" 6 | 7 | # Function to rename directories 8 | rename_dirs() { 9 | local base_dir=$1 10 | mv "${base_dir}/dog_bark" "${base_dir}/DogBark" 11 | mv "${base_dir}/footstep" "${base_dir}/Footstep" 12 | mv "${base_dir}/gunshot" "${base_dir}/GunShot" 13 | mv "${base_dir}/keyboard" "${base_dir}/Keyboard" 14 | mv "${base_dir}/moving_motor_vehicle" "${base_dir}/MovingMotorVehicle" 15 | mv "${base_dir}/rain" "${base_dir}/Rain" 16 | mv "${base_dir}/sneeze_cough" "${base_dir}/Sneeze_Cough" 17 | } 18 | 19 | # Rename directories in both dev and eval 20 | rename_dirs "$BASE_DIR_DEV" 21 | rename_dirs "$BASE_DIR_EVAL" 22 | 23 | echo "Directories have been renamed in both dev and eval." 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch.cuda import device_count 2 | from torch.multiprocessing import spawn 3 | 4 | from learner import train, train_distributed 5 | from params import params 6 | 7 | 8 | def _get_free_port(): 9 | import socketserver 10 | 11 | with socketserver.TCPServer(("localhost", 0), None) as s: 12 | return s.server_address[1] 13 | 14 | def main(): 15 | replica_count = device_count() 16 | if replica_count > 1: 17 | if params['batch_size'] % replica_count != 0: 18 | raise ValueError( 19 | f"Batch size {params['batch_size']} is not evenly divisble by # GPUs {replica_count}." 20 | ) 21 | params['batch_size'] = params['batch_size'] // replica_count 22 | port = _get_free_port() 23 | spawn( 24 | train_distributed, 25 | args=(replica_count, port, params), 26 | nprocs=replica_count, 27 | join=True, 28 | ) 29 | else: 30 | train(params) 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | params = { 2 | 3 | # --- Data --- : provide lists of folders that contain .wav files 4 | 'train_dirs': ['./DCASE_2023_Challenge_Task_7_Dataset/train.txt'], 5 | 'test_dirs': ['./DCASE_2023_Challenge_Task_7_Dataset/eval.txt'], 6 | 'sample_rate': 22050, 7 | 'audio_length': 88200, # traning data seconds * sample_rate 8 | 'num_workers': 4, 9 | 10 | # --- Model --- 11 | 'model_dir': 'logs/', 12 | 'sequential': 'lstm', 13 | 'factors': [2,2,3,3,5,5,7], 14 | 'dims': [32,64,128,128,256,256,512,512], 15 | 16 | # --- Condition --- 17 | 'time_emb_dim': 512, 18 | 'class_emb_dim': 512, 19 | 'mid_dim': 512, 20 | 'film_type': 'block', # {None, temporal, block} 21 | 'block_nums': [49,49,49,49,49,49,14], 22 | 'event_type': 'rms', # {rms, power, onset} 23 | 'event_dims': {'rms': 690, 'power': 88200, 'onset': 88200}, 24 | 'cond_prob': [0.1, 0.1], # [class prob, event prob] 25 | 26 | # --- Training --- 27 | 'lr': 1e-4, 28 | 'batch_size': 16, 29 | 'ema_rate': 0.999, 30 | 'scheduler_patience_epoch': 25, 31 | 'scheduler_factor': 0.8, 32 | 'scheduler_threshold': 0.01, 33 | 'restore': False, 34 | 35 | # --- Logging --- 36 | 'checkpoint_id': None, 37 | 'num_epochs_to_save': 10, 38 | 'num_steps_to_test': 250, 39 | 'n_bins': 5, 40 | 41 | } 42 | -------------------------------------------------------------------------------- /sde.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class SDE(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def sigma(self, t: torch.Tensor): 10 | raise NotImplementedError 11 | 12 | def mean(self, t: torch.Tensor): 13 | raise NotImplementedError 14 | 15 | def perturb(self, x, t, noise=None): 16 | if noise is None: 17 | noise = torch.randn_like(x) 18 | mean = self.mean(t) 19 | sigma = self.sigma(t) 20 | return mean * x + sigma * noise 21 | 22 | 23 | class VpSdeCos(SDE): 24 | def __init__(self): 25 | self.t_min = 0.007 26 | self.t_max = 1 - 0.007 27 | 28 | def sigma_inverse(self, sigma: torch.tensor): 29 | return 1 / np.pi * torch.acos(1 - 2 * sigma) 30 | 31 | def sigma(self, t: torch.Tensor): 32 | return 0.5 * (1 - torch.cos(np.pi * t)) 33 | 34 | def mean(self, t: torch.Tensor): 35 | return (1 - self.sigma(t)**2)**0.5 36 | 37 | def sigma_derivative(self, t: torch.Tensor): 38 | return 0.5 * np.pi * torch.sin(np.pi * t) 39 | 40 | def beta(self, t: torch.Tensor): 41 | return 2 * self.sigma(t) * self.sigma_derivative(t) / ( 42 | 1 - self.sigma(t)**2) 43 | 44 | def g(self, t: torch.Tensor): 45 | return self.beta(t)**0.5 46 | 47 | 48 | class SubVpSdeCos(SDE): 49 | def __init__(self): 50 | self.t_min = 0.006 51 | self.t_max = 1 - 0.006 52 | 53 | def sigma_inverse(self, sigma: torch.tensor): 54 | return 1 / np.pi * torch.acos(1 - 2 * sigma) 55 | 56 | def sigma(self, t: torch.Tensor): 57 | return 0.5 * (1 - torch.cos(np.pi * t)) 58 | 59 | def mean(self, t: torch.Tensor): 60 | return (1 - self.sigma(t))**0.5 61 | 62 | def sigma_from_mean_approx(self, mean_t, nu_t): 63 | return (1-mean_t**2-nu_t**2) 64 | 65 | def sigma_derivative(self, t: torch.Tensor): 66 | return 0.5 * np.pi * torch.sin(np.pi * t) 67 | 68 | def beta(self, t: torch.Tensor): 69 | return self.sigma_derivative(t) / (1 - self.sigma(t)) 70 | 71 | def g(self, t: torch.Tensor): 72 | return (self.sigma_derivative(t) * self.sigma(t) * 73 | (2 - self.sigma(t)) / (1 - self.sigma(t)))**0.5 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torchaudio 7 | from torch.utils.data.distributed import DistributedSampler 8 | 9 | from utils import get_event_cond 10 | 11 | 12 | class AudioDataset(torch.utils.data.Dataset): 13 | def __init__(self, paths, params, labels): 14 | super().__init__() 15 | self.filenames = [] 16 | self.audio_length = params['audio_length'] 17 | self.labels = labels 18 | self.event_type = params['event_type'] 19 | for path in paths: 20 | self.filenames += self.parse_filelist(path) 21 | 22 | def __len__(self): 23 | return len(self.filenames) 24 | 25 | def __getitem__(self, idx): 26 | audio_filename = self.filenames[idx] 27 | signal, _ = torchaudio.load(audio_filename) 28 | signal = signal[0, :self.audio_length] 29 | 30 | # extract class cond 31 | cls_name = os.path.dirname(audio_filename).split('/')[-1] 32 | cls = torch.tensor(self.labels.index(cls_name)) 33 | 34 | # extract event cond 35 | event = signal.clone().detach() 36 | event = get_event_cond(event, self.event_type) 37 | 38 | return { 39 | 'audio': signal, 40 | 'class': cls, 41 | 'event': event 42 | } 43 | 44 | def parse_filelist(self, filelist_path): 45 | # if filelist_path is txt file 46 | if filelist_path.endswith('.txt'): 47 | with open(filelist_path, 'r') as f: 48 | filelist = [line.strip() for line in f.readlines()] 49 | return filelist 50 | 51 | # if filelist_path is csv file 52 | if filelist_path.endswith('.csv'): 53 | with open(filelist_path, 'r') as f: 54 | reader = csv.reader(f) 55 | filelist = [row[0] for row in reader] 56 | f.close() 57 | return filelist 58 | 59 | def moving_avg(self, input, window_size): 60 | if type(input) != list: input = list(input) 61 | result = [] 62 | for i in range(1, window_size+1): 63 | result.append(sum(input[:i])/i) 64 | 65 | moving_sum = sum(input[:window_size]) 66 | result.append(moving_sum/window_size) 67 | for i in range(len(input) - window_size): 68 | moving_sum += (input[i+window_size] - input[i]) 69 | result.append(moving_sum/window_size) 70 | return np.array(result) 71 | 72 | 73 | 74 | def from_path(data_dirs, params, labels, distributed=False): 75 | dataset = AudioDataset(data_dirs, params, labels) 76 | if distributed: 77 | return torch.utils.data.DataLoader( 78 | dataset, 79 | batch_size=params['batch_size'], 80 | collate_fn=None, 81 | shuffle=False, 82 | num_workers=params['num_workers'], 83 | pin_memory=True, 84 | drop_last=True, 85 | sampler=DistributedSampler(dataset)) 86 | return torch.utils.data.DataLoader( 87 | dataset, 88 | batch_size=params['batch_size'], 89 | collate_fn=None, 90 | shuffle=True, 91 | num_workers=os.cpu_count()//4, 92 | pin_memory=True, 93 | drop_last=True) 94 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | import torchaudio.transforms as T 6 | from scipy.signal import ellip, filtfilt, firwin, lfilter 7 | from torchaudio.transforms import MelSpectrogram 8 | 9 | 10 | # --- Preprocess Event --- 11 | def get_event_cond(x, event_type='rms'): 12 | assert event_type in ['rms', 'power', 'onset'] 13 | if event_type == 'rms': 14 | return get_rms(x) 15 | if event_type == 'power': 16 | return get_power(x) 17 | if event_type == 'onset': 18 | return get_onset(x) 19 | 20 | def get_rms(signal): 21 | rms = librosa.feature.rms(y=signal, frame_length=512, hop_length=128) 22 | rms = rms[0] 23 | rms = zero_phased_filter(rms) 24 | return torch.tensor(rms.copy(), dtype=torch.float32) 25 | 26 | def get_power(signal): 27 | if torch.is_tensor(signal): 28 | signal_copy_grad = signal.clone().detach().requires_grad_(signal.requires_grad) 29 | return signal_copy_grad*signal_copy_grad 30 | else: 31 | return torch.tensor(signal*signal, dtype=torch.float32) 32 | 33 | def get_onset(y, sr=22050): 34 | y = np.array(y) 35 | o_env = librosa.onset.onset_strength(y=y, sr=sr, aggregate=np.median, fmax=8000, n_mels=256) 36 | onset_frames = librosa.onset.onset_detect(onset_envelope=o_env, sr=sr, normalize=True, delta=0.3, units='samples') 37 | onsets = np.zeros(y.shape) 38 | onsets[onset_frames] = 1.0 39 | return torch.tensor(onsets, dtype=torch.float32) 40 | 41 | def resample_audio(audio, original_sr, target_sr): 42 | resampler = T.transforms.Resample(original_sr, target_sr, resampling_method='sinc_interpolation') 43 | return resampler(audio) 44 | 45 | def adjust_audio_length(audio, length): 46 | if audio.shape[1] >= length: 47 | return audio[0, :length] 48 | return torch.cat((audio[0, :], torch.zeros(length - audio.shape[1])), dim=-1) 49 | 50 | 51 | # --- Post-process Audio --- 52 | def normalize(x): 53 | return x / torch.max(torch.abs(x)).item() 54 | 55 | def high_pass_filter(x, sr=22050): 56 | b = firwin(101, cutoff=20, fs=sr, pass_zero='highpass') 57 | x= lfilter(b, [1,0], x) 58 | return x 59 | 60 | def zero_phased_filter(x): 61 | b, a = ellip(4, 0.01, 120, 0.125) 62 | x = filtfilt(b, a, x, method="gust") 63 | return x 64 | 65 | def pooling(x, block_num=49): 66 | block_size = x.shape[-1] // block_num 67 | 68 | device = x.device 69 | pooling = torch.nn.MaxPool1d(block_size, stride=block_size) 70 | x = x.unsqueeze(1) 71 | pooled_x = pooling(x).to(device) 72 | 73 | return pooled_x 74 | 75 | 76 | # --- Plot --- 77 | def save_figure_to_numpy(fig): 78 | # save it to a numpy array. 79 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 80 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 81 | return data 82 | 83 | def plot_spec(waveform, sample_rate): 84 | # Transform to mel-spec 85 | transform = MelSpectrogram(sample_rate) 86 | mel_spec = transform(waveform) 87 | 88 | # Plot 89 | plt.style.use('default') 90 | fig, ax = plt.subplots(figsize=(12, 3)) 91 | im = ax.imshow(mel_spec) 92 | plt.colorbar(im, ax=ax) 93 | plt.tight_layout() 94 | 95 | # Turn into numpy format to upload to tensorboard 96 | fig.canvas.draw() 97 | data = save_figure_to_numpy(fig) 98 | plt.close() 99 | return data 100 | 101 | def plot_env(waveform): 102 | # Plot 103 | plt.style.use('default') 104 | fig, ax = plt.subplots(figsize=(12, 3)) 105 | plt.plot(waveform) 106 | plt.tight_layout() 107 | 108 | # Turn into numpy format to upload to tensorboard 109 | fig.canvas.draw() 110 | data = save_figure_to_numpy(fig) 111 | plt.close() 112 | return data -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.integrate import solve_ivp 3 | 4 | 5 | class SDESampling: 6 | def __init__(self, model, sde): 7 | self.model = model 8 | self.sde = sde 9 | 10 | def create_schedules(self, nb_steps): 11 | t_schedule = torch.arange(0, nb_steps + 1) / nb_steps 12 | t_schedule = (self.sde.t_max - self.sde.t_min) * \ 13 | t_schedule + self.sde.t_min 14 | sigma_schedule = self.sde.sigma(t_schedule) 15 | m_schedule = self.sde.mean(t_schedule) 16 | 17 | return sigma_schedule, m_schedule 18 | 19 | def conditional_inference(self, audio, sigma, classes, events, cond_scale=1): 20 | cond_drop_prob = [0.0, 1.0] if events == None else [0.0, 0.0] 21 | cond_score = self.model(audio, sigma, classes, events, cond_drop_prob=cond_drop_prob) 22 | if cond_scale != 1: 23 | uncond_score = self.model(audio, sigma, classes, events, cond_drop_prob=[1.0, 1.0]) 24 | cond_score = uncond_score + (cond_score - uncond_score) * cond_scale 25 | 26 | return cond_score 27 | 28 | def predict( 29 | self, 30 | audio, 31 | nb_steps, 32 | classes, 33 | amplitude, 34 | cond_scale = 3. 35 | ): 36 | 37 | with torch.no_grad(): 38 | 39 | sigma, m = self.create_schedules(nb_steps) 40 | 41 | for n in range(nb_steps - 1, 0, -1): 42 | # begins at t = 1 (n = nb_steps - 1) 43 | # stops at t = 2/nb_steps (n=1) 44 | 45 | cond_score = self.conditional_inference(audio, sigma[n], classes, amplitude, cond_scale) 46 | audio = m[n-1] / m[n] * audio + (m[n] / m[n-1] * (sigma[n-1])**2 / sigma[n] - m[n-1] / m[n] * sigma[n]) * cond_score 47 | 48 | if n > 0: # everytime 49 | noise = torch.randn_like(audio) 50 | audio += sigma[n-1]*(1 - (sigma[n-1]*m[n] / 51 | (sigma[n]*m[n-1]))**2)**0.5 * noise 52 | 53 | # The noise level is now sigma(1/nb_steps) = sigma[0] 54 | # Jump step 55 | cond_score = self.conditional_inference(audio, sigma[0], classes, amplitude, cond_scale) 56 | audio = (audio - sigma[0] * cond_score) / m[0] 57 | 58 | return audio 59 | 60 | class SDESampling_batch: 61 | def __init__(self, model, sde, batch_size, device): 62 | self.model = model 63 | self.sde = sde 64 | self.batch_size = batch_size 65 | self.device = device 66 | 67 | def create_schedules(self, nb_steps, batch): 68 | t_schedule = torch.arange(0, nb_steps + 1) / nb_steps 69 | t_schedule = (self.sde.t_max - self.sde.t_min) * \ 70 | t_schedule + self.sde.t_min 71 | t_schedule = t_schedule.expand(batch, -1) 72 | sigma_schedule = self.sde.sigma(t_schedule) 73 | m_schedule = self.sde.mean(t_schedule) 74 | 75 | return sigma_schedule, m_schedule 76 | 77 | def conditional_inference(self, audio, sigma, classes, events, cond_scale=1): 78 | cond_drop_prob = [0.0, 1.0] if events == None else [0.0, 0.0] 79 | cond_score = self.model(audio, sigma, classes, events, cond_drop_prob=cond_drop_prob) 80 | if cond_scale != 1: 81 | uncond_score = self.model(audio, sigma, classes, events, cond_drop_prob=[1.0, 1.0]) 82 | cond_score = uncond_score + (cond_score - uncond_score) * cond_scale 83 | 84 | return cond_score 85 | 86 | def predict( 87 | self, 88 | audio, 89 | nb_steps, 90 | classes, 91 | amplitude, 92 | cond_scale = 3. 93 | ): 94 | 95 | with torch.no_grad(): 96 | 97 | sigma, m = self.create_schedules(nb_steps, self.batch_size) 98 | sigma = sigma.permute((1,0)).unsqueeze(2).to(self.device) 99 | m = m.permute((1,0)).unsqueeze(2).to(self.device) 100 | 101 | for n in range(nb_steps - 1, 0, -1): 102 | # begins at t = 1 (n = nb_steps - 1) 103 | # stops at t = 2/nb_steps (n=1) 104 | 105 | cond_score = self.conditional_inference(audio, sigma[n], classes, amplitude, cond_scale) 106 | audio = m[n-1] / m[n] * audio + (m[n] / m[n-1] * (sigma[n-1])**2 / sigma[n] - m[n-1] / m[n] * sigma[n]) * cond_score 107 | 108 | if n > 0: # everytime 109 | noise = torch.randn_like(audio) 110 | audio += sigma[n-1]*(1 - (sigma[n-1]*m[n] / 111 | (sigma[n]*m[n-1]))**2)**0.5 * noise 112 | 113 | # The noise level is now sigma(1/nb_steps) = sigma[0] 114 | # Jump step 115 | cond_score = self.conditional_inference(audio, sigma[0], classes, amplitude, cond_scale) 116 | audio = (audio - sigma[0] * cond_score) / m[0] 117 | 118 | return audio 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # T-FOLEY: A Controllable Waveform-Domain Diffusion Model for Temporal-Event-Guided Foley Sound Synthesis 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2401.09294) [![githubio](https://img.shields.io/badge/GitHub.io-Demo_Page-blue?logo=Github&style=flat-square)](https://yoonjinxd.github.io/Event-guided_FSS_Demo.github.io/) *Yoonjin Chung\*, Junwon Lee\*, Juhan Nam* 3 | 4 |

5 | 6 |

7 | 8 | This repository contains the implementation of the paper, *[T-FOLEY: A Controllable Waveform-Domain Diffusion Model for Temporal-Event-Guided Foley Sound Synthesis](https://arxiv.org/pdf/2401.09294.pdf)*, accepted in 2024 ICASSP. 9 | 10 | In our paper, we propose ***T-Foley***, a ***T***emporal-event guided waveform generation model for ***Foley*** sound synthesis, which can generate high-quality audio considering both sound class and when sound should be arranged. 11 | 12 |

13 | 14 |

15 | 16 | ## Setup 17 | 18 | To get started, please prepare the codes and python environment. 19 | 20 | 1. Clone this repository: 21 | ```bash 22 | $ git clone https://github.com/YoonjinXD/T-foley.git 23 | $ cd ./T-foley 24 | ``` 25 | 26 | 2. Install the required dependencies by running the following command: 27 | ```bash 28 | # (Optional) Create a conda virtual emvironment 29 | $ conda create -n tfoley python=3.8.0 30 | $ conda activate tfoley 31 | # Install dependency with pip. Choose appropriate cuda version 32 | $ pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 33 | $ pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | ## Dataset 38 | 39 | To train and evaluate our model, we used [DCASE 2023 Challenge Task 7](https://zenodo.org/records/8091972) which was constructed for Foley Sound Synthesis. 40 | To evaluate our mode, we used the [subsets](https://yoonjinxd.github.io/Event-guided_FSS_Demo.github.io/#-vocal-imitating-dataset) of [VocalImitationSet](https://zenodo.org/records/1340763) and [VocalSketch](https://zenodo.org/records/1251982). These vocal imitating sets consist of vocal audios that mimick event-based or environmental sounds. 41 | Click the link above links to download the corresponding datasets. 42 | 43 | 44 | ## Inference 45 | 46 | To perform inference using our model, follow these steps: 47 | 48 | 1. Download the pre-trained model weights and configurations from the following link: [prertrained.zip](https://zenodo.org/records/10826692). 49 | ```bash 50 | $ wget https://zenodo.org/records/10826692/files/pretrained.zip 51 | ``` 52 | 53 | 2. Unzip and place the downloaded model weights and config json file in the `./pretrained` directory. 54 | ```bash 55 | $ unzip pretrained.zip 56 | ``` 57 | 58 | 3. Run the inference script by executing the following command: 59 | ```bash 60 | $ python inference.py --class_name "DogBark" 61 | ``` 62 | 63 | The class_name **must be** one of the class name of [2023 DCASE Task7 dataset](https://dcase.community/challenge2023/task-foley-sound-synthesis). The list of the class name: `"DogBark", "Footstep", "GunShot", "Keyboard", "MovingMotorVehicle", "Rain", "Sneeze_Cough"` 64 | 65 | 4. The generated samples would be saved in the `./results` directory. 66 | 5. For FAD evaluation, we utilized this toolkit: [FAD tookit](https://github.com/jnwnlee/fadtk) 67 | 68 | 69 | ## Training 70 | 71 | To train the T-Foley model, follow these steps: 72 | 73 | 1. Download and unzip the [DCASE 2023 task 7 dataset](https://zenodo.org/records/8091972). Due to the mismatch between the provided csv and actual data files, please make valid filelists(.txt) using the provided scripts: 74 | ```bash 75 | $ wget http://zenodo.org/records/8091972/files/DCASE_2023_Challenge_Task_7_Dataset.tar.gz 76 | $ tar -zxvf DCASE_2023_Challenge_Task_7_Dataset.tar.gz 77 | $ sh rename_dirs.sh 78 | $ sh make_filelist.sh 79 | ``` 80 | 81 | If you use other dataset, prepare file path list of your training data as .txt format and configure to `params.py`. 82 | 83 | 84 | 2. Run the training: 85 | ```bash 86 | $ python train.py 87 | ``` 88 | 89 | This will start the training process and save the trained model weights in the `logs/` directory. 90 | 91 | To see the training on tensorboard, run: 92 | ```bash 93 | $ tensorboard --logdir logs/ 94 | ``` 95 | 96 | 97 | ## Citation 98 | ```bibtex 99 | @inproceedings{t-foley, 100 | title={T-FOLEY: A Controllable Waveform-Domain Diffusion Model for Temporal-Event-Guided Foley Sound Synthesis}, 101 | author={Chung, Yoonjin and Lee, Junwon and Nam, Juhan}, 102 | booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 103 | year={2024}, 104 | organization={IEEE} 105 | } 106 | ``` 107 | 108 | 109 | ## License 110 | 111 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more information. -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torchaudio as T 8 | import pydub 9 | import soundfile as sf 10 | from model import UNet 11 | from sampler import SDESampling_batch 12 | from scipy.io.wavfile import write 13 | from sde import VpSdeCos 14 | from utils import (adjust_audio_length, get_event_cond, high_pass_filter, 15 | normalize, pooling, resample_audio) 16 | 17 | LABELS = ['DogBark', 'Footstep', 'GunShot', 'Keyboard', 'MovingMotorVehicle', 'Rain', 'Sneeze_Cough'] 18 | 19 | 20 | def load_ema_weights(model, model_path): 21 | checkpoint = torch.load(model_path) 22 | dic_ema = {} 23 | for (key, tensor) in zip(checkpoint['model'].keys(), checkpoint['ema_weights']): 24 | dic_ema[key] = tensor 25 | model.load_state_dict(dic_ema) 26 | return model 27 | 28 | def generate_samples(target_events, class_idx, sampler, cond_scale, device, N, audio_length): 29 | print(f"Generate {N} samples of class \'{LABELS[class_idx]}\'...") 30 | noise = torch.randn(N, audio_length, device=device) 31 | classes = torch.tensor([class_idx]*N, device=device) 32 | sampler.batch_size = N 33 | samples = sampler.predict(noise, 100, classes, target_events, cond_scale=cond_scale) 34 | return samples 35 | 36 | def save_samples(samples, output_dir, sr, class_name, stereo=False, target_audio=None): 37 | for j in range(samples.shape[0]): 38 | sample = samples[j].cpu() 39 | sample = high_pass_filter(sample) 40 | write(f"{output_dir}/{class_name}_{str(j+1).zfill(3)}.wav", sr, sample) 41 | 42 | if stereo: 43 | assert target_audio is not None, "Target audio is required for stereo output." 44 | left_audio = target_audio.cpu().numpy() 45 | right_audio = sample.copy() 46 | assert len(left_audio) == len(right_audio), "Length of target and generated audio must be the same." 47 | 48 | sf.write('temp_left.wav', left_audio, 22050, 'PCM_24') 49 | sf.write('temp_right.wav', right_audio, 22050, 'PCM_24') 50 | 51 | left_audio = pydub.AudioSegment.from_wav('temp_left.wav') 52 | right_audio = pydub.AudioSegment.from_wav('temp_right.wav') 53 | 54 | if left_audio.sample_width > 4: 55 | left_audio = left_audio.set_sample_width(4) 56 | if right_audio.sample_width > 4: 57 | right_audio = right_audio.set_sample_width(4) 58 | 59 | # pan the sound 60 | left_audio_panned = left_audio.pan(-1.) 61 | right_audio_panned = right_audio.pan(+1.) 62 | 63 | mixed = left_audio_panned.overlay(right_audio_panned) 64 | mixed.export(f"{output_dir}/{class_name}_{str(j+1).zfill(3)}_stereo.wav", format='wav') 65 | 66 | # remove temp files 67 | os.remove('temp_left.wav') 68 | os.remove('temp_right.wav') 69 | 70 | def measure_el1_distance(sample, target, event_type): 71 | sample = normalize(sample).cpu() 72 | target = normalize(target).cpu() 73 | 74 | sample_event = get_event_cond(sample, event_type) 75 | target_event = get_event_cond(target, event_type) 76 | 77 | # sample_event = pooling(sample_event, block_num=49) 78 | # target_event = pooling(target_event, block_num=49) 79 | 80 | loss_fn = torch.nn.L1Loss() 81 | loss = loss_fn(sample_event, target_event) 82 | return loss.cpu().item() 83 | 84 | 85 | def main(args): 86 | os.makedirs(args.output_dir, exist_ok=True) 87 | 88 | # Set model and sampler 89 | T.set_audio_backend('sox_io') 90 | device=torch.device('cuda') 91 | 92 | with open(args.param_path) as f: 93 | params = json.load(f) 94 | sample_rate = params['sample_rate'] 95 | audio_length = sample_rate * 4 96 | model = UNet(len(LABELS), params).to(device) 97 | model = load_ema_weights(model, args.model_path) 98 | 99 | sde = VpSdeCos() 100 | sampler = SDESampling_batch(model, sde, batch_size=args.N, device=device) 101 | 102 | # Prepare target audio if exist 103 | if args.target_audio_path is not None: 104 | target_audio, sr = T.load(args.target_audio_path) 105 | if sr != sample_rate: 106 | target_audio = resample_audio(target_audio, sr, sample_rate) 107 | target_audio = adjust_audio_length(target_audio, audio_length) 108 | target_event = get_event_cond(target_audio, params['event_type']) 109 | target_event = target_event.repeat(args.N, 1).to(device) 110 | else: 111 | target_audio = None 112 | target_event = None 113 | 114 | # Generate N samples 115 | class_idx = LABELS.index(args.class_name) 116 | generated = generate_samples(target_event, class_idx, sampler, args.cond_scale, device, args.N, audio_length) 117 | save_samples(generated, args.output_dir, sample_rate, args.class_name, args.stereo, target_audio) 118 | print('Done!') 119 | 120 | # Measure E-L1 distance if target audio is given 121 | if args.target_audio_path is not None: 122 | dists = [] 123 | for sample in generated: 124 | dist = measure_el1_distance(sample, target_audio, params['event_type']) 125 | dists.append(dist) 126 | print(f"E-L1 distance: {np.mean(dists)}") 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('--model_path', type=str, default='./pretrained/block-49_epoch-500.pt') 132 | parser.add_argument('--param_path', type=str, default='./pretrained/params.json') 133 | parser.add_argument('--target_audio_path', type=str, help='Path to the target audio file.', default=None) 134 | parser.add_argument('--class_name', type=str, required=True, help='Class name to generate samples.', 135 | choices=LABELS) 136 | parser.add_argument('--output_dir', type=str, default="./results") 137 | parser.add_argument('--cond_scale', type=int, default=3) 138 | parser.add_argument('--N', type=int, default=3) 139 | parser.add_argument('--stereo', action='store_true', help='Output stereo audio (left: target / right: generated).', 140 | default=False) 141 | args = parser.parse_args() 142 | 143 | main(args) -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from glob import glob 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.parallel import DistributedDataParallel 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | from dataset import from_path as dataset_from_path 14 | from model import UNet 15 | from sampler import SDESampling 16 | from sde import SubVpSdeCos 17 | from utils import get_event_cond, high_pass_filter, normalize, plot_env 18 | 19 | LABELS = ['DogBark', 'Footstep', 'GunShot', 'Keyboard', 'MovingMotorVehicle', 'Rain', 'Sneeze_Cough'] 20 | 21 | def _nested_map(struct, map_fn): 22 | if isinstance(struct, tuple): 23 | return tuple(_nested_map(x, map_fn) for x in struct) 24 | if isinstance(struct, list): 25 | return [_nested_map(x, map_fn) for x in struct] 26 | if isinstance(struct, dict): 27 | return {k: _nested_map(v, map_fn) for k, v in struct.items()} 28 | return map_fn(struct) 29 | 30 | # --- Learner --- 31 | class Learner: 32 | def __init__( 33 | self, model_dir, model, train_set, test_set, params, distributed 34 | ): 35 | os.makedirs(model_dir, exist_ok=True) 36 | self.model_dir = model_dir 37 | self.model = model 38 | self.ema_weights = [param.clone().detach() 39 | for param in self.model.parameters()] 40 | self.lr = params['lr'] 41 | self.epoch = 0 42 | self.step = 0 43 | self.is_master = True 44 | self.distributed = distributed 45 | self.restore_from_checkpoint(params['checkpoint_id']) 46 | 47 | self.sde = SubVpSdeCos() 48 | self.ema_rate = params['ema_rate'] 49 | self.train_set = train_set 50 | self.test_set = test_set 51 | self.params = params 52 | 53 | self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) 54 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 55 | self.optimizer, factor=params['scheduler_factor'], 56 | patience=params['scheduler_patience_epoch']*len(self.train_set)//params['num_steps_to_test'], 57 | threshold=params['scheduler_threshold'] 58 | ) 59 | self.params['total_params_num'] = sum(p.numel() for p in model.parameters() if p.requires_grad) 60 | 61 | self.loss_fn = nn.MSELoss() 62 | self.v_loss = nn.MSELoss(reduction="none") 63 | self.summary_writer = None 64 | self.n_bins = params['n_bins'] 65 | self.num_elems_in_bins_train = np.zeros(self.n_bins) 66 | self.sum_loss_in_bins_train = np.zeros(self.n_bins) 67 | self.num_elems_in_bins_test = np.zeros(self.n_bins) 68 | self.sum_loss_in_bins_test = np.zeros(self.n_bins) 69 | self.cum_grad_norms = 0 70 | 71 | 72 | # Train 73 | def train(self): 74 | device = next(self.model.parameters()).device 75 | while True: 76 | if self.distributed: self.train_set.sampler.set_epoch(self.epoch) 77 | for features in ( 78 | tqdm(self.train_set, 79 | desc=f"Epoch {self.epoch}") 80 | if self.is_master 81 | else self.train_set 82 | ): 83 | self.model.train() 84 | features = _nested_map( 85 | features, 86 | lambda x: x.to(device) if isinstance( 87 | x, torch.Tensor) else x, 88 | ) 89 | loss = self.train_step(features) 90 | if torch.isnan(loss).any(): 91 | raise RuntimeError( 92 | f"Detected NaN loss at step {self.step}.") 93 | 94 | # Logging by steps 95 | if self.is_master: 96 | if self.step % 250 == 249: 97 | self._write_summary(self.step) 98 | 99 | if self.step % self.params['num_steps_to_test'] == 0: 100 | self.test_set_evaluation() 101 | self.scheduler.step(sum(self.sum_loss_in_bins_test)/sum(self.num_elems_in_bins_test)) 102 | self._write_test_summary(self.step) 103 | self.step += 1 104 | 105 | # Logging by epochs 106 | if self.is_master: 107 | if self.epoch % self.params['num_epochs_to_save'] == 0: 108 | self._write_inference_summary(self.step, device) 109 | self.save_to_checkpoint(filename=f'epoch-{self.epoch}') 110 | self.epoch += 1 111 | 112 | def train_step(self, features): 113 | for param in self.model.parameters(): 114 | param.grad = None 115 | 116 | audio = features["audio"] 117 | classes = features["class"] 118 | events = features["event"] 119 | 120 | N, T = audio.shape 121 | 122 | t = torch.rand(N, 1, device=audio.device) 123 | t = (self.sde.t_max - self.sde.t_min) * t + self.sde.t_min 124 | noise = torch.randn_like(audio) 125 | noisy_audio = self.sde.perturb(audio, t, noise) 126 | sigma = self.sde.sigma(t) 127 | predicted = self.model(noisy_audio, sigma, classes, events) 128 | loss = self.loss_fn(noise, predicted) 129 | 130 | loss.backward() 131 | self.grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 132 | self.optimizer.step() 133 | if self.is_master: 134 | self.update_ema_weights() 135 | 136 | t_detach = t.clone().detach().cpu().numpy() 137 | t_detach = np.reshape(t_detach, -1) 138 | 139 | vectorial_loss = self.v_loss(noise, predicted).detach() 140 | vectorial_loss = torch.mean(vectorial_loss, 1).cpu().numpy() 141 | vectorial_loss = np.reshape(vectorial_loss, -1) 142 | 143 | self.update_conditioned_loss(vectorial_loss, t_detach, True) 144 | self.cum_grad_norms += self.grad_norm 145 | return loss 146 | 147 | 148 | # Test 149 | def test_set_evaluation(self): 150 | with torch.no_grad(): 151 | self.model.eval() 152 | for features in self.test_set: 153 | audio = features["audio"].cuda() 154 | classes = features["class"].cuda() 155 | events = features["event"].cuda() 156 | 157 | N, T = audio.shape 158 | 159 | t = torch.rand(N, 1, device=audio.device) 160 | t = (self.sde.t_max - self.sde.t_min) * t + self.sde.t_min 161 | noise = torch.randn_like(audio) 162 | noisy_audio = self.sde.perturb(audio, t, noise) 163 | sigma = self.sde.sigma(t) 164 | predicted = self.model(noisy_audio, sigma, classes, events) 165 | 166 | vectorial_loss = self.v_loss(noise, predicted).detach() 167 | 168 | vectorial_loss = torch.mean(vectorial_loss, 1).cpu().numpy() 169 | vectorial_loss = np.reshape(vectorial_loss, -1) 170 | t = t.cpu().numpy() 171 | t = np.reshape(t, -1) 172 | self.update_conditioned_loss( 173 | vectorial_loss, t, False) 174 | 175 | 176 | # Update loss & ema weights 177 | def update_conditioned_loss(self, vectorial_loss, continuous_array, isTrain): 178 | continuous_array = np.trunc(self.n_bins * continuous_array) 179 | continuous_array = continuous_array.astype(int) 180 | if isTrain: 181 | for k in range(len(continuous_array)): 182 | self.num_elems_in_bins_train[continuous_array[k]] += 1 183 | self.sum_loss_in_bins_train[continuous_array[k] 184 | ] += vectorial_loss[k] 185 | else: 186 | for k in range(len(continuous_array)): 187 | self.num_elems_in_bins_test[continuous_array[k]] += 1 188 | self.sum_loss_in_bins_test[continuous_array[k] 189 | ] += vectorial_loss[k] 190 | 191 | def update_ema_weights(self): 192 | for ema_param, param in zip(self.ema_weights, self.model.parameters()): 193 | if param.requires_grad: 194 | ema_param -= (1 - self.ema_rate) * (ema_param - param.detach()) 195 | 196 | # Logging stuff 197 | def _write_summary(self, step): 198 | loss_in_bins_train = np.divide( 199 | self.sum_loss_in_bins_train, self.num_elems_in_bins_train 200 | ) 201 | dic_loss_train = {} 202 | for k in range(self.n_bins): 203 | dic_loss_train["loss_bin_" + str(k)] = loss_in_bins_train[k] 204 | 205 | sum_loss_n_steps = np.sum(self.sum_loss_in_bins_train) 206 | mean_grad_norms = self.cum_grad_norms / self.num_elems_in_bins_train.sum() * \ 207 | self.params['batch_size'] 208 | writer = self.summary_writer or SummaryWriter( 209 | self.model_dir, purge_step=step) 210 | 211 | writer.add_scalar('train/sum_loss_on_n_steps', 212 | sum_loss_n_steps, step) 213 | writer.add_scalar("train/mean_grad_norm", mean_grad_norms, step) 214 | writer.add_scalars("train/conditioned_loss", dic_loss_train, step) 215 | writer.add_scalar("train/learning_rate", self.optimizer.param_groups[0]['lr'], step) 216 | writer.flush() 217 | self.summary_writer = writer 218 | self.num_elems_in_bins_train = np.zeros(self.n_bins) 219 | self.sum_loss_in_bins_train = np.zeros(self.n_bins) 220 | self.cum_grad_norms = 0 221 | 222 | def _write_test_summary(self, step): 223 | loss_in_bins_test = np.divide( 224 | self.sum_loss_in_bins_test, self.num_elems_in_bins_test 225 | ) 226 | dic_loss_test = {} 227 | for k in range(self.n_bins): 228 | dic_loss_test["loss_bin_" + str(k)] = loss_in_bins_test[k] 229 | 230 | writer = self.summary_writer or SummaryWriter( 231 | self.model_dir, purge_step=step) 232 | writer.add_scalars("test/conditioned_loss", dic_loss_test, step) 233 | writer.flush() 234 | self.summary_writer = writer 235 | self.num_elems_in_bins_test = np.zeros(self.n_bins) 236 | self.sum_loss_in_bins_test = np.zeros(self.n_bins) 237 | 238 | def _write_inference_summary(self, step, device, cond_scale=3.): 239 | sde = SubVpSdeCos() 240 | sampler = SDESampling(self.model, sde) 241 | 242 | test_feature = self.get_random_test_feature() 243 | test_event = test_feature["event"].unsqueeze(0).to(device) 244 | 245 | event_loss = [] 246 | writer = self.summary_writer or SummaryWriter(self.model_dir, purge_step=step) 247 | writer.add_audio(f"test_sample/audio", test_feature["audio"], step, sample_rate=22050) 248 | writer.add_image(f"test_sample/envelope", plot_env(test_feature["audio"]), step, dataformats='HWC') 249 | 250 | for class_idx in range(len(LABELS)): 251 | noise = torch.randn(1, self.params['audio_length'], device=device) 252 | classes = torch.tensor([class_idx], device=device) 253 | 254 | sample = sampler.predict(noise, 100, classes, test_event, cond_scale=cond_scale) 255 | sample = sample.flatten().cpu() 256 | 257 | sample = normalize(sample) 258 | sample = high_pass_filter(sample, sr=22050) 259 | 260 | event_loss.append(self.loss_fn(test_event.squeeze(0).cpu(), get_event_cond(sample, self.params['event_type']))) 261 | writer.add_audio(f"{LABELS[class_idx]}/audio", sample, step, sample_rate=22050) 262 | writer.add_image(f"{LABELS[class_idx]}/envelope", plot_env(sample), step, dataformats='HWC') 263 | 264 | event_loss = sum(event_loss) / len(event_loss) 265 | writer.add_scalar(f"test/event_loss", event_loss, step) 266 | writer.flush() 267 | 268 | # Utils 269 | def get_random_test_feature(self): 270 | return self.test_set.dataset[random.choice(range(len(self.test_set.dataset)))] 271 | 272 | def log_params(self): 273 | with open(os.path.join(self.model_dir, 'params.json'), 'w') as fp: 274 | json.dump(self.params, fp, indent=4) 275 | fp.close() 276 | 277 | def state_dict(self): 278 | if hasattr(self.model, "module") and isinstance(self.model.module, nn.Module): 279 | model_state = self.model.module.state_dict() 280 | else: 281 | model_state = self.model.state_dict() 282 | return { 283 | "epoch": self.epoch, 284 | "step": self.step, 285 | "model": { 286 | k: v.cpu() if isinstance(v, torch.Tensor) else v 287 | for k, v in model_state.items() 288 | }, 289 | "ema_weights": self.ema_weights, 290 | "lr": self.lr, 291 | } 292 | 293 | def load_state_dict(self, state_dict): 294 | if hasattr(self.model, "module") and isinstance(self.model.module, nn.Module): 295 | self.model.module.load_state_dict(state_dict["model"]) 296 | else: 297 | self.model.load_state_dict(state_dict["model"]) 298 | self.epoch = state_dict["epoch"] 299 | self.step = state_dict["step"] 300 | self.ema_weights = state_dict["ema_weights"] 301 | self.lr = state_dict["lr"] 302 | 303 | def restore_from_checkpoint(self, checkpoint_id=None): 304 | try: 305 | if checkpoint_id is None: 306 | # find latest checkpoint_id 307 | list_weights = glob(f'{self.model_dir}/epoch-*.pt') 308 | list_ids = [int(os.path.basename(weight_path).split('-')[-1].rstrip('.pt')) for weight_path in list_weights] 309 | checkpoint_id = list_ids.index(max(list_ids)) 310 | 311 | checkpoint = torch.load(list_weights[checkpoint_id]) 312 | self.load_state_dict(checkpoint) 313 | return True 314 | except (FileNotFoundError, ValueError): 315 | return False 316 | 317 | def save_to_checkpoint(self, filename="weights"): 318 | save_basename = f"{filename}_step-{self.step}.pt" 319 | save_name = f"{self.model_dir}/{save_basename}" 320 | torch.save(self.state_dict(), save_name) 321 | 322 | 323 | # --- Training functions --- 324 | def _train_impl(replica_id, model, train_set, test_set, params, distributed=False): 325 | torch.backends.cudnn.benchmark = True 326 | learner = Learner( 327 | params['model_dir'], model, train_set, test_set, params, distributed=distributed 328 | ) 329 | learner.is_master = replica_id == 0 330 | learner.log_params() 331 | learner.train() 332 | 333 | 334 | def train(params): 335 | model = UNet(num_classes=len(LABELS), params=params).cuda() 336 | train_set = dataset_from_path(params['train_dirs'], params, LABELS) 337 | test_set = dataset_from_path(params['test_dirs'], params, LABELS) 338 | 339 | _train_impl(0, model, train_set, test_set, params) 340 | 341 | 342 | def train_distributed(replica_id, replica_count, port, params): 343 | print(f"Replica {replica_id} of {replica_count} started") 344 | os.environ["MASTER_ADDR"] = "localhost" 345 | os.environ["MASTER_PORT"] = str(port) 346 | torch.distributed.init_process_group( 347 | "nccl", rank=replica_id, world_size=replica_count 348 | ) 349 | device = torch.device("cuda", replica_id) 350 | torch.cuda.set_device(device) 351 | 352 | model = UNet(num_classes=len(LABELS), params=params).cuda() 353 | train_set = dataset_from_path(params['train_dirs'], params, LABELS, distributed=True) 354 | test_set = dataset_from_path(params['test_dirs'], params, LABELS) 355 | model = DistributedDataParallel(model, device_ids=[replica_id], find_unused_parameters=True) 356 | 357 | _train_impl(replica_id, model, train_set, test_set, params, distributed=True) 358 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from torch import einsum, nn 6 | 7 | 8 | # --- Helper Utils --- 9 | def exists(x): 10 | return x is not None 11 | 12 | def default(val, d): 13 | if exists(val): 14 | return val 15 | return d() if callable(d) else d 16 | 17 | class Residual(nn.Module): 18 | def __init__(self, fn): 19 | super().__init__() 20 | self.fn = fn 21 | 22 | def forward(self, x, *args, **kwargs): 23 | return self.fn(x, *args, **kwargs) + x 24 | 25 | class LayerNorm(nn.Module): 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.g = nn.Parameter(torch.ones(1, dim, 1)) 29 | 30 | def forward(self, x): 31 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 32 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 33 | mean = torch.mean(x, dim = 1, keepdim = True) 34 | return (x - mean) * (var + eps).rsqrt() * self.g 35 | 36 | class PreNorm(nn.Module): 37 | def __init__(self, dim, fn): 38 | super().__init__() 39 | self.fn = fn 40 | self.norm = LayerNorm(dim) 41 | 42 | def forward(self, x): 43 | x = self.norm(x) 44 | return self.fn(x) 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, dim, heads = 4, dim_head = 64): 48 | super().__init__() 49 | self.scale = dim_head ** -0.5 50 | self.heads = heads 51 | hidden_dim = dim_head * heads 52 | 53 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False) 54 | self.to_out = nn.Conv1d(hidden_dim, dim, 1) 55 | 56 | def forward(self, x): 57 | b, c, l = x.shape 58 | qkv = self.to_qkv(x).chunk(3, dim = 1) # (B, 256, L) tuple 3개 59 | q, k, v = map(lambda t: rearrange(t, 'b (h c) l -> b h c l', h = self.heads), qkv) # (B, 4, 64, L) 60 | 61 | q = q * self.scale 62 | 63 | sim = einsum('b h d i, b h d j -> b h i j', q, k) # (B, 4, L, L) 64 | attn = sim.softmax(dim = -1) 65 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 66 | 67 | out = rearrange(out, 'b h l d -> b (h d) l', l=l) 68 | return self.to_out(out) 69 | 70 | 71 | # --- classifier free guidance functions --- 72 | def uniform(shape, device): 73 | return torch.zeros(shape, device = device).float().uniform_(0, 1) 74 | 75 | def prob_mask_like(shape, prob, device): 76 | if prob == 1: 77 | return torch.ones(shape, device = device, dtype = torch.bool) 78 | elif prob == 0: 79 | return torch.zeros(shape, device = device, dtype = torch.bool) 80 | else: 81 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob 82 | 83 | 84 | # --- Random Fourier Feature MLP --- 85 | class RFF_MLP_Block(nn.Module): 86 | def __init__(self, time_emb_dim=512): 87 | super().__init__() 88 | self.RFF_freq = nn.Parameter( 89 | 16 * torch.randn([1, 32]), requires_grad=False) 90 | self.MLP = nn.ModuleList([ 91 | nn.Linear(64, 128), 92 | nn.Linear(128, 256), 93 | nn.Linear(256, time_emb_dim), 94 | ]) 95 | 96 | def forward(self, sigma): 97 | """ 98 | Arguments: 99 | sigma: 100 | (shape: [B, 1], dtype: float32) 101 | 102 | Returns: 103 | x: embedding of sigma 104 | (shape: [B, 512], dtype: float32) 105 | """ 106 | x = self._build_RFF_embedding(sigma) 107 | for layer in self.MLP: 108 | x = F.relu(layer(x)) 109 | return x 110 | 111 | def _build_RFF_embedding(self, sigma): 112 | """ 113 | Arguments: 114 | sigma: 115 | (shape: [B, 1], dtype: float32) 116 | Returns: 117 | table: 118 | (shape: [B, 64], dtype: float32) 119 | """ 120 | freqs = self.RFF_freq 121 | freqs = freqs.to(device=torch.device("cuda")) 122 | table = 2 * np.pi * sigma * freqs 123 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) 124 | return table 125 | 126 | # --- Conditioning Modules --- 127 | class Film(nn.Module): 128 | def __init__(self, input_dim, output_dim): 129 | super().__init__() 130 | self.output_layer = nn.Linear(input_dim, 2 * output_dim) 131 | 132 | def forward(self, sigma_encoding): 133 | sigma_encoding = self.output_layer(sigma_encoding) 134 | sigma_encoding = sigma_encoding.unsqueeze(-1) 135 | gamma, beta = torch.chunk(sigma_encoding, 2, dim=1) 136 | return gamma, beta 137 | 138 | class Film_withConds(nn.Module): 139 | def __init__(self, output_dim, time_emb_dim=512, classes_emb_dim=512): 140 | super().__init__() 141 | self.layer = nn.Sequential( 142 | nn.SiLU(), 143 | nn.Linear(int(time_emb_dim) + int(classes_emb_dim), output_dim * 2) 144 | ) if exists(time_emb_dim) or exists(classes_emb_dim) else None 145 | 146 | def forward(self, time_emb = None, class_emb = None): 147 | cond_emb = tuple(filter(exists, (time_emb, class_emb))) 148 | cond_emb = torch.cat(cond_emb, dim = -1) 149 | cond_emb = self.layer(cond_emb) 150 | cond_emb = cond_emb.unsqueeze(-1) 151 | gamma, beta = cond_emb.chunk(2, dim = 1) 152 | return gamma, beta 153 | 154 | class TFilm(nn.Module): 155 | """ 156 | Arguments: 157 | block_num: range(1, 88200), dtype: int 158 | output_dim: dtype: int 159 | Returns: 160 | norm_list = [(gamma_b1, beta_b1), (gamma_b2, beta_b2), ..., (gamma_bn, beta_bn)], shape: (block_num, 2), dtype: list(tuples) 161 | """ 162 | def __init__(self, block_num, input_dim, output_dim): 163 | super().__init__() 164 | self.block_num = block_num 165 | self.num_layers = 8 # default = 2 166 | self.lstm = nn.LSTM(input_dim, output_dim, num_layers=self.num_layers, batch_first=True, bidirectional=False) 167 | self.output_dim = output_dim 168 | self.layer = nn.Linear(output_dim, output_dim*2) 169 | 170 | def forward(self, x): 171 | block_size = x.shape[-1] // self.block_num 172 | 173 | pooling = nn.MaxPool1d(block_size, stride=block_size).to(x.device) 174 | x = pooling(x.unsqueeze(1)).squeeze(1) 175 | h0 = torch.randn(self.num_layers, x.shape[0], self.output_dim, device=x.device) 176 | c0 = torch.randn(self.num_layers, x.shape[0], self.output_dim, device=x.device) 177 | x, _ = self.lstm(x.unsqueeze(-1), (h0, c0)) 178 | x = self.layer(x) 179 | x = x.permute(0, 2, 1) 180 | gamma, beta = x.chunk(2, dim=1) 181 | 182 | return gamma, beta 183 | 184 | class BFilm(nn.Module): 185 | """ 186 | Arguments: 187 | block_num: range(1, 88200), dtype: int 188 | output_dim: dtype: int 189 | Returns: 190 | norm_list = [(gamma_b1, beta_b1), (gamma_b2, beta_b2), ..., (gamma_bn, beta_bn)], shape: (block_num, 2), dtype: list(tuples) 191 | """ 192 | def __init__(self, block_num, input_dim, output_dim): 193 | super().__init__() 194 | self.block_num = block_num 195 | self.layer = nn.Linear(input_dim, output_dim*2) 196 | 197 | def forward(self, x): 198 | block_size = x.shape[-1] // self.block_num 199 | 200 | device = x.device 201 | pooling = nn.MaxPool1d(block_size, stride=block_size) 202 | x = x.unsqueeze(1) 203 | pooled_x = pooling(x).to(device) 204 | 205 | if pooled_x.shape[-1] != self.block_num: 206 | block_size += 1 207 | padding = (block_size*self.block_num - x.shape[-1]+1)//2 208 | x = F.interpolate(x, size=x.shape[-1]+2*padding, mode='nearest') 209 | 210 | pooling = nn.MaxPool1d(block_size, stride=block_size) 211 | pooled_x = pooling(x).to(device) 212 | 213 | pooled_x = pooled_x.squeeze(1) 214 | x = self.layer(pooled_x.unsqueeze(-1)) 215 | x = x.permute(0, 2, 1) 216 | gamma, beta = x.chunk(2, dim = 1) 217 | 218 | return gamma, beta 219 | 220 | # --- Down&Up-sampling blocks --- 221 | class Conv1d(nn.Conv1d): 222 | def __init__(self, *args, **kwargs): 223 | super().__init__(*args, **kwargs) 224 | self.reset_parameters() 225 | 226 | def reset_parameters(self): 227 | nn.init.orthogonal_(self.weight) 228 | nn.init.zeros_(self.bias) 229 | 230 | 231 | # --- Block Modules --- 232 | class FilmConvBlock(nn.Module): 233 | '''FiLM + Activation + Conv''' 234 | def __init__(self, in_channel, out_channel, factor=1): 235 | super().__init__() 236 | self.proj = Conv1d(in_channel, out_channel, 3, dilation=1, padding=1) 237 | # self.norm = nn.GroupNorm(2 if factor < 0 else 1, out_channel) 238 | self.norm = nn.GroupNorm(8, out_channel) 239 | self.act = nn.SiLU() 240 | 241 | self.factor = factor 242 | if factor < 0: 243 | self.conv = nn.ConvTranspose1d(out_channel, out_channel, 3, stride=abs(factor), padding=1, output_padding=abs(factor)-1) 244 | else: 245 | self.conv = Conv1d(out_channel, out_channel, 3, stride=factor, padding=1) 246 | 247 | def forward(self, x, gamma, beta): 248 | x = self.proj(x) 249 | x = gamma * x + beta 250 | 251 | x = self.norm(x) 252 | x = self.act(x) 253 | 254 | x = self.conv(x) 255 | return x 256 | 257 | class TFilmConvBlock(nn.Module): 258 | '''FiLM + Activation + Conv''' 259 | def __init__(self, in_channel, out_channel): 260 | super().__init__() 261 | self.proj = Conv1d(in_channel, out_channel, 3, padding=1) 262 | self.act = nn.SiLU() 263 | self.conv = Conv1d(out_channel, out_channel, 3, padding=1) 264 | 265 | def forward(self, x, gamma, beta): 266 | x = self.proj(x) 267 | 268 | chunks = list(x.chunk(gamma.shape[-1], dim=-1)) 269 | for i, chunk in enumerate(chunks): 270 | g = gamma[:,:,i].unsqueeze(-1) 271 | b = beta[:,:,i].unsqueeze(-1) 272 | chunks[i] = chunk * g + b 273 | 274 | x = torch.cat(chunks, dim=-1) 275 | x = self.act(x) 276 | 277 | x = self.conv(x) 278 | return x 279 | 280 | class BFilmConvBlock(nn.Module): 281 | '''FiLM + Activation + Conv''' 282 | def __init__(self, in_channel, out_channel): 283 | super().__init__() 284 | self.proj = Conv1d(in_channel, out_channel, 3, padding=1) 285 | self.act = nn.SiLU() 286 | self.conv = Conv1d(out_channel, out_channel, 3, padding=1) 287 | 288 | def forward(self, x, gamma, beta): 289 | x = self.proj(x) 290 | 291 | chunks = list(x.chunk(gamma.shape[-1], dim=-1)) 292 | for i, chunk in enumerate(chunks): 293 | g = gamma[:,:,i].unsqueeze(-1) 294 | b = beta[:,:,i].unsqueeze(-1) 295 | chunks[i] = chunk * g + b 296 | 297 | x = torch.cat(chunks, dim=-1) 298 | x = self.act(x) 299 | 300 | x = self.conv(x) 301 | return x 302 | 303 | class GBlock(nn.Module): 304 | def __init__(self, in_channel, out_channel, factor, block_num, film_type, event_dim): 305 | super().__init__() 306 | self.factor = factor 307 | if self.factor < 0: in_channel = in_channel * 2 308 | 309 | self.residual_dense1 = Conv1d(in_channel, out_channel, 1) 310 | self.factor_convs = nn.ModuleList([ 311 | FilmConvBlock(in_channel, in_channel, factor), 312 | FilmConvBlock(in_channel, out_channel), 313 | ]) 314 | self.factor_films = nn.ModuleList([ 315 | Film_withConds(in_channel), 316 | Film_withConds(out_channel), 317 | ]) 318 | 319 | self.residual_dense2 = Conv1d(out_channel, out_channel, 1) 320 | assert film_type in [None, 'film', 'temporal', 'block'] 321 | if film_type == None: 322 | self.t_convs = None 323 | self.t_films = None 324 | elif film_type == 'film': 325 | self.t_convs = nn.ModuleList([FilmConvBlock(out_channel, out_channel), FilmConvBlock(out_channel, out_channel)]) 326 | self.t_films = nn.ModuleList([Film(event_dim, out_channel), Film(event_dim, out_channel)]) 327 | elif film_type == 'temporal': 328 | self.t_convs = nn.ModuleList([TFilmConvBlock(out_channel, out_channel), TFilmConvBlock(out_channel, out_channel)]) 329 | self.t_films = nn.ModuleList([TFilm(block_num, 1, out_channel), TFilm(block_num, 1, out_channel)]) 330 | elif film_type == 'block': 331 | self.t_convs = nn.ModuleList([BFilmConvBlock(out_channel, out_channel), BFilmConvBlock(out_channel, out_channel)]) 332 | self.t_films = nn.ModuleList([BFilm(block_num, 1, out_channel), BFilm(block_num, 1, out_channel)]) 333 | 334 | def forward(self, x, sigma, c, a): 335 | size = self._output_size(x.shape[-1]) 336 | 337 | residual = F.interpolate(x, size=size) 338 | residual = self.residual_dense1(residual) 339 | for film, layer in zip(self.factor_films, self.factor_convs): 340 | gamma, beta = film(sigma, c) 341 | x = layer(x, gamma, beta) 342 | x = x + residual 343 | 344 | if self.t_films != None: 345 | residual = F.interpolate(x, size=size) 346 | residual = self.residual_dense2(residual) 347 | for t_film, layer in zip(self.t_films, self.t_convs): 348 | gamma, beta = t_film(a) 349 | x = layer(x, gamma, beta) 350 | x = x + residual 351 | return x 352 | 353 | def _output_size(self, input_size): 354 | return input_size * abs(self.factor) if self.factor < 0 else input_size // self.factor 355 | 356 | 357 | # --- U-Net --- 358 | class UNet(nn.Module): 359 | def __init__(self, num_classes, params): 360 | super().__init__() 361 | print("Model initializing... This can take a few minutes.") 362 | 363 | # Hyperparameter Settings 364 | sequential = params['sequential'] 365 | assert sequential in ['lstm', 'attn', None], "Choose sequential between \'lstm\' or \'attn\', None." 366 | 367 | dims = params['dims'] 368 | factors =params['factors'] 369 | assert len(dims)-1 == len(factors) 370 | 371 | block_nums = params['block_nums'] 372 | time_emb_dim = params['time_emb_dim'] 373 | class_emb_dim = params['class_emb_dim'] 374 | event_dim = params['event_dims'][params['event_type']] 375 | 376 | cond_drop_prob = params['cond_prob'] 377 | film_type = params['film_type'] 378 | 379 | # Pre-conv/emb Layers 380 | self.conv_1 = Conv1d(1, dims[0], 5, padding=2) 381 | self.embedding = RFF_MLP_Block(time_emb_dim) 382 | 383 | # Up/DownSample Block Layers 384 | DBlock_list = [] 385 | for in_dim, out_dim, factor, block_num in zip(dims[:-1], dims[1:], factors, block_nums): 386 | DBlock_list.append(GBlock(in_dim, out_dim, factor, block_num, film_type, event_dim)) 387 | self.downsample = nn.ModuleList(DBlock_list) 388 | 389 | UBlock_list = [] 390 | for in_dim, out_dim, factor, block_num in zip(dims[:0:-1], dims[-2::-1], factors[::-1], block_nums[::-1]): 391 | UBlock_list.append(GBlock(in_dim, out_dim, -1*factor, block_num, film_type, event_dim)) 392 | self.upsample = nn.ModuleList(UBlock_list) 393 | self.last_conv = Conv1d(dims[0], 1, 3, padding=1) 394 | 395 | # Bottleneck layer 396 | self.sequential = sequential 397 | if sequential: 398 | self.mid_dim = params['mid_dim'] 399 | if sequential == 'lstm': 400 | self.lstm = nn.LSTM(self.mid_dim, self.mid_dim, num_layers=2, batch_first=True, bidirectional=True) 401 | self.lstm_mlp = nn.Sequential( 402 | nn.Linear(self.mid_dim*2, self.mid_dim), 403 | nn.SiLU(), 404 | nn.Linear(self.mid_dim, self.mid_dim) 405 | ) 406 | 407 | if sequential == 'attn': 408 | self.mid_attn = Residual(PreNorm(self.mid_dim, Attention(self.mid_dim))) 409 | 410 | # Classifier-free guidance 411 | self.cond_drop_prob = cond_drop_prob 412 | 413 | self.classes_emb = nn.Embedding(num_classes, class_emb_dim) 414 | self.null_classes_emb = nn.Parameter(torch.randn(class_emb_dim)) 415 | self.null_event_emb = nn.Parameter(torch.randn(event_dim)) 416 | 417 | classes_dim = class_emb_dim * 4 418 | self.classes_mlp = nn.Sequential( 419 | nn.Linear(class_emb_dim, classes_dim), 420 | nn.SiLU(), 421 | nn.Linear(classes_dim, class_emb_dim) 422 | ) 423 | 424 | print("Model successfully initialized!") 425 | 426 | def forward(self, audio, sigma, classes, events, cond_drop_prob=None): 427 | batch, device = audio.shape[0], audio.device 428 | x = audio.unsqueeze(1) 429 | x = self.conv_1(x) 430 | downsampled = [] 431 | sigma_encoding = self.embedding(sigma) 432 | 433 | # Prepare Conditions(class, event) 434 | cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) 435 | classes_emb = self.classes_emb(classes) 436 | if cond_drop_prob[0] > 0: 437 | keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob[0], device = device) 438 | null_classes_emb = repeat(self.null_classes_emb, 'd -> b d', b = batch) 439 | 440 | classes_emb = torch.where( 441 | rearrange(keep_mask, 'b -> b 1'), 442 | classes_emb, 443 | null_classes_emb 444 | ) 445 | c = self.classes_mlp(classes_emb) 446 | 447 | if cond_drop_prob[1] > 0: 448 | keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob[1], device = device) 449 | null_event= repeat(self.null_event_emb, 'd -> b d', b = batch) 450 | 451 | events = torch.where( 452 | rearrange(keep_mask, 'b -> b 1'), 453 | events, 454 | null_event 455 | ) if events != None else null_event 456 | 457 | # Downsample 458 | for layer in self.downsample: 459 | x = layer(x, sigma_encoding, c, events) 460 | downsampled.append(x) 461 | 462 | # Bottleneck 463 | if self.sequential: 464 | if self.sequential == 'lstm': 465 | h0 = torch.randn(4, batch, self.mid_dim, device=device) 466 | c0 = torch.randn(4, batch, self.mid_dim, device=device) 467 | x = x.permute(0, 2, 1) 468 | x, _ = self.lstm(x, (h0, c0)) 469 | x = self.lstm_mlp(x) 470 | x = x.permute(0, 2, 1) 471 | 472 | if self.sequential == 'attn': 473 | x = self.mid_attn(x) 474 | 475 | x = x + downsampled[-1] # residual connection 476 | 477 | # Upsample 478 | for layer, x_dblock in zip(self.upsample, reversed(downsampled)): 479 | x = torch.cat([x, x_dblock], dim=1) 480 | x = layer(x, sigma_encoding, c, events) 481 | 482 | x = self.last_conv(x) 483 | x = x.squeeze(1) 484 | return x 485 | 486 | def forward_with_cond_scale(self, audio, sigma, classes, event, cond_scale=1.): 487 | cond_score = self.forward(audio, sigma, classes, event, cond_drop_prob=[0.0, 0.0]) 488 | if cond_scale == 1: return cond_score 489 | uncond_score = self.forward(audio, sigma, classes, event, cond_drop_prob=[1.0, 1.0]) 490 | return uncond_score + (cond_score - uncond_score) * cond_scale 491 | --------------------------------------------------------------------------------