├── README.md ├── versions ├── tango.py ├── RePaint.py ├── DDNMplus.py └── utils.py └── DM4ASC.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ## Diffusion Models for Audio Semantic Communication 2 | 3 | ### Eleonora Grassucci, Christian Marinoni, Andrea Rodriguez, and Danilo Comminiello 4 | 5 | ### IEEE ICASSP 2024 [[Paper on IEEEXplore](https://ieeexplore.ieee.org/document/10447612)] 6 | 7 | This repository is under construction and we will update it soon! 8 | -------------------------------------------------------------------------------- /versions/tango.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import json 4 | import torch 5 | from tqdm import tqdm 6 | from huggingface_hub import snapshot_download 7 | from models import AudioDiffusion, DDPMScheduler 8 | from audioldm.audio.stft import TacotronSTFT 9 | from audioldm.variational_autoencoder import AutoencoderKL 10 | import IPython 11 | import soundfile as sf 12 | 13 | class Tango: 14 | def __init__(self, path_to_tango, path_to_weights='', device="cuda:0"): 15 | 16 | if path_to_weights=='': 17 | path_to_weights = os.path.join(path_to_tango, 'weights', 'tango-full-ft-audiocaps') 18 | 19 | vae_config = json.load(open("{}/vae_config.json".format(path_to_weights))) 20 | stft_config = json.load(open("{}/stft_config.json".format(path_to_weights))) 21 | main_config = json.load(open("{}/main_config.json".format(path_to_weights))) 22 | 23 | self.vae = AutoencoderKL(**vae_config).to(device) 24 | self.stft = TacotronSTFT(**stft_config).to(device) 25 | self.model = AudioDiffusion(**main_config).to(device) 26 | 27 | vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path_to_weights), map_location=device) 28 | stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path_to_weights), map_location=device) 29 | main_weights = torch.load("{}/pytorch_model_main.bin".format(path_to_weights), map_location=device) 30 | 31 | self.vae.load_state_dict(vae_weights) 32 | self.stft.load_state_dict(stft_weights) 33 | self.model.load_state_dict(main_weights) 34 | 35 | print ("Successfully loaded checkpoint from:", path_to_weights) 36 | 37 | self.vae.eval() 38 | self.stft.eval() 39 | self.model.eval() 40 | 41 | self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler") 42 | 43 | def chunks(self, lst, n): 44 | """ Yield successive n-sized chunks from a list. """ 45 | for i in range(0, len(lst), n): 46 | yield lst[i:i + n] 47 | 48 | def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True): 49 | """ Generate audio for a single prompt string. """ 50 | with torch.no_grad(): 51 | latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress) 52 | mel = self.vae.decode_first_stage(latents) 53 | wave = self.vae.decode_to_waveform(mel) 54 | return wave[0] 55 | 56 | def generate_for_batch(self, prompts, steps=100, guidance=3, samples=1, batch_size=8, disable_progress=True): 57 | """ Generate audio for a list of prompt strings. """ 58 | outputs = [] 59 | for k in tqdm(range(0, len(prompts), batch_size)): 60 | batch = prompts[k: k+batch_size] 61 | with torch.no_grad(): 62 | latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress) 63 | mel = self.vae.decode_first_stage(latents) 64 | wave = self.vae.decode_to_waveform(mel) 65 | outputs += [item for item in wave] 66 | if samples == 1: 67 | return outputs 68 | else: 69 | return list(self.chunks(outputs, samples)) -------------------------------------------------------------------------------- /versions/RePaint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import librosa.display 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | 10 | @torch.no_grad() 11 | def inference_with_mask(self, prompt, inference_scheduler, guidance_scale=3, num_samples_per_prompt=1, 12 | disable_progress=True, mask=None, original=None, t_T=1000, jump_len=10, jump_n_sample=10, noisy_prompts=False, psnr_prompts=20): 13 | 14 | def compute_steps(t_T=1000, jump_len=10, jump_n_sample=10): 15 | jumps = {} 16 | for j in range(0, t_T - jump_len, jump_len): 17 | jumps[j] = jump_n_sample - 1 18 | 19 | t = t_T 20 | ts = [] 21 | 22 | while t >= 1: 23 | t = t-1 24 | ts.append(t) 25 | 26 | if jumps.get(t, 0) > 0: 27 | jumps[t] = jumps[t] - 1 28 | for _ in range(jump_len): 29 | t = t + 1 30 | ts.append(t) 31 | 32 | ts.append(-1) 33 | return ts 34 | 35 | 36 | def get_psnr(snr): 37 | 38 | SNR_DICT = {100: 0.0, 39 | 30: 0.05, 40 | 25: 0.08, 41 | 20: 0.13, 42 | 17.5: 0.175, 43 | 15: 0.22, 44 | 10: 0.36, 45 | 5: 0.6, 46 | 1: 0.9} 47 | 48 | return SNR_DICT[snr] 49 | 50 | device = self.text_encoder.device 51 | 52 | ### START 53 | assert mask is not None, "A mask is needed" 54 | assert original is not None, "The original audio is needed" 55 | mask = mask.to(device) 56 | original = original.to(device) 57 | ### END 58 | 59 | classifier_free_guidance = guidance_scale > 1.0 60 | batch_size = len(prompt) * num_samples_per_prompt 61 | 62 | if classifier_free_guidance: 63 | prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt) 64 | else: 65 | prompt_embeds, boolean_prompt_mask = self.encode_text(prompt) 66 | prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) 67 | boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0) 68 | 69 | prompt_embeds2 = prompt_embeds.flatten() 70 | print(f"Prompts | Min: {torch.min(prompt_embeds2)}, Max: {torch.max(prompt_embeds2)}, std: {torch.std(prompt_embeds)}") 71 | 72 | 73 | if noisy_prompts: 74 | lmin, lmax = torch.min(prompt_embeds), torch.max(prompt_embeds) 75 | noise_embeds_prompts = (prompt_embeds - lmin) / (lmax - lmin) 76 | noise_prompts = torch.randn_like(noise_embeds_prompts) * get_psnr(psnr_prompts) 77 | prompt_embeds = noise_embeds_prompts + noise_prompts 78 | prompt_embeds = (prompt_embeds * (lmax - lmin)) + lmin 79 | 80 | 81 | inference_scheduler.set_timesteps(t_T, device=device) # serve quando uso inference_scheduler per fare step 82 | timesteps = compute_steps(t_T=t_T, jump_len=jump_len, jump_n_sample=jump_n_sample) 83 | num_steps = len(timesteps) 84 | 85 | num_channels_latents = self.unet.in_channels 86 | latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device) 87 | 88 | num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order 89 | progress_bar = tqdm(range(num_steps), disable=disable_progress) 90 | 91 | for i, (t_last, t_cur) in enumerate(zip(timesteps[:-1], timesteps[1:])): 92 | 93 | if t_cur < t_last: 94 | 95 | t_last = torch.tensor(t_last, dtype=torch.int64).to(device) 96 | 97 | # expand the latents if we are doing classifier free guidance 98 | latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents 99 | latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t_last) # returns latent_model_input 100 | 101 | noise_pred = self.unet( 102 | latent_model_input, t_last, encoder_hidden_states=prompt_embeds, 103 | encoder_attention_mask=boolean_prompt_mask 104 | ).sample 105 | 106 | # perform guidance 107 | if classifier_free_guidance: 108 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 109 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 110 | 111 | # compute the previous noisy sample x_t -> x_t-1 112 | latents = inference_scheduler.step(noise_pred, t_last, latents).prev_sample 113 | 114 | ### START 115 | noise = torch.randn_like(original) if t_last>0 else torch.zeros_like(latents) 116 | latent_orig = self.noise_scheduler.add_noise(original, noise, t_last) if t_last>0 else original 117 | 118 | latents = (latent_orig * mask + latents * (1.0 - mask)) 119 | 120 | else: 121 | 122 | ns = self.noise_scheduler 123 | beta_t = ns.betas[t_last+1] # t_last + 1 = t_cur 124 | 125 | noise = torch.randn_like(latents) 126 | 127 | latents = ((1 - beta_t) ** 0.5) * latents + (beta_t ** 0.5) * noise 128 | ### END 129 | 130 | 131 | # call the callback, if provided 132 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0): 133 | progress_bar.update(1) 134 | 135 | if self.set_from == "pre-trained": 136 | latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() 137 | 138 | return latents -------------------------------------------------------------------------------- /versions/DDNMplus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import librosa.display 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from tqdm import tqdm 7 | 8 | #torch.manual_seed(1999) 9 | #np.random.seed(1999) 10 | 11 | @torch.no_grad() 12 | def inference_with_mask(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1, 13 | disable_progress=True, mask=None, original=None, sigma_y=0, travel_length=10, noisy_prompts=False, psnr_prompts=20): 14 | 15 | def get_psnr(snr): 16 | 17 | SNR_DICT = {100: 0.0, 18 | 30: 0.05, 19 | 25: 0.08, 20 | 20: 0.13, 21 | 17.5: 0.175, 22 | 15: 0.22, 23 | 10: 0.36, 24 | 5: 0.6, 25 | 1: 0.9} 26 | 27 | return SNR_DICT[snr] 28 | 29 | device = self.text_encoder.device 30 | 31 | ### START 32 | assert mask is not None, "A mask is needed" 33 | assert original is not None, "The original audio is needed" 34 | # assert sigma_y >= 0 and sigma_y <= 1, "sigma_y must be between 0 and 1 included" 35 | mask = mask.to(device) 36 | original = original.to(device) 37 | ### END 38 | 39 | classifier_free_guidance = guidance_scale > 1.0 40 | batch_size = len(prompt) * num_samples_per_prompt 41 | 42 | if classifier_free_guidance: 43 | prompt_embeds, boolean_prompt_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt) 44 | else: 45 | prompt_embeds, boolean_prompt_mask = self.encode_text(prompt) 46 | prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) 47 | boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0) 48 | 49 | if noisy_prompts: 50 | lmin, lmax = torch.min(prompt_embeds), torch.max(prompt_embeds) 51 | noise_embeds_prompts = (prompt_embeds - lmin) / (lmax - lmin) 52 | noise_prompts = torch.randn_like(noise_embeds_prompts) * get_psnr(psnr_prompts) 53 | prompt_embeds = noise_embeds_prompts + noise_prompts 54 | prompt_embeds = (prompt_embeds * (lmax - lmin)) + lmin 55 | 56 | if sigma_y == -1: 57 | sigma_y = (torch.max(prompt_embeds)-torch.min(prompt_embeds))*torch.std(prompt_embeds) 58 | 59 | inference_scheduler.set_timesteps(num_steps, device=device) 60 | timesteps = inference_scheduler.timesteps 61 | 62 | num_channels_latents = self.unet.in_channels 63 | latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device) 64 | 65 | num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order 66 | progress_bar = tqdm(range(num_steps), disable=disable_progress) 67 | 68 | for i, t in enumerate(timesteps): 69 | 70 | L = min(num_steps - 1 - t, travel_length) 71 | 72 | ns = self.noise_scheduler 73 | 74 | beta_t = ns.betas[t] 75 | alpha_t = ns.alphas[t] 76 | alpha_cumprod_t = ns.alphas_cumprod[t] 77 | alpha_cumprod_tL = ns.alphas_cumprod[t+L] 78 | 79 | alpha_cumprod_L = alpha_cumprod_tL / alpha_cumprod_t 80 | noise = torch.randn_like(latents) 81 | 82 | latents = (alpha_cumprod_L ** 0.5) * latents + ((1 - alpha_cumprod_L) ** 0.5) * noise 83 | 84 | for j in range(L, -1, -1): 85 | 86 | # expand the latents if we are doing classifier free guidance 87 | latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents 88 | latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t+j) 89 | 90 | noise_pred = self.unet( 91 | latent_model_input, t+j, encoder_hidden_states=prompt_embeds, 92 | encoder_attention_mask=boolean_prompt_mask 93 | ).sample 94 | 95 | # perform guidance 96 | if classifier_free_guidance: 97 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 98 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 99 | 100 | ### START 101 | 102 | # compute the previous noisy sample x_t -> x_t-1 103 | # latents = inference_scheduler.step(noise_pred, t, latents).prev_sample 104 | 105 | ns = self.noise_scheduler 106 | 107 | beta_tj = ns.betas[t+j] 108 | alpha_tj = ns.alphas[t+j] 109 | alpha_cumprod_tj = ns.alphas_cumprod[t+j] 110 | alpha_cumprod_tj_prev = ns.alphas_cumprod[t+j-1] if (t + j - 1) >= 0 else ns.one 111 | 112 | x_0tj = (alpha_cumprod_tj ** 0.5) * latents - ((1 - alpha_cumprod_tj) ** 0.5) * noise_pred 113 | 114 | a_tj = ((alpha_cumprod_tj_prev ** 0.5) * beta_tj) / (1. - alpha_cumprod_tj) 115 | c3 = (((1 - alpha_cumprod_tj_prev) / (1 - alpha_cumprod_tj)) * beta_tj) ** 0.5 # sigma_t 116 | 117 | if c3 >= (a_tj * sigma_y): 118 | lambda_tj = 1. 119 | gamma_tj = (c3 ** 2) - ((a_tj * sigma_y) ** 2) 120 | else: 121 | lambda_tj = c3 / (a_tj * sigma_y) 122 | gamma_tj = 0. 123 | 124 | 125 | 126 | # formula 13 DDNM 127 | x_0tj_hat = x_0tj - (lambda_tj * (mask * ((mask * x_0tj) - original))) 128 | 129 | c1 = ((alpha_cumprod_tj_prev ** 0.5) * beta_tj) / (1 - alpha_cumprod_tj) 130 | c2 = ((alpha_tj ** 0.5) * (1 - alpha_cumprod_tj_prev)) / (1 - alpha_cumprod_tj) 131 | 132 | noise = torch.randn_like(latents) if t + j > 0 else torch.zeros_like(latents) 133 | 134 | # formula 7 DDPM, formulae 11 and 14 DDNM 135 | latents = c1 * x_0tj_hat + c2 * latents + (gamma_tj ** 0.5) * noise 136 | ### END 137 | 138 | # call the callback, if provided 139 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0): 140 | progress_bar.update(1) 141 | 142 | if self.set_from == "pre-trained": 143 | latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() 144 | return latents 145 | -------------------------------------------------------------------------------- /versions/utils.py: -------------------------------------------------------------------------------- 1 | from DDNMplus import inference_with_mask as iwm_ddnmplus 2 | from RePaintDDNM import inference_with_mask as iwm_repaintddnm 3 | from RePaint import inference_with_mask as iwm_repaint 4 | 5 | import tools.torch_tools as torch_tools 6 | from scipy.io.wavfile import write 7 | import numpy as np 8 | import torch 9 | from torchmetrics.functional.audio import signal_noise_ratio, signal_distortion_ratio 10 | import torchaudio 11 | import os 12 | import librosa 13 | from matplotlib import pyplot as plt 14 | from tqdm import tqdm 15 | 16 | def get_version(name_version): 17 | ''' 18 | CHOOSE WTHER 19 | ''' 20 | assert name_version in ['ddnm+', 'repaint'] 21 | 22 | if name_version == 'ddnm+': 23 | print('DDNM+') 24 | return iwm_ddnmplus 25 | elif name_version == 'repaint': 26 | print("RePaint") 27 | return iwm_repaint 28 | 29 | def get_d_audios(file_list, num_to_select, dataset_path, caption_filter=None, max_f=200): 30 | 31 | def check_caption(caption_filter, caption): 32 | if caption_filter is None: 33 | return True 34 | for word in caption_filter: 35 | if word in caption: 36 | return True 37 | return False 38 | 39 | selected_audios = file_list.head(max_f) 40 | 41 | d_audios = {} 42 | count = 0 43 | for i, audio in enumerate(selected_audios.itertuples()): 44 | if os.path.isfile(f"{dataset_path}/{audio._asdict()['Index']}_{audio._asdict()['start_time']}.wav") \ 45 | and check_caption(caption_filter, audio._asdict()['caption']) : 46 | d_audios[count] = audio._asdict() 47 | count += 1 48 | if count == 50: 49 | break 50 | 51 | return d_audios 52 | 53 | def get_psnr(snr): 54 | 55 | SNR_DICT = {100: 0.0, 56 | 30: 0.05, 57 | 25: 0.08, 58 | 20: 0.13, 59 | 17.5: 0.175, 60 | 15: 0.22, 61 | 10: 0.36, 62 | 5: 0.6, 63 | 1: 0.9} 64 | 65 | return SNR_DICT[snr] 66 | 67 | def get_mel_spectrogram(original_audio_path, tango): 68 | original_mels, _, _ = torch_tools.wav_to_fbank(original_audio_path, 1024, tango.stft) 69 | return original_mels 70 | 71 | def get_original_latents(original_audio_paths, tango): 72 | original_mels, _, _ = torch_tools.wav_to_fbank(original_audio_paths, 1024, tango.stft) 73 | original_mels = original_mels.unsqueeze(1) 74 | original_latents = tango.vae.get_first_stage_encoding(tango.vae.encode_first_stage(original_mels)) 75 | 76 | return original_mels, original_latents 77 | 78 | def apply_noise(original_latents, snr, verbose=False): 79 | 80 | lmin, lmax = torch.min(original_latents), torch.max(original_latents) 81 | 82 | noisy_latents = (original_latents - lmin) / (lmax - lmin) 83 | 84 | noise = torch.randn(noisy_latents.shape, device=noisy_latents.device)*get_psnr(snr) 85 | noisy_latents += noise 86 | 87 | original_psnr = 10*torch.log(torch.max(original_latents)/torch.var(original_latents)) 88 | print("Original range", original_psnr, torch.var(original_latents)) if verbose else None 89 | noisy_psnr = 10*torch.log(torch.max(noisy_latents)/torch.var(noisy_latents)) 90 | print("Noisy norm", noisy_psnr,torch.var(noisy_latents)) if verbose else None 91 | 92 | #newlmin, newlmax = torch.min(noisy_latents), torch.max(noisy_latents) 93 | 94 | noisy_latents = (noisy_latents * (lmax - lmin)) + lmin 95 | 96 | noisy_psnr = 10*torch.log(torch.max(noisy_latents)/torch.var(noisy_latents)) 97 | print("Noisy change range", noisy_psnr,torch.var(noisy_latents)) if verbose else None 98 | 99 | 100 | return noisy_latents 101 | 102 | 103 | def get_sigma_y(noisy_latents, mask=None): 104 | 105 | if mask is not None: 106 | sigma_y = -1 # As a consequence, sigma_y is derived from the prompt embeddings 107 | else: 108 | sigma_y = (torch.max(noisy_latents)-torch.min(noisy_latents))*torch.std(noisy_latents) 109 | 110 | return sigma_y 111 | 112 | 113 | def get_mask(time_mask_percentage, mask_type='time'): 114 | 115 | min_p, max_p = time_mask_percentage 116 | 117 | assert mask_type in ['time', 'mel', 'wave'] 118 | 119 | if sum(time_mask_percentage)==0: 120 | return None 121 | else: 122 | 123 | if mask_type == 'time': 124 | mask = torch.ones(1,1,256,16).to('cuda') 125 | mask[:, :, int(min_p*256):int(max_p*256), :] = 0 126 | elif mask_type == 'mel': 127 | mask = torch.ones(1,1,1024,64).to('cuda') 128 | mask[:, :, int(min_p*1024):int(max_p*1024), :] = 0 129 | elif mask_type == 'wave': 130 | mask = np.ones((1, 163872), dtype=np.int16) 131 | mask[:, int(min_p*163872):int(max_p*163872)] = 0 132 | 133 | return mask 134 | 135 | 136 | 137 | def save_audio(path, waves, id_file=None, descr='generated'): 138 | 139 | for i, wave in enumerate(waves): 140 | write(f"{path}/{id_file if id_file else i}_{descr}.wav", 16000, wave) 141 | 142 | return 143 | 144 | 145 | 146 | def compute_snr(prediction, target, verbose): 147 | 148 | torch.manual_seed(1) 149 | 150 | snr = signal_noise_ratio(prediction, target).item() 151 | print("SNR: ", "{:.2f}".format(snr)) if verbose else None 152 | 153 | return snr 154 | 155 | 156 | 157 | def compute_sdr(prediction, target, verbose): 158 | 159 | torch.manual_seed(1) 160 | 161 | sdr = signal_distortion_ratio(prediction, target).item() 162 | if verbose: 163 | print("SDR {:.5f}".format(sdr)) 164 | 165 | return sdr 166 | 167 | 168 | 169 | def compute_fad(path_to_clean_audio, path_to_generated_audio, verbose): 170 | from frechet_audio_distance import FrechetAudioDistance 171 | 172 | fad_embeddings = 'vggish' # either 'vggish' or 'pann' 173 | 174 | if fad_embeddings == 'vggish': 175 | frechet = FrechetAudioDistance( 176 | model_name="vggish", 177 | use_pca=False, 178 | use_activation=False, 179 | verbose=False 180 | ) 181 | elif fad_embeddings == 'pann': 182 | frechet = FrechetAudioDistance( 183 | model_name="pann", 184 | use_pca=False, 185 | use_activation=False, 186 | verbose=False 187 | ) 188 | 189 | fad_score = frechet.score(path_to_clean_audio, path_to_generated_audio, dtype="float32") 190 | 191 | if verbose: 192 | print("FAD {:.5f}".format(fad_score)) 193 | 194 | return fad_score 195 | 196 | 197 | def load_audio_old(prediction_path, target_path): 198 | prediction, _ = torchaudio.load(prediction_path) 199 | target, _ = torchaudio.load(target_path) 200 | 201 | pmin, pmax = torch.min(prediction), torch.max(prediction) 202 | tmin, tmax = torch.min(target), torch.max(target) 203 | 204 | prediction = (prediction - pmin) / (pmax - pmin) 205 | prediction = (prediction * (tmax - tmin)) + tmin 206 | 207 | return prediction, target 208 | 209 | def load_audio(prediction_path, target_path, segment_length): 210 | prediction = torch_tools.read_wav_file(prediction_path, segment_length) 211 | target = torch_tools.read_wav_file(target_path, segment_length) 212 | return prediction, target 213 | 214 | 215 | def mask_wav(folder_audios, output_dir, segment_length=160000, mask=None): 216 | assert mask is not None 217 | 218 | if not os.path.isdir(f'{folder_audios}/{output_dir}'): 219 | os.mkdir(f'{folder_audios}/{output_dir}') 220 | 221 | for audio in tqdm(os.listdir(f'{folder_audios}/generated')): 222 | audio_path = f'{folder_audios}/generated/{audio}' 223 | audio_wav = torch_tools.read_wav_file(audio_path, segment_length) 224 | audio_wav = audio_wav[:,int(mask[0]*segment_length):int(mask[1]*segment_length)] 225 | 226 | save_audio(f'{folder_audios}/{output_dir}', audio_wav.cpu().numpy(), id_file=audio, descr='') 227 | 228 | return 229 | 230 | 231 | def compute_metrics(prediction_path, target_path, verbose=False, exclude=None): 232 | 233 | if os.path.isfile(prediction_path) and os.path.isfile(target_path): 234 | 235 | prediction, target = load_audio(prediction_path, target_path, 160000) 236 | 237 | snr = compute_snr(prediction, target, verbose) 238 | sdr = compute_sdr(prediction, target, verbose) 239 | 240 | print("\n--------\nBEWARE OF PICKPOCKETS!\n \ 241 | To calculate the FAD, it is necessary to provide the path to two folders") 242 | 243 | else: 244 | 245 | snr_all = [] 246 | sdr_all = [] 247 | 248 | # We assume that files have been equally named in both the 249 | # prediction and target folders. 250 | for audio_file in os.listdir(prediction_path): 251 | 252 | audio_file_target = audio_file.split("_")[0] 253 | print(audio_file_target) if verbose else None 254 | if os.path.isfile(f'{target_path}/{audio_file_target}.wav') and (exclude is None or int(audio_file_target) not in exclude): 255 | prediction, target = load_audio(f'{prediction_path}/{audio_file}',f'{target_path}/{audio_file_target}.wav', 160000) 256 | 257 | # print(f'PRED {audio_file_target} min {torch.min(prediction)}, max {torch.max(prediction)}') 258 | # print(f'TARGET {audio_file_target} min {torch.min(target)}, max {torch.max(target)}') 259 | snr = compute_snr(prediction, target, verbose) 260 | sdr = compute_sdr(prediction, target, verbose) 261 | 262 | snr_all.append(snr) 263 | sdr_all.append(sdr) 264 | 265 | fad = compute_fad(prediction_path, target_path, verbose) 266 | 267 | return snr_all, sdr_all, fad 268 | 269 | 270 | def plot_spectrogram(original_mel, noisy_mel, output_mel, mel_mask=None, save_fig=False, id_s=0, output_path=None): 271 | 272 | original_mel_signal = original_mel.cpu()[0][0].transpose(0,1) 273 | original_spectrogram = np.abs(original_mel_signal) 274 | original_power_to_db = librosa.power_to_db(original_spectrogram, ref=np.max) 275 | 276 | if mel_mask is not None: 277 | noisy_mel = noisy_mel * mel_mask 278 | noisy_mel_signal = noisy_mel.cpu()[0][0].transpose(0,1) 279 | noisy_spectrogram = np.abs(noisy_mel_signal) 280 | noisy_power_to_db = librosa.power_to_db(noisy_spectrogram, ref=np.max) 281 | 282 | output_mel_signal = output_mel.cpu()[0][0].transpose(0,1) 283 | output_spectrogram = np.abs(output_mel_signal) 284 | output_power_to_db = librosa.power_to_db(output_spectrogram, ref=np.max) 285 | 286 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 3)) 287 | #fig.colorbar(label='dB') 288 | 289 | librosa.display.specshow(original_power_to_db, sr=16000, x_axis='time', y_axis='mel', cmap='magma', hop_length=160, ax=ax1) 290 | librosa.display.specshow(noisy_power_to_db, sr=16000, x_axis='time', y_axis='mel', cmap='magma', hop_length=160, ax=ax2) 291 | librosa.display.specshow(output_power_to_db, sr=16000, x_axis='time', y_axis='mel', cmap='magma', hop_length=160, ax=ax3) 292 | 293 | ax1.set_title('Mel-Spectrogram Original', fontdict=dict(size=12)) 294 | ax2.set_title('Mel-Spectrogram Corrupted', fontdict=dict(size=12)) 295 | ax3.set_title('Mel-Spectrogram Output', fontdict=dict(size=12)) 296 | 297 | xticks = np.arange(0, 11, 1.0) 298 | ax1.set_xticks(xticks) 299 | ax2.set_xticks(xticks) 300 | ax3.set_xticks(xticks) 301 | 302 | if not save_fig: 303 | plt.show() 304 | 305 | if save_fig: 306 | fig.savefig(f"{output_path}/{id_s}_mel-spectrograms.png") 307 | 308 | fig.clf() 309 | return 310 | 311 | 312 | def print_metrics_report(snr_all, sdr_all, fad): 313 | 314 | mean_snr = np.mean(snr_all) 315 | mean_sdr = np.mean(sdr_all) 316 | min_snr = np.min(snr_all) 317 | min_sdr = np.min(sdr_all) 318 | max_snr = np.max(snr_all) 319 | max_sdr = np.max(sdr_all) 320 | std_snr = np.std(snr_all) 321 | std_sdr = np.std(sdr_all) 322 | 323 | # Adjust for a better visualization 324 | print("\033[1m{:<12} Results \033[0m ".format(" ")) 325 | print("{:<6} {:<9} {:<10} ".format(' ',' SNR',' SDR')) 326 | print("{:<6} {:<9.4f} {:<10.4f}".format('mean',mean_snr,mean_sdr)) 327 | print("{:<6} {:<9.4f} {:<10.4f}".format('std',std_snr,std_sdr)) 328 | print("{:<6} {:<9.4f} {:<10.4f}".format('min',min_snr,min_sdr)) 329 | print("{:<6} {:<9.4f} {:<10.4f} ".format('max',max_snr,max_sdr)) 330 | print("") 331 | print("{:<12} {:<20.4f} ".format('FAD',fad)) 332 | 333 | return -------------------------------------------------------------------------------- /DM4ASC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "701920bc", 6 | "metadata": {}, 7 | "source": [ 8 | "# Diffusion Models for Audio Semantic Communication" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "ae6fb974-42a4-446a-ae48-7645c90f6264", 14 | "metadata": {}, 15 | "source": [ 16 | "### Prerequisites" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "db736330-cc8b-4233-8300-6290e26ffc0b", 22 | "metadata": {}, 23 | "source": [ 24 | "Run once to set all the needed stuff" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "5e26ea29", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "!pip install librosa==0.9.2\n", 35 | "!pip install huggingface_hub==0.13.3\n", 36 | "!pip install einops==0.6.1\n", 37 | "!pip install transformers==4.27.0\n", 38 | "!pip install progressbar\n", 39 | "!pip install pandas\n", 40 | "!pip install matplotlib\n", 41 | "!pip install torchmetrics\n", 42 | "!pip install frechet-audio-distance\n", 43 | "\n", 44 | "# Install PyTorch version 1.13.1 with CUDA 11.7 support\n", 45 | "!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117\n", 46 | "\n", 47 | "%cd diffusers/\n", 48 | "!pip install -e .\n", 49 | " \n", 50 | "!git clone https://github.com/declare-lab/tango\n", 51 | "%cd tango\n", 52 | "%cd 'diffusers'\n", 53 | "!pip install -e .\n", 54 | "%cd ../tango/\n", 55 | "%mkdir weights\n", 56 | "%cd weights\n", 57 | "!git clone https://huggingface.co/declare-lab/tango-full-ft-audiocaps" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "8d7cddc9-a311-4d90-81c4-519a9c8a0a26", 63 | "metadata": {}, 64 | "source": [ 65 | "## Initialize tango" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 1, 71 | "id": "0e4efc13-2cd3-494c-bd6a-172ad890cd3c", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "/mnt/media/christian/DiffInpainting/tango\n" 79 | ] 80 | }, 81 | { 82 | "name": "stderr", 83 | "output_type": "stream", 84 | "text": [ 85 | "/home/ispamm/miniconda3/envs/DiffInp/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", 86 | " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n", 87 | "/home/ispamm/miniconda3/envs/DiffInp/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 88 | " from .autonotebook import tqdm as notebook_tqdm\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "%cd tango\n", 94 | "import os\n", 95 | "\n", 96 | "import json\n", 97 | "import torch\n", 98 | "from tqdm import tqdm\n", 99 | "from huggingface_hub import snapshot_download\n", 100 | "from models import AudioDiffusion, DDPMScheduler\n", 101 | "from audioldm.audio.stft import TacotronSTFT\n", 102 | "from audioldm.variational_autoencoder import AutoencoderKL\n", 103 | "import IPython\n", 104 | "import soundfile as sf" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 2, 110 | "id": "dfaa9b67", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "class Tango:\n", 115 | " def __init__(self, path_to_weights='', device=\"cuda:0\"):\n", 116 | "\n", 117 | " if path_to_weights=='':\n", 118 | " path_to_weights = './weights/tango-full-ft-audiocaps'\n", 119 | "\n", 120 | " vae_config = json.load(open(\"{}/vae_config.json\".format(path_to_weights)))\n", 121 | " stft_config = json.load(open(\"{}/stft_config.json\".format(path_to_weights)))\n", 122 | " main_config = json.load(open(\"{}/main_config.json\".format(path_to_weights)))\n", 123 | "\n", 124 | " self.vae = AutoencoderKL(**vae_config).to(device)\n", 125 | " self.stft = TacotronSTFT(**stft_config).to(device)\n", 126 | " self.model = AudioDiffusion(**main_config).to(device)\n", 127 | "\n", 128 | " vae_weights = torch.load(\"{}/pytorch_model_vae.bin\".format(path_to_weights), map_location=device)\n", 129 | " stft_weights = torch.load(\"{}/pytorch_model_stft.bin\".format(path_to_weights), map_location=device)\n", 130 | " main_weights = torch.load(\"{}/pytorch_model_main.bin\".format(path_to_weights), map_location=device)\n", 131 | "\n", 132 | " self.vae.load_state_dict(vae_weights)\n", 133 | " self.stft.load_state_dict(stft_weights)\n", 134 | " self.model.load_state_dict(main_weights)\n", 135 | "\n", 136 | " print (\"Successfully loaded checkpoint from:\", path_to_weights)\n", 137 | "\n", 138 | " self.vae.eval()\n", 139 | " self.stft.eval()\n", 140 | " self.model.eval()\n", 141 | "\n", 142 | " self.scheduler = DDPMScheduler.from_pretrained(main_config[\"scheduler_name\"], subfolder=\"scheduler\")\n", 143 | "\n", 144 | " def chunks(self, lst, n):\n", 145 | " \"\"\" Yield successive n-sized chunks from a list. \"\"\"\n", 146 | " for i in range(0, len(lst), n):\n", 147 | " yield lst[i:i + n]\n", 148 | "\n", 149 | " def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):\n", 150 | " \"\"\" Generate audio for a single prompt string. \"\"\"\n", 151 | " with torch.no_grad():\n", 152 | " latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)\n", 153 | " mel = self.vae.decode_first_stage(latents)\n", 154 | " wave = self.vae.decode_to_waveform(mel)\n", 155 | " return wave[0]\n", 156 | "\n", 157 | " def generate_for_batch(self, prompts, steps=100, guidance=3, samples=1, batch_size=8, disable_progress=True):\n", 158 | " \"\"\" Generate audio for a list of prompt strings. \"\"\"\n", 159 | " outputs = []\n", 160 | " for k in tqdm(range(0, len(prompts), batch_size)):\n", 161 | " batch = prompts[k: k+batch_size]\n", 162 | " with torch.no_grad():\n", 163 | " latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)\n", 164 | " mel = self.vae.decode_first_stage(latents)\n", 165 | " wave = self.vae.decode_to_waveform(mel)\n", 166 | " outputs += [item for item in wave]\n", 167 | " if samples == 1:\n", 168 | " return outputs\n", 169 | " else:\n", 170 | " return list(self.chunks(outputs, samples))" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 3, 176 | "id": "f2b7df8e-a49a-46cc-9237-b6a563c79a9e", 177 | "metadata": { 178 | "scrolled": true 179 | }, 180 | "outputs": [ 181 | { 182 | "name": "stderr", 183 | "output_type": "stream", 184 | "text": [ 185 | "/mnt/media/christian/DiffInpainting/tango/audioldm/audio/stft.py:42: FutureWarning: Pass size=1024 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", 186 | " fft_window = pad_center(fft_window, filter_length)\n", 187 | "/mnt/media/christian/DiffInpainting/tango/audioldm/audio/stft.py:151: FutureWarning: Pass sr=16000, n_fft=1024, n_mels=64, fmin=0, fmax=8000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", 188 | " mel_basis = librosa_mel_fn(\n" 189 | ] 190 | }, 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "UNet initialized randomly.\n" 196 | ] 197 | }, 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "Some weights of the model checkpoint at google/flan-t5-large were not used when initializing T5EncoderModel: ['decoder.block.9.layer.1.layer_norm.weight', 'decoder.block.20.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.1.layer_norm.weight', 'decoder.block.14.layer.2.DenseReluDense.wo.weight', 'decoder.block.18.layer.1.EncDecAttention.q.weight', 'decoder.block.23.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.0.SelfAttention.o.weight', 'decoder.block.15.layer.1.layer_norm.weight', 'decoder.block.22.layer.0.SelfAttention.k.weight', 'decoder.block.15.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.1.EncDecAttention.q.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.21.layer.2.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.embed_tokens.weight', 'decoder.block.15.layer.1.EncDecAttention.k.weight', 'decoder.block.21.layer.0.SelfAttention.k.weight', 'decoder.final_layer_norm.weight', 'decoder.block.6.layer.1.EncDecAttention.k.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.1.EncDecAttention.k.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.22.layer.0.SelfAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.q.weight', 'decoder.block.7.layer.2.layer_norm.weight', 'decoder.block.12.layer.0.SelfAttention.k.weight', 'decoder.block.11.layer.1.EncDecAttention.o.weight', 'decoder.block.20.layer.1.EncDecAttention.o.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.19.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.18.layer.0.SelfAttention.o.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.15.layer.0.layer_norm.weight', 'decoder.block.11.layer.0.SelfAttention.o.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.8.layer.2.layer_norm.weight', 'decoder.block.16.layer.2.layer_norm.weight', 'decoder.block.6.layer.0.SelfAttention.k.weight', 'decoder.block.14.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.1.EncDecAttention.q.weight', 'decoder.block.14.layer.2.layer_norm.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.1.EncDecAttention.q.weight', 'decoder.block.2.layer.0.layer_norm.weight', 'decoder.block.20.layer.0.SelfAttention.o.weight', 'decoder.block.9.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.13.layer.2.layer_norm.weight', 'decoder.block.21.layer.1.EncDecAttention.k.weight', 'decoder.block.14.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.1.layer_norm.weight', 'decoder.block.21.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.16.layer.2.DenseReluDense.wo.weight', 'decoder.block.16.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.13.layer.0.SelfAttention.q.weight', 'decoder.block.18.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.0.SelfAttention.v.weight', 'decoder.block.6.layer.0.SelfAttention.o.weight', 'decoder.block.12.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.21.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.0.SelfAttention.o.weight', 'decoder.block.7.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.22.layer.0.layer_norm.weight', 'decoder.block.11.layer.1.EncDecAttention.k.weight', 'decoder.block.17.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.15.layer.0.SelfAttention.k.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.10.layer.2.layer_norm.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'decoder.block.17.layer.0.SelfAttention.v.weight', 'decoder.block.20.layer.0.layer_norm.weight', 'decoder.block.14.layer.0.SelfAttention.k.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.0.layer_norm.weight', 'decoder.block.15.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.12.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.2.DenseReluDense.wo.weight', 'decoder.block.23.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.15.layer.2.DenseReluDense.wo.weight', 'decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.16.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.2.layer_norm.weight', 'decoder.block.15.layer.1.EncDecAttention.o.weight', 'decoder.block.23.layer.1.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.14.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.12.layer.1.EncDecAttention.q.weight', 'decoder.block.17.layer.2.DenseReluDense.wo.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.14.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.0.SelfAttention.o.weight', 'decoder.block.13.layer.1.layer_norm.weight', 'decoder.block.22.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.19.layer.1.layer_norm.weight', 'decoder.block.17.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.19.layer.0.SelfAttention.k.weight', 'decoder.block.23.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.17.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.11.layer.0.SelfAttention.q.weight', 'decoder.block.14.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.23.layer.0.layer_norm.weight', 'decoder.block.22.layer.2.layer_norm.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.13.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.13.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'decoder.block.12.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.13.layer.1.EncDecAttention.k.weight', 'decoder.block.15.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.17.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.1.EncDecAttention.o.weight', 'decoder.block.22.layer.1.layer_norm.weight', 'decoder.block.17.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.20.layer.2.layer_norm.weight', 'decoder.block.6.layer.1.layer_norm.weight', 'decoder.block.15.layer.0.SelfAttention.o.weight', 'decoder.block.8.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.21.layer.0.SelfAttention.o.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.11.layer.0.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.k.weight', 'decoder.block.12.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.19.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.17.layer.1.EncDecAttention.k.weight', 'decoder.block.23.layer.0.SelfAttention.q.weight', 'decoder.block.15.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.0.SelfAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.14.layer.0.SelfAttention.o.weight', 'decoder.block.20.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.20.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.2.DenseReluDense.wo.weight', 'decoder.block.16.layer.0.SelfAttention.o.weight', 'decoder.block.11.layer.0.SelfAttention.v.weight', 'decoder.block.14.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.10.layer.0.SelfAttention.k.weight', 'decoder.block.12.layer.0.SelfAttention.v.weight', 'decoder.block.21.layer.0.SelfAttention.v.weight', 'decoder.block.22.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.12.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.12.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.23.layer.1.EncDecAttention.q.weight', 'decoder.block.23.layer.2.layer_norm.weight', 'decoder.block.17.layer.1.EncDecAttention.o.weight', 'decoder.block.8.layer.1.EncDecAttention.v.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.5.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.18.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.23.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.21.layer.1.layer_norm.weight', 'decoder.block.9.layer.2.DenseReluDense.wo.weight', 'decoder.block.1.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.12.layer.1.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.o.weight', 'decoder.block.14.layer.0.layer_norm.weight', 'decoder.block.13.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.21.layer.1.EncDecAttention.v.weight', 'decoder.block.8.layer.0.SelfAttention.k.weight', 'decoder.block.9.layer.0.SelfAttention.v.weight', 'decoder.block.8.layer.1.EncDecAttention.o.weight', 'decoder.block.16.layer.0.layer_norm.weight', 'decoder.block.10.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.19.layer.2.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.17.layer.2.layer_norm.weight', 'decoder.block.13.layer.0.SelfAttention.k.weight', 'decoder.block.18.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'decoder.block.17.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.15.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.7.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.13.layer.2.DenseReluDense.wo.weight', 'decoder.block.18.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.13.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.13.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.18.layer.0.SelfAttention.k.weight', 'decoder.block.9.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.19.layer.1.EncDecAttention.o.weight', 'decoder.block.10.layer.1.layer_norm.weight', 'decoder.block.13.layer.0.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.q.weight', 'decoder.block.22.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.7.layer.0.SelfAttention.v.weight', 'decoder.block.19.layer.0.SelfAttention.v.weight', 'decoder.block.9.layer.0.layer_norm.weight', 'decoder.block.16.layer.1.EncDecAttention.k.weight', 'decoder.block.21.layer.1.EncDecAttention.q.weight', 'decoder.block.18.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.10.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.16.layer.1.EncDecAttention.q.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.20.layer.2.DenseReluDense.wo.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.19.layer.1.EncDecAttention.v.weight', 'decoder.block.9.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.22.layer.1.EncDecAttention.o.weight', 'decoder.block.8.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.0.layer_norm.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'decoder.block.20.layer.0.SelfAttention.q.weight', 'decoder.block.21.layer.1.EncDecAttention.o.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.16.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.9.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.11.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.6.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.18.layer.0.layer_norm.weight', 'decoder.block.10.layer.1.EncDecAttention.v.weight', 'decoder.block.13.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.1.EncDecAttention.k.weight', 'decoder.block.16.layer.0.SelfAttention.k.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.7.layer.0.layer_norm.weight', 'decoder.block.15.layer.2.layer_norm.weight', 'decoder.block.6.layer.2.layer_norm.weight', 'decoder.block.8.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.18.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.21.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.8.layer.0.layer_norm.weight', 'decoder.block.20.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.k.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.18.layer.1.layer_norm.weight', 'decoder.block.20.layer.1.layer_norm.weight', 'decoder.block.20.layer.1.EncDecAttention.v.weight', 'decoder.block.23.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.10.layer.0.SelfAttention.q.weight', 'decoder.block.17.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'decoder.block.9.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'decoder.block.16.layer.1.EncDecAttention.o.weight', 'decoder.block.9.layer.1.EncDecAttention.v.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.11.layer.2.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.3.layer.2.DenseReluDense.wi_0.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'lm_head.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.2.DenseReluDense.wo.weight', 'decoder.block.8.layer.0.SelfAttention.o.weight', 'decoder.block.17.layer.0.layer_norm.weight', 'decoder.block.20.layer.0.SelfAttention.v.weight', 'decoder.block.14.layer.0.SelfAttention.q.weight', 'decoder.block.16.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.23.layer.0.SelfAttention.o.weight', 'decoder.block.14.layer.1.EncDecAttention.v.weight', 'decoder.block.13.layer.1.EncDecAttention.q.weight', 'decoder.block.22.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.2.layer_norm.weight', 'decoder.block.23.layer.2.DenseReluDense.wi_1.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.18.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.EncDecAttention.v.weight', 'decoder.block.6.layer.0.SelfAttention.q.weight', 'decoder.block.17.layer.0.SelfAttention.o.weight', 'decoder.block.18.layer.2.layer_norm.weight', 'decoder.block.14.layer.1.layer_norm.weight', 'decoder.block.11.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.1.EncDecAttention.o.weight']\n", 203 | "- This IS expected if you are initializing T5EncoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 204 | "- This IS NOT expected if you are initializing T5EncoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 205 | ] 206 | }, 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "Successfully loaded checkpoint from: ./weights/tango-full-ft-audiocaps\n" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "tango = Tango()" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "c5d4e570", 222 | "metadata": {}, 223 | "source": [ 224 | "## DNNM+ & Repaint" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "febbce01-c74c-435b-aaac-f76d30e87747", 230 | "metadata": {}, 231 | "source": [ 232 | "Change the following parameters to test DDNM+ on the denoising or impainting task.\n", 233 | "\n", 234 | "You should first download the AudioCaps dataset and change the paths below accordingly. Please note that different seeds can impact the quality of final results." 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 4, 240 | "id": "49fab913-9a2e-46e5-bda1-b4845a048a72", 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "/mnt/media/christian/DiffInpainting/versions\n" 248 | ] 249 | }, 250 | { 251 | "name": "stderr", 252 | "output_type": "stream", 253 | "text": [ 254 | "/home/ispamm/miniconda3/envs/DiffInp/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", 255 | " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "%cd ../versions" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 5, 266 | "id": "97ff368f-e774-40f4-b0b6-adf2adbce95e", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "model = 'ddnm+' # 'ddnm+' or 'repaint'\n", 271 | "snr = 20 # 17.5, 20, 30 # Specify the PSNR level of the noisy audio\n", 272 | "noisy_prompts=True # Whether to apply noise to text prompts or not\n", 273 | "prompt_psnr = snr # Here you can control the amount of noise to apply to the prompt embedding\n", 274 | "time_mask_percentage = (0.45, 0.55) # IMPAINTING? If not, set (0.,0.)\n", 275 | "mask_type = 'time' # type of data you want to apply the mask to\n", 276 | "print_report = False # True to print final report\n", 277 | "num_samples = 1 # Number of audio files with which to test the method\n", 278 | "save_output_audio = True # Save the output audio file\n", 279 | "audio_description = 'ddnmp_snr20' # filename of the saved files\n", 280 | "save_noisy_audio = False # Whether to save noise latents converted to audio or not\n", 281 | "caption_filter = None # Search for specific words in the captions (list of words)\n", 282 | "\n", 283 | "name_experiment = 'DDNM+ inpaint SNR20 test'\n", 284 | "output_path = '../output'\n", 285 | "dataset_path = '../AudioCaps'\n", 286 | "dataset_subset = 'AudioCaps_Val'" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 6, 292 | "id": "cc384970", 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "DDNM+\n" 300 | ] 301 | }, 302 | { 303 | "name": "stderr", 304 | "output_type": "stream", 305 | "text": [ 306 | "100%|██████████████████████████████████████████████████| 1000/1000 [07:13<00:00, 2.31it/s]\n" 307 | ] 308 | }, 309 | { 310 | "data": { 311 | "text/plain": [ 312 | "
" 313 | ] 314 | }, 315 | "metadata": {}, 316 | "output_type": "display_data" 317 | } 318 | ], 319 | "source": [ 320 | "import pandas as pd\n", 321 | "import shutil\n", 322 | "import utils\n", 323 | "\n", 324 | "\n", 325 | "output_path_generated = os.path.join(output_path, f'{name_experiment}/generated')\n", 326 | "output_path_noisy = os.path.join(output_path, f'{name_experiment}/noisy')\n", 327 | "output_path_clean = os.path.join(output_path, f'{name_experiment}/clean')\n", 328 | "output_path_spectrograms = os.path.join(output_path, f'{name_experiment}/spectrograms')\n", 329 | "\n", 330 | "usecols = [\"youtube_id\", \"start_time\", \"caption\"]\n", 331 | "file_list = pd.read_csv(os.path.join(dataset_path, \"val.csv\"), index_col=\"youtube_id\", usecols=usecols)\n", 332 | "\n", 333 | "d_audios = utils.get_d_audios(file_list, num_to_select=num_samples, dataset_path = os.path.join(dataset_path, dataset_subset), caption_filter=caption_filter, max_f=200)\n", 334 | "\n", 335 | "if not os.path.exists(os.path.join(output_path, name_experiment)):\n", 336 | " os.mkdir(os.path.join(output_path, name_experiment))\n", 337 | " os.mkdir(output_path_generated)\n", 338 | " os.mkdir(output_path_noisy)\n", 339 | " os.mkdir(output_path_clean)\n", 340 | " os.mkdir(output_path_spectrograms)\n", 341 | "\n", 342 | "inference_with_mask = utils.get_version(model)\n", 343 | "\n", 344 | "for i in range(0, num_samples):\n", 345 | " original_audio_paths = [os.path.join(dataset_path, dataset_subset, f\"{d_audios[i]['Index']}_{d_audios[i]['start_time']}.wav\")]\n", 346 | " caption = d_audios[i]['caption']\n", 347 | " shutil.copyfile(original_audio_paths[0], os.path.join(output_path_clean, f\"{i}.wav\"))\n", 348 | " \n", 349 | " # Get the embeddings of the original audio\n", 350 | " _, original_latents = utils.get_original_latents(original_audio_paths, tango)\n", 351 | "\n", 352 | " # Apply noise to the latents\n", 353 | " noisy_latents = utils.apply_noise(original_latents, snr, verbose=False)\n", 354 | "\n", 355 | " # Get the mask (time or mel) to apply to the latents (needed for inpainting)\n", 356 | " if sum(time_mask_percentage)==0:\n", 357 | " ipainting = True\n", 358 | " mask = utils.get_mask(time_mask_percentage, mask_type=mask_type)\n", 359 | "\n", 360 | " # Set the sigma_y value\n", 361 | " sigma_y = utils.get_sigma_y(noisy_latents, mask)\n", 362 | "\n", 363 | " with torch.no_grad():\n", 364 | " \n", 365 | " if model == 'ddnm+':\n", 366 | " latents = inference_with_mask(tango.model, [caption], tango.scheduler,\n", 367 | " num_steps=1000, guidance_scale=3, num_samples_per_prompt=1, disable_progress=False,\n", 368 | " mask=mask, original= original_latents if mask is None else noisy_latents, \n", 369 | " sigma_y=sigma_y, travel_length=0,\n", 370 | " noisy_prompts=noisy_prompts, psnr_prompts=prompt_psnr)\n", 371 | " elif model == 'repaint' and mask is not None:\n", 372 | " latents = inference_with_mask(tango.model, [caption], tango.scheduler,\n", 373 | " t_T=1000, guidance_scale=3, num_samples_per_prompt=1, disable_progress=False,\n", 374 | " mask=mask, original=original_latents)\n", 375 | " else:\n", 376 | " print(\"Please select a proper method for the desired task\")\n", 377 | "\n", 378 | "\n", 379 | " output_mels = tango.vae.decode_first_stage(latents)\n", 380 | " waves = tango.vae.decode_to_waveform(output_mels)\n", 381 | "\n", 382 | " utils.save_audio(output_path_generated, waves, id_file=f'{i}', descr=audio_description)\n", 383 | "\n", 384 | " original_mels = tango.vae.decode_first_stage(original_latents)\n", 385 | " \n", 386 | " \n", 387 | " noisy_mels = tango.vae.decode_first_stage(noisy_latents)\n", 388 | " noisy_waves = tango.vae.decode_to_waveform(noisy_mels)\n", 389 | " \n", 390 | " if save_noisy_audio:\n", 391 | " utils.save_audio(output_path_noisy, noisy_waves, id_file=f'{i}', descr=audio_description)\n", 392 | " \n", 393 | "\n", 394 | " mel_mask = utils.get_mask(time_mask_percentage, mask_type='mel')\n", 395 | " utils.plot_spectrogram(original_mels, noisy_mels, output_mels, mel_mask=mel_mask, save_fig=True, id_s=i, output_path=output_path_spectrograms)\n" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "id": "55b3de60-5992-42a9-8373-d95a15c5b07c", 401 | "metadata": {}, 402 | "source": [ 403 | "## Evaluate" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 8, 409 | "id": "6cac7b24-9a2e-4dbd-a322-7f0004317003", 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "name": "stderr", 414 | "output_type": "stream", 415 | "text": [ 416 | "Using cache found in /home/christian/.cache/torch/hub/harritaylor_torchvggish_master\n" 417 | ] 418 | }, 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "\u001b[1m Results \u001b[0m \n", 424 | " SNR SDR \n", 425 | "mean -2.2181 -10.5631 \n", 426 | "std 0.0895 0.4068 \n", 427 | "min -2.3077 -10.9699 \n", 428 | "max -2.1286 -10.1563 \n", 429 | "\n", 430 | "FAD 4.2125 \n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "import utils\n", 436 | "\n", 437 | "folders = ['generated', 'clean']\n", 438 | "\n", 439 | "# specify set of audios to exclude from the metrics computation\n", 440 | "exclude_audios = []\n", 441 | "\n", 442 | "snr_all, sdr_all, fad = utils.compute_metrics(os.path.join(output_path, name_experiment, folders[0]), \n", 443 | " os.path.join(output_path, name_experiment, folders[1]), exclude=exclude_audios)\n", 444 | "utils.print_metrics_report(snr_all, sdr_all, fad)" 445 | ] 446 | } 447 | ], 448 | "metadata": { 449 | "kernelspec": { 450 | "display_name": "DiffInp", 451 | "language": "python", 452 | "name": "diffinp" 453 | }, 454 | "language_info": { 455 | "codemirror_mode": { 456 | "name": "ipython", 457 | "version": 3 458 | }, 459 | "file_extension": ".py", 460 | "mimetype": "text/x-python", 461 | "name": "python", 462 | "nbconvert_exporter": "python", 463 | "pygments_lexer": "ipython3", 464 | "version": "3.10.14" 465 | } 466 | }, 467 | "nbformat": 4, 468 | "nbformat_minor": 5 469 | } 470 | --------------------------------------------------------------------------------