├── LICENSE ├── README.md ├── conf ├── conf.yaml ├── diff_params │ └── edm.yaml ├── dset │ ├── librispeech.yaml │ ├── maestro_allyears.yaml │ └── musicnet.yaml ├── exp │ ├── librispeech16k_8s.yaml │ ├── maestro22k_131072.yaml │ ├── maestro22k_8s.yaml │ ├── musicnet44k_4s.yaml │ ├── musicnet44k_8s.yaml │ └── test_cqtdiff_22k.yaml ├── logging │ ├── base_logging.yaml │ ├── debug_logging.yaml │ ├── frequent_logging.yaml │ └── huge_model_logging.yaml ├── network │ ├── ADP_raw_patching.yaml │ ├── paper_1912_unet_cqt_oct_attention_44k_2.yaml │ ├── paper_1912_unet_cqt_oct_attention_adaLN_2.yaml │ ├── paper_1912_unet_cqt_oct_noattention_adaln.yaml │ └── unet_cqtdiff_original.yaml └── tester │ ├── edm_2ndorder_stochastic.yaml │ ├── inpainting_tester.yaml │ └── inpainting_tester_shortgaps.yaml ├── datasets ├── audiofolder.py ├── audiofolder_test.py ├── librispeech.py ├── maestro_dataset.py └── maestro_dataset_test.py ├── diff_params └── edm.py ├── networks └── unet_cqt_oct_with_projattention_adaLN_2.py ├── notebooks └── demo_inpainting_spectrogram.ipynb ├── test.py ├── testing.sh ├── testing ├── edm_sampler.py ├── edm_sampler_inpainting.py ├── tester.py └── tester_inpainting.py ├── testing_shortgaps.sh ├── train.py ├── training.sh ├── training ├── __init__.py └── trainer.py └── utils ├── dnnlib ├── __init__.py └── util.py ├── logging.py ├── setup.py ├── torch_utils ├── __init__.py ├── distributed.py ├── misc.py └── training_stats.py ├── training_utils.py └── utils_notebook.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eloi Moliner Juanpere 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion-Based Audio Inpainting 2 | 3 | Official repository of the paper: 4 | > E. Moliner and V. Välimäki,, "Diffusion-Based Audio Inpainting", Journal of the Audio Engineering Society, March 2024 5 | 6 | Audio examples available at the [companion website](http://research.spa.aalto.fi/publications/papers/jaes-diffusion-inpainting/) 7 | 8 | Trained models available at [HuggingFace](https://huggingface.co/Eloimoliner/audio-inpainting-diffusion) 9 | -------------------------------------------------------------------------------- /conf/conf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dset: maestro_allyears 3 | - network: paper_1912_unet_cqt_oct_attention_adaLN_2 4 | - diff_params: edm 5 | - tester: inpainting_tester 6 | - exp: maestro22k_8s 7 | - logging: huge_model_logging 8 | 9 | model_dir: "experiments/cqt" 10 | 11 | dry_run: False #', help='Print training options and exit', is_flag=True) 12 | 13 | #testing (demos) 14 | 15 | 16 | hydra: 17 | job: 18 | config: 19 | # configuration for the ${hydra.job.override_dirname} runtime variable 20 | override_dirname: 21 | kv_sep: '=' 22 | item_sep: ',' 23 | # Remove all paths, as the / in them would mess up things 24 | exclude_keys: ['path_experiment', 25 | 'hydra.job_logging.handles.file.filename'] 26 | -------------------------------------------------------------------------------- /conf/diff_params/edm.yaml: -------------------------------------------------------------------------------- 1 | callable: "diff_params.edm.EDM" 2 | 3 | sigma_data: 0.063 #default for maestro 4 | sigma_min: 1e-5 5 | sigma_max: 10 6 | P_mean: -1.2 #what is this for? 7 | P_std: 1.2 #ehat is this for? 8 | ro: 13 9 | ro_train: 10 10 | Schurn: 5 11 | Snoise: 1 12 | Stmin: 0 13 | Stmax: 50 14 | 15 | aweighting: 16 | use_aweighting: False 17 | ntaps: 101 18 | 19 | 20 | -------------------------------------------------------------------------------- /conf/dset/librispeech.yaml: -------------------------------------------------------------------------------- 1 | name: "musicnet" 2 | type: "audio" 3 | callable: "datasets.librispeech.LibrispeechTrain" 4 | #path: "/u/25/molinee2/unix/datasets/MusicNet/train" 5 | path: "/scratch/work/molinee2/datasets/LibriSpeech" 6 | 7 | train_dirs: ["train_clean-100", "train-clean-360"] 8 | 9 | overfit: False 10 | 11 | test: 12 | callable: "datasets.librispeech.LibrispeechTest" 13 | #path: "/u/25/molinee2/unix/datasets/MusicNet/train" 14 | path: "/scratch/work/molinee2/datasets/LibriSpeech/dev-clean" 15 | num_samples: 4 16 | batch_size: 1 17 | 18 | -------------------------------------------------------------------------------- /conf/dset/maestro_allyears.yaml: -------------------------------------------------------------------------------- 1 | name: "maestro_allyears" 2 | callable: "datasets.maestro_dataset.MaestroDataset_fs" 3 | type: "audio" 4 | path: "/scratch/shareddata/dldata/maestro/v3.0.0/maestro-v3.0.0" 5 | #years: [2004,2006,2008,2009,2011, 2013, 2014, 2015, 2017, 2018] # I will use only these years as the most recent ones are sampled at 48kHz 6 | years: [ 2015, 2017, 2018] # I will use only these years as the most recent ones are sampled at 48kHz 7 | years_test: [ 2009] #less years to make it quicker to compute, otherwise sampling is crazy long. 8 | cache: True #', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) 9 | 10 | 11 | load_len: 405000 12 | 13 | test: 14 | callable: "datasets.maestro_dataset_test.MaestroDatasetTestChunks" 15 | num_samples: 4 16 | batch_size: 1 17 | -------------------------------------------------------------------------------- /conf/dset/musicnet.yaml: -------------------------------------------------------------------------------- 1 | name: "musicnet" 2 | type: "audio" 3 | callable: "datasets.audiofolder.AudioFolderDataset" 4 | #path: "/u/25/molinee2/unix/datasets/MusicNet/train" 5 | path: "/scratch/work/molinee2/datasets/musicnet/train_data" 6 | 7 | overfit: False 8 | 9 | test: 10 | callable: "datasets.audiofolder_test.AudioFolderDatasetTest" 11 | #path: "/u/25/molinee2/unix/datasets/MusicNet/train" 12 | path: "/scratch/work/molinee2/datasets/musicnet/test_data" 13 | num_samples: 4 14 | batch_size: 1 15 | 16 | -------------------------------------------------------------------------------- /conf/exp/librispeech16k_8s.yaml: -------------------------------------------------------------------------------- 1 | exp_name: "22k_8s" 2 | 3 | trainer_callable: "training.trainer.Trainer" 4 | 5 | wandb: 6 | entity: "eloimoliner" 7 | project: "A-diffusion" 8 | 9 | model_dir: None 10 | #main options 11 | #related to optimization 12 | optimizer: 13 | type: "adam" #only "adam implemented 14 | beta1: 0.9 15 | beta2: 0.999 16 | eps: 1e-8 #for numerical stability, we may need to modify it if usinf fp16 17 | 18 | lr: 2e-4 # help='Learning rate', 19 | lr_rampup_it: 10000 #, help='Learning rate rampup duration' 20 | 21 | #for lr scheduler (not noise schedule!!) TODO (I think) 22 | scheduler_step_size: 60000 23 | scheduler_gamma: 0.8 24 | 25 | 26 | #save_model: True #wether to save the checkpoints of the model in this experiment 27 | 28 | 29 | # Training related. 30 | #total_its: 100000 #help='Training duration' 31 | batch: 4 # help='Total batch size' 32 | batch_gpu: 4 #, help='Limit batch size per GPU' 33 | num_accumulation_rounds: 1 #gradient accumulation, truncated backprop 34 | 35 | 36 | # Performance-related. 37 | use_fp16: False #', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 38 | ls: 1 #', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 39 | bench: True #', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 40 | num_workers: 4 #', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 41 | 42 | # I/O-related. moved to logging 43 | seed: 42 #', help='Random seed [default: random]', metavar='INT', type=int) 44 | transfer: None #', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 45 | #resume: True #', help='Resume from previous training state', metavar='PT', type=str) 46 | resume: True 47 | resume_checkpoint: None 48 | 49 | 50 | #audio data related 51 | sample_rate: 16000 52 | audio_len: 184184 53 | resample_factor: 1 #useful for the maestro dataset, which is sampled at 44.1kHz or 48kHz and we want to resample it to 22.05kHz 54 | 55 | 56 | #training functionality parameters 57 | device: "cpu" #it will be updated in the code, no worries 58 | 59 | #training use_cqt_DC_correction: False #if True, the loss will be corrected for the DC component and the nyquist frequency. This is important because we are discarding the DC component and the nyquist frequency in the cqt 60 | 61 | #ema_rate: "0.9999" # comma-separated list of EMA values 62 | ema_rate: 0.9999 #unused 63 | ema_rampup: 10000 #linear rampup to ema_rate #help='EMA half-life' 64 | 65 | 66 | #gradient clipping 67 | use_grad_clip: True 68 | max_grad_norm: 1 69 | 70 | restore : False 71 | checkpoint_id: None 72 | 73 | #pre-emph. This should not go here! either logging or network 74 | 75 | #augmentation related 76 | augmentations: 77 | rev_polarity: True 78 | pitch_shift: 79 | use: False 80 | min_semitones: -6 81 | max_semitones: 6 82 | gain: 83 | use: False 84 | min_db: -3 85 | max_db: 3 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /conf/exp/maestro22k_131072.yaml: -------------------------------------------------------------------------------- 1 | exp_name: "22k_8s" 2 | 3 | trainer_callable: "training.trainer.Trainer" 4 | 5 | wandb: 6 | entity: "eloimoliner" 7 | project: "A-diffusion" 8 | 9 | model_dir: None 10 | #main options 11 | #related to optimization 12 | optimizer: 13 | type: "adam" #only "adam implemented 14 | beta1: 0.9 15 | beta2: 0.999 16 | eps: 1e-8 #for numerical stability, we may need to modify it if usinf fp16 17 | 18 | lr: 2e-4 # help='Learning rate', 19 | lr_rampup_it: 10000 #, help='Learning rate rampup duration' 20 | 21 | #for lr scheduler (not noise schedule!!) TODO (I think) 22 | scheduler_step_size: 60000 23 | scheduler_gamma: 0.8 24 | 25 | 26 | #save_model: True #wether to save the checkpoints of the model in this experiment 27 | 28 | 29 | # Training related. 30 | #total_its: 100000 #help='Training duration' 31 | batch: 8 # help='Total batch size' 32 | batch_gpu: 8 #, help='Limit batch size per GPU' 33 | num_accumulation_rounds: 1 #gradient accumulation, truncated backprop 34 | 35 | 36 | # Performance-related. 37 | use_fp16: False #', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 38 | ls: 1 #', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 39 | bench: True #', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 40 | num_workers: 4 #', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 41 | 42 | # I/O-related. moved to logging 43 | seed: 42 #', help='Random seed [default: random]', metavar='INT', type=int) 44 | transfer: None #', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 45 | #resume: True #', help='Resume from previous training state', metavar='PT', type=str) 46 | resume: True 47 | resume_checkpoint: None 48 | 49 | 50 | #audio data related 51 | sample_rate: 22050 52 | audio_len: 131072 53 | resample_factor: 2 #useful for the maestro dataset, which is sampled at 44.1kHz or 48kHz and we want to resample it to 22.05kHz 54 | 55 | 56 | #training functionality parameters 57 | device: "cpu" #it will be updated in the code, no worries 58 | 59 | #training 60 | use_cqt_DC_correction: False #if True, the loss will be corrected for the DC component and the nyquist frequency. This is important because we are discarding the DC component and the nyquist frequency in the cqt 61 | 62 | #ema_rate: "0.9999" # comma-separated list of EMA values 63 | ema_rate: 0.9999 #unused 64 | ema_rampup: 10000 #linear rampup to ema_rate #help='EMA half-life' 65 | 66 | 67 | #gradient clipping 68 | use_grad_clip: True 69 | max_grad_norm: 1 70 | 71 | restore : False 72 | checkpoint_id: None 73 | 74 | #pre-emph. This should not go here! either logging or network 75 | 76 | #augmentation related 77 | augmentations: 78 | rev_polarity: True 79 | pitch_shift: 80 | use: False 81 | min_semitones: -6 82 | max_semitones: 6 83 | gain: 84 | use: False 85 | min_db: -3 86 | max_db: 3 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /conf/exp/maestro22k_8s.yaml: -------------------------------------------------------------------------------- 1 | exp_name: "22k_8s" 2 | 3 | trainer_callable: "training.trainer.Trainer" 4 | 5 | wandb: 6 | entity: "eloimoliner" 7 | project: "A-diffusion" 8 | 9 | model_dir: None 10 | #main options 11 | #related to optimization 12 | optimizer: 13 | type: "adam" #only "adam implemented 14 | beta1: 0.9 15 | beta2: 0.999 16 | eps: 1e-8 #for numerical stability, we may need to modify it if usinf fp16 17 | 18 | lr: 2e-4 # help='Learning rate', 19 | lr_rampup_it: 10000 #, help='Learning rate rampup duration' 20 | 21 | #for lr scheduler (not noise schedule!!) TODO (I think) 22 | scheduler_step_size: 60000 23 | scheduler_gamma: 0.8 24 | 25 | 26 | #save_model: True #wether to save the checkpoints of the model in this experiment 27 | 28 | 29 | # Training related. 30 | #total_its: 100000 #help='Training duration' 31 | batch: 4 # help='Total batch size' 32 | batch_gpu: 4 #, help='Limit batch size per GPU' 33 | num_accumulation_rounds: 1 #gradient accumulation, truncated backprop 34 | 35 | 36 | # Performance-related. 37 | use_fp16: False #', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 38 | ls: 1 #', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 39 | bench: True #', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 40 | num_workers: 4 #', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 41 | 42 | # I/O-related. moved to logging 43 | seed: 42 #', help='Random seed [default: random]', metavar='INT', type=int) 44 | transfer: None #', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 45 | #resume: True #', help='Resume from previous training state', metavar='PT', type=str) 46 | resume: True 47 | resume_checkpoint: None 48 | 49 | 50 | #audio data related 51 | sample_rate: 22050 52 | audio_len: 184184 53 | resample_factor: 2 #useful for the maestro dataset, which is sampled at 44.1kHz or 48kHz and we want to resample it to 22.05kHz 54 | 55 | 56 | #training functionality parameters 57 | device: "cpu" #it will be updated in the code, no worries 58 | 59 | #training 60 | use_cqt_DC_correction: False #if True, the loss will be corrected for the DC component and the nyquist frequency. This is important because we are discarding the DC component and the nyquist frequency in the cqt 61 | 62 | #ema_rate: "0.9999" # comma-separated list of EMA values 63 | ema_rate: 0.9999 #unused 64 | ema_rampup: 10000 #linear rampup to ema_rate #help='EMA half-life' 65 | 66 | 67 | #gradient clipping 68 | use_grad_clip: True 69 | max_grad_norm: 1 70 | 71 | restore : False 72 | checkpoint_id: None 73 | 74 | #pre-emph. This should not go here! either logging or network 75 | 76 | #augmentation related 77 | augmentations: 78 | rev_polarity: True 79 | pitch_shift: 80 | use: False 81 | min_semitones: -6 82 | max_semitones: 6 83 | gain: 84 | use: False 85 | min_db: -3 86 | max_db: 3 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /conf/exp/musicnet44k_4s.yaml: -------------------------------------------------------------------------------- 1 | exp_name: "44k_4s" 2 | 3 | trainer_callable: "training.trainer.Trainer" 4 | 5 | wandb: 6 | entity: "eloimoliner" 7 | project: "A-diffusion" 8 | 9 | model_dir: None 10 | #main options 11 | #related to optimization 12 | optimizer: 13 | type: "adam" #only "adam implemented 14 | beta1: 0.9 15 | beta2: 0.999 16 | eps: 1e-8 #for numerical stability, we may need to modify it if usinf fp16 17 | 18 | lr: 2e-4 # help='Learning rate', 19 | lr_rampup_it: 10000 #, help='Learning rate rampup duration' 20 | 21 | #for lr scheduler (not noise schedule!!) TODO (I think) 22 | scheduler_step_size: 60000 23 | scheduler_gamma: 0.8 24 | 25 | 26 | #save_model: True #wether to save the checkpoints of the model in this experiment 27 | 28 | 29 | # Training related. 30 | #total_its: 100000 #help='Training duration' 31 | batch: 4 # help='Total batch size' 32 | batch_gpu: 4 #, help='Limit batch size per GPU' 33 | num_accumulation_rounds: 1 #gradient accumulation, truncated backprop 34 | 35 | 36 | # Performance-related. 37 | use_fp16: False #', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 38 | ls: 1 #', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 39 | bench: True #', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 40 | num_workers: 4 #', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 41 | 42 | # I/O-related. moved to logging 43 | seed: 42 #', help='Random seed [default: random]', metavar='INT', type=int) 44 | transfer: None #', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 45 | #resume: True #', help='Resume from previous training state', metavar='PT', type=str) 46 | resume: True 47 | resume_checkpoint: None 48 | 49 | 50 | #audio data related 51 | sample_rate: 44100 52 | audio_len: 184184 53 | resample_factor: 1 #useful for the maestro dataset, which is sampled at 44.1kHz or 48kHz and we want to resample it to 22.05kHz 54 | 55 | 56 | #training functionality parameters 57 | device: "cpu" #it will be updated in the code, no worries 58 | 59 | #training 60 | use_cqt_DC_correction: False #if True, the loss will be corrected for the DC component and the nyquist frequency. This is important because we are discarding the DC component and the nyquist frequency in the cqt 61 | 62 | #ema_rate: "0.9999" # comma-separated list of EMA values 63 | ema_rate: 0.9999 #unused 64 | ema_rampup: 10000 #linear rampup to ema_rate #help='EMA half-life' 65 | 66 | 67 | #gradient clipping 68 | use_grad_clip: True 69 | max_grad_norm: 1 70 | 71 | restore : False 72 | checkpoint_id: None 73 | 74 | #pre-emph. This should not go here! either logging or network 75 | 76 | #augmentation related 77 | augmentations: 78 | rev_polarity: True 79 | pitch_shift: 80 | use: False 81 | min_semitones: -6 82 | max_semitones: 6 83 | gain: 84 | use: False 85 | min_db: -3 86 | max_db: 3 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /conf/exp/musicnet44k_8s.yaml: -------------------------------------------------------------------------------- 1 | exp_name: "44k_8s" 2 | 3 | trainer_callable: "training.trainer.Trainer" 4 | 5 | wandb: 6 | entity: "eloimoliner" 7 | project: "A-diffusion" 8 | 9 | model_dir: None 10 | #main options 11 | #related to optimization 12 | optimizer: 13 | type: "adam" #only "adam implemented 14 | beta1: 0.9 15 | beta2: 0.999 16 | eps: 1e-8 #for numerical stability, we may need to modify it if usinf fp16 17 | 18 | lr: 2e-4 # help='Learning rate', 19 | lr_rampup_it: 10000 #, help='Learning rate rampup duration' 20 | 21 | #for lr scheduler (not noise schedule!!) TODO (I think) 22 | scheduler_step_size: 60000 23 | scheduler_gamma: 0.8 24 | 25 | 26 | #save_model: True #wether to save the checkpoints of the model in this experiment 27 | 28 | 29 | # Training related. 30 | #total_its: 100000 #help='Training duration' 31 | batch: 4 # help='Total batch size' 32 | batch_gpu: 4 #, help='Limit batch size per GPU' 33 | num_accumulation_rounds: 1 #gradient accumulation, truncated backprop 34 | 35 | 36 | # Performance-related. 37 | use_fp16: False #', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 38 | ls: 1 #', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 39 | bench: True #', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 40 | num_workers: 4 #', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 41 | 42 | # I/O-related. moved to logging 43 | seed: 42 #', help='Random seed [default: random]', metavar='INT', type=int) 44 | transfer: None #', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 45 | #resume: True #', help='Resume from previous training state', metavar='PT', type=str) 46 | resume: True 47 | resume_checkpoint: None 48 | 49 | 50 | #audio data related 51 | sample_rate: 44100 52 | audio_len: 368368 53 | resample_factor: 1 #useful for the maestro dataset, which is sampled at 44.1kHz or 48kHz and we want to resample it to 22.05kHz 54 | 55 | 56 | #training functionality parameters 57 | device: "cpu" #it will be updated in the code, no worries 58 | 59 | #training 60 | use_cqt_DC_correction: False #if True, the loss will be corrected for the DC component and the nyquist frequency. This is important because we are discarding the DC component and the nyquist frequency in the cqt 61 | 62 | #ema_rate: "0.9999" # comma-separated list of EMA values 63 | ema_rate: 0.9999 #unused 64 | ema_rampup: 10000 #linear rampup to ema_rate #help='EMA half-life' 65 | 66 | 67 | #gradient clipping 68 | use_grad_clip: True 69 | max_grad_norm: 1 70 | 71 | restore : False 72 | checkpoint_id: None 73 | 74 | #pre-emph. This should not go here! either logging or network 75 | 76 | #augmentation related 77 | augmentations: 78 | rev_polarity: True 79 | pitch_shift: 80 | use: False 81 | min_semitones: -6 82 | max_semitones: 6 83 | gain: 84 | use: False 85 | min_db: -3 86 | max_db: 3 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /conf/exp/test_cqtdiff_22k.yaml: -------------------------------------------------------------------------------- 1 | exp_name: "testing_training_code" 2 | 3 | trainer_callable: "training.trainer.Trainer" 4 | 5 | wandb: 6 | entity: "eloimoliner" 7 | project: "A-diffusion" 8 | 9 | model_dir: None 10 | #main options 11 | #related to optimization 12 | optimizer: 13 | type: "adam" #only "adam implemented 14 | beta1: 0.9 15 | beta2: 0.999 16 | eps: 1e-8 #for numerical stability, we may need to modify it if usinf fp16 17 | 18 | lr: 2e-4 # help='Learning rate', 19 | lr_rampup_it: 10000 #, help='Learning rate rampup duration' 20 | 21 | #for lr scheduler (not noise schedule!!) TODO (I think) 22 | scheduler_step_size: 60000 23 | scheduler_gamma: 0.8 24 | 25 | 26 | #save_model: True #wether to save the checkpoints of the model in this experiment 27 | 28 | 29 | # Training related. 30 | #total_its: 100000 #help='Training duration' 31 | batch: 4 # help='Total batch size' 32 | batch_gpu: 4 #, help='Limit batch size per GPU' 33 | num_accumulation_rounds: 1 #gradient accumulation, truncated backprop 34 | 35 | 36 | # Performance-related. 37 | use_fp16: False #', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 38 | ls: 1 #', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 39 | bench: True #', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 40 | num_workers: 4 #', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 41 | 42 | # I/O-related. moved to logging 43 | seed: 42 #', help='Random seed [default: random]', metavar='INT', type=int) 44 | transfer: None #', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 45 | ##resume: True #', help='Resume from previous training state', metavar='PT', type=str) 46 | resume: True 47 | resume_checkpoint: None 48 | 49 | #audio data related 50 | sample_rate: 22050 51 | audio_len: 65536 52 | resample_factor: 2 #useful for the maestro dataset, which is sampled at 44.1kHz or 48kHz and we want to resample it to 22.05kHz 53 | 54 | 55 | #training functionality parameters 56 | device: "cpu" #it will be updated in the code, no worries 57 | 58 | #training 59 | use_cqt_DC_correction: False #if True, the loss will be corrected for the DC component and the nyquist frequency. This is important because we are discarding the DC component and the nyquist frequency in the cqt 60 | 61 | #ema_rate: "0.9999" # comma-separated list of EMA values 62 | ema_rate: 0.9999 #unused 63 | ema_rampup: 10000 #linear rampup to ema_rate #help='EMA half-life' 64 | 65 | 66 | #gradient clipping 67 | use_grad_clip: True 68 | max_grad_norm: 0.1 69 | 70 | restore : False 71 | checkpoint_id: None 72 | 73 | #pre-emph. This should not go here! either logging or network 74 | 75 | #augmentation related 76 | augmentations: 77 | rev_polarity: True 78 | pitch_shift: 79 | use: False 80 | min_semitones: -6 81 | max_semitones: 6 82 | gain: 83 | use: False 84 | min_db: -3 85 | max_db: 3 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /conf/logging/base_logging.yaml: -------------------------------------------------------------------------------- 1 | 2 | #logging params 3 | log: True 4 | log_interval: 1000 5 | heavy_log_interval: 50000 #same as save_interval 6 | save_model: True 7 | save_interval: 50000 8 | 9 | 10 | #about special logs 11 | num_sigma_bins: 20 12 | freq_cqt_logging: 100 13 | 14 | 15 | print_model_summary: True 16 | 17 | profiling: 18 | enabled: True 19 | wait: 5 20 | warmup: 10 21 | active: 2 22 | repeat: 1 23 | 24 | #stft 25 | stft: 26 | win_size: 1024 27 | hop_size: 256 28 | 29 | cqt: 30 | hop_length: 1024 31 | num_octs: 6 32 | fmin: 70 33 | bins_per_oct: 1 34 | 35 | log_feature_stats: True 36 | log_feature_stats_interval: 50000 37 | -------------------------------------------------------------------------------- /conf/logging/debug_logging.yaml: -------------------------------------------------------------------------------- 1 | 2 | #logging params 3 | log: True 4 | log_interval: 50 5 | heavy_log_interval: 100 #same as save_interval 6 | save_model: True 7 | save_interval: 100 #hihly recommended to be the same as heavy_log_interval, or at least a multiple of it 8 | 9 | remove_last_checkpoint: True 10 | #about special logs 11 | num_sigma_bins: 20 12 | freq_cqt_logging: 1 13 | 14 | 15 | print_model_summary: True 16 | 17 | profiling: 18 | enabled: False 19 | wait: 10 20 | warmup: 100 21 | active: 2 22 | repeat: 1 23 | 24 | #stft 25 | stft: 26 | win_size: 1024 27 | hop_size: 256 28 | 29 | cqt: 30 | hop_length: 1024 31 | num_octs: 8 32 | fmin: 32.70 33 | bins_per_oct: 1 34 | 35 | log_feature_stats: False 36 | log_feature_stats_interval: 20000000 37 | -------------------------------------------------------------------------------- /conf/logging/frequent_logging.yaml: -------------------------------------------------------------------------------- 1 | 2 | #logging params 3 | log: True 4 | log_interval: 1000 5 | heavy_log_interval: 20000 #same as save_interval 6 | save_model: True 7 | save_interval: 20000 8 | 9 | 10 | #about special logs 11 | num_sigma_bins: 20 12 | freq_cqt_logging: 100 13 | 14 | 15 | print_model_summary: True 16 | 17 | profiling: 18 | enabled: True 19 | wait: 5 20 | warmup: 10 21 | active: 2 22 | repeat: 1 23 | 24 | #stft 25 | stft: 26 | win_size: 1024 27 | hop_size: 256 28 | 29 | cqt: 30 | hop_length: 1024 31 | num_octs: 8 32 | fmin: 32.70 33 | bins_per_oct: 2 34 | 35 | log_feature_stats: True 36 | log_feature_stats_interval: 20000 37 | -------------------------------------------------------------------------------- /conf/logging/huge_model_logging.yaml: -------------------------------------------------------------------------------- 1 | 2 | #logging params 3 | log: True 4 | log_interval: 1000 5 | heavy_log_interval: 10000 #same as save_interval 6 | save_model: True 7 | save_interval: 10000 8 | 9 | remove_last_checkpoint: True 10 | 11 | 12 | #about special logs 13 | num_sigma_bins: 20 14 | freq_cqt_logging: 100 15 | 16 | 17 | print_model_summary: True 18 | 19 | profiling: 20 | enabled: True 21 | wait: 5 22 | warmup: 10 23 | active: 2 24 | repeat: 1 25 | 26 | #stft 27 | stft: 28 | win_size: 1024 29 | hop_size: 256 30 | 31 | cqt: 32 | hop_length: 1024 33 | num_octs: 8 34 | fmin: 32.70 35 | bins_per_oct: 2 36 | 37 | log_feature_stats: False 38 | log_feature_stats_interval: 1000000 39 | -------------------------------------------------------------------------------- /conf/network/ADP_raw_patching.yaml: -------------------------------------------------------------------------------- 1 | #add all the parameters of the layers you want to use in 1d U-Net 2 | name: "ADP_raw_patching" 3 | callable: "networks.flavio_models.modules.UNet1d" 4 | 5 | use_cqt_DC_correction: False 6 | 7 | channels: 128 8 | patch_factor: 16 9 | patch_blocks: 1 10 | resnet_groups: 8 11 | kernel_multiplier_downsample: 2 12 | multipliers: [1, 2, 4, 4, 4, 4, 4] 13 | factors: [2, 2, 2, 2, 2, 2] 14 | num_blocks: [2, 2, 2, 2, 2, 2] 15 | attentions: [0, 0, 0, 0, 1, 1, 1] 16 | attention: 17 | attention_heads: 16 18 | attention_features: 64 19 | attention_multiplier: 4 20 | attention_use_rel_pos: False 21 | use_nearest_upsample: False 22 | use_skip_scale: True 23 | use_context_time: True 24 | -------------------------------------------------------------------------------- /conf/network/paper_1912_unet_cqt_oct_attention_44k_2.yaml: -------------------------------------------------------------------------------- 1 | #network from the paper: Solving audio inverse problems with a diffusion model 2 | name: "unet_cqt_oct_with_attention" 3 | callable: "networks.unet_cqt_oct_with_projattention_adaLN_2.Unet_CQT_oct_with_attention" 4 | 5 | 6 | use_fencoding: False 7 | use_norm: True 8 | 9 | filter_out_cqt_DC_Nyq: True 10 | 11 | depth: 8 #total depth of the network (including the first stage) 12 | 13 | emb_dim: 256 #dimensionality of the RFF embeddings 14 | 15 | 16 | #dimensions of the first stage (the length of this vector should be equal to num_octs) 17 | #Ns: [64, 96 ,96, 128, 128,256, 256] #it is hardcoded 18 | Ns: [64, 64,96, 96, 128, 128,256, 256] #it is hardcoded 19 | #Ns: [8, 8 ,8, 8, 16,16, 16] #it is hardcoded 20 | 21 | attention_layers: [0, 0, 0, 0, 0, 1, 1, 1, 1] #num_octs+bottleneck 22 | #attention_Ns: [0, 0, 0, 0,256 ,512,1024 ,1024] 23 | 24 | #Ns: [8,8,16,16,32,32,64] 25 | Ss: [2,2,2, 2, 2, 2, 2] #downsample factors at the first stage, now it is ignored 26 | 27 | num_dils: [2,3,4,5,6,7,8, 8] 28 | 29 | cqt: 30 | window: "kaiser" 31 | beta: 1 32 | num_octs: 8 33 | bins_per_oct: 64 #this needs to be lower than 64, otherwise the time attention is inpractical 34 | 35 | 36 | #inner_Ns: [64, 64, 64, 64] 37 | #if 4x2, then down factor of 16! 38 | 39 | bottleneck_type: "res_dil_convs" 40 | 41 | num_bottleneck_layers: 1 42 | 43 | #transformer: 44 | # num_heads: 8 45 | # dim_head: 64 46 | # num_layers: 16 47 | # channels: 512 48 | # attn_dropout: 0.1 49 | # multiplier_ff: 4 50 | # activation: "gelu" #fixed 51 | 52 | 53 | 54 | #for now, only the last two layers have attention + bottleneck 55 | 56 | 57 | attention_dict: 58 | num_heads: 8 59 | attn_dropout: 0.0 60 | bias_qkv: False 61 | N: 0 62 | rel_pos_num_buckets: 32 63 | rel_pos_max_distance: 64 64 | use_rel_pos: False 65 | Nproj: 8 66 | #the number of channels is the same as the Ns of the corresponding layer 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /conf/network/paper_1912_unet_cqt_oct_attention_adaLN_2.yaml: -------------------------------------------------------------------------------- 1 | #network from the paper: Solving audio inverse problems with a diffusion model 2 | name: "unet_cqt_oct_with_attention" 3 | callable: "networks.unet_cqt_oct_with_projattention_adaLN_2.Unet_CQT_oct_with_attention" 4 | 5 | 6 | use_fencoding: False 7 | use_norm: True 8 | 9 | filter_out_cqt_DC_Nyq: True 10 | 11 | depth: 7 #total depth of the network (including the first stage) 12 | 13 | emb_dim: 256 #dimensionality of the RFF embeddings 14 | 15 | 16 | #dimensions of the first stage (the length of this vector should be equal to num_octs) 17 | #Ns: [64, 96 ,96, 128, 128,256, 256] #it is hardcoded 18 | Ns: [64,96, 96, 128, 128,256, 256] #it is hardcoded 19 | #Ns: [8, 8 ,8, 8, 16,16, 16] #it is hardcoded 20 | 21 | attention_layers: [0, 0, 0, 0, 1, 1, 1, 1] #num_octs+bottleneck 22 | #attention_Ns: [0, 0, 0, 0,256 ,512,1024 ,1024] 23 | 24 | #Ns: [8,8,16,16,32,32,64] 25 | Ss: [2,2,2, 2, 2, 2, 2] #downsample factors at the first stage, now it is ignored 26 | 27 | num_dils: [2,3,4,5,6,7,7] 28 | 29 | cqt: 30 | window: "kaiser" 31 | beta: 1 32 | num_octs: 7 33 | bins_per_oct: 64 #this needs to be lower than 64, otherwise the time attention is inpractical 34 | 35 | 36 | #inner_Ns: [64, 64, 64, 64] 37 | #if 4x2, then down factor of 16! 38 | 39 | bottleneck_type: "res_dil_convs" 40 | 41 | num_bottleneck_layers: 1 42 | #transformer: 43 | # num_heads: 8 44 | # dim_head: 64 45 | # num_layers: 16 46 | # channels: 512 47 | # attn_dropout: 0.1 48 | # multiplier_ff: 4 49 | # activation: "gelu" #fixed 50 | 51 | 52 | 53 | #for now, only the last two layers have attention + bottleneck 54 | 55 | 56 | attention_dict: 57 | num_heads: 8 58 | attn_dropout: 0.0 59 | bias_qkv: False 60 | N: 0 61 | rel_pos_num_buckets: 32 62 | rel_pos_max_distance: 64 63 | use_rel_pos: False 64 | Nproj: 8 65 | #the number of channels is the same as the Ns of the corresponding layer 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /conf/network/paper_1912_unet_cqt_oct_noattention_adaln.yaml: -------------------------------------------------------------------------------- 1 | #network from the paper: Solving audio inverse problems with a diffusion model 2 | name: "unet_cqt_oct_with_attention" 3 | #callable: "networks.unet_cqt_oct_with_projattention.Unet_CQT_oct_with_attention" 4 | callable: "networks.unet_cqt_oct_with_projattention_adaLN_2.Unet_CQT_oct_with_attention" 5 | 6 | 7 | use_fencoding: False 8 | use_norm: True 9 | 10 | filter_out_cqt_DC_Nyq: True 11 | 12 | depth: 7 #total depth of the network (including the first stage) 13 | 14 | emb_dim: 256 #dimensionality of the RFF embeddings 15 | 16 | 17 | #dimensions of the first stage (the length of this vector should be equal to num_octs) 18 | Ns: [64, 96 ,96, 128, 128,256, 256] #it is hardcoded 19 | #Ns: [8, 8 ,8, 8, 16,16, 16] #it is hardcoded 20 | 21 | #attention_layers: [0, 0, 0, 0, 1, 1, 1, 1] #num_octs+bottleneck 22 | attention_layers: [0, 0, 0, 0, 0, 0, 0, 0] #num_octs+bottleneck 23 | #attention_Ns: [0, 0, 0, 0,256 ,512,1024 ,1024] 24 | 25 | #Ns: [8,8,16,16,32,32,64] 26 | Ss: [2,2,2, 2, 2, 2, 2] #downsample factors at the first stage, now it is ignored 27 | 28 | num_dils: [2,3,4,5,6,7,7] 29 | 30 | cqt: 31 | window: "kaiser" 32 | beta: 1 33 | num_octs: 7 34 | bins_per_oct: 64 #this needs to be lower than 64, otherwise the time attention is inpractical 35 | 36 | 37 | #inner_Ns: [64, 64, 64, 64] 38 | #if 4x2, then down factor of 16! 39 | 40 | bottleneck_type: "res_dil_convs" 41 | 42 | num_bottleneck_layers: 1 43 | #transformer: 44 | # num_heads: 8 45 | # dim_head: 64 46 | # num_layers: 16 47 | # channels: 512 48 | # attn_dropout: 0.1 49 | # multiplier_ff: 4 50 | # activation: "gelu" #fixed 51 | 52 | 53 | 54 | #for now, only the last two layers have attention + bottleneck 55 | 56 | 57 | attention_dict: 58 | num_heads: 8 59 | attn_dropout: 0.0 60 | bias_qkv: False 61 | N: 0 62 | rel_pos_num_buckets: 32 63 | rel_pos_max_distance: 64 64 | use_rel_pos: True 65 | Nproj: 8 66 | #the number of channels is the same as the Ns of the corresponding layer 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /conf/network/unet_cqtdiff_original.yaml: -------------------------------------------------------------------------------- 1 | #network from the paper: Solving audio inverse problems with a diffusion model 2 | name: "unet_cqtdiff_original" 3 | callable: "networks.CQTdiff_original.unet_cqt_fast.Unet_CQT" 4 | 5 | 6 | use_fencoding: True 7 | use_norm: False 8 | 9 | filter_out_cqt_DC_Nyq: True 10 | 11 | depth: 6 #it is hardcoded 12 | Ns: [32, 64,64,128, 128, 128, 128, 128] #it is hardcoded 13 | Ss: [2,2,2,2,2,2] #it is hardcoded 14 | 15 | cqt: 16 | num_octs: 7 17 | bins_per_oct: 64 18 | 19 | -------------------------------------------------------------------------------- /conf/tester/edm_2ndorder_stochastic.yaml: -------------------------------------------------------------------------------- 1 | do_test: True #boolean flag to run inference, False means no testing at all 2 | 3 | name: "edm_2ndorder_stochastic" #same as the file name, try to do that for all testers 4 | 5 | type: "classic" 6 | 7 | callable: 'testing.tester.Tester' 8 | sampler_callable: 'testing.edm_sampler.Sampler' 9 | 10 | modes: ['unconditional', 'bwe', 'inpainting'] #modes to test 11 | T: 35 #number of discretizatio steprs 12 | order: 2 #order of the discretization TODO: implement higher order samplers as the one used in ediffi 13 | 14 | filter_out_cqt_DC_Nyq: False 15 | 16 | checkpoint: None 17 | 18 | unconditional: 19 | num_samples: 4 20 | audio_len: 65536 21 | 22 | posterior_sampling: 23 | xi: 0.25 #restoration guidance, 0 means no guidance 24 | data_consistency: False 25 | 26 | #new diffusion parameters (only for sampling): 27 | diff_params: 28 | same_as_training: False 29 | sigma_data: 0.063 #default for maestro 30 | sigma_min: 1e-4 31 | sigma_max: 1 32 | P_mean: -1.2 #what is this for? 33 | P_std: 1.2 #ehat is this for? 34 | ro: 13 35 | ro_train: 13 36 | Schurn: 10 37 | Snoise: 1 38 | Stmin: 0 39 | Stmax: 50 40 | 41 | 42 | autoregressive: 43 | overlap: 0.25 44 | num_samples: 4 45 | 46 | sampler: "stochastic" #wether deterministic or stochastic, unused as Scurn defineds the stochasticity 47 | 48 | noise_in_observations_SNR: None 49 | bandwidth_extension: 50 | decimate: 51 | factor: 1 52 | filter: 53 | type: "firwin" #or "cheby1_fir" 54 | fc: 1000 #cutoff frequency of the applied lpf 55 | order: 500 56 | fir_order: 500 57 | beta: 1 58 | ripple: 0.05 #for the cheby1 59 | resample: 60 | fs: 2000 61 | biquad: 62 | Q: 0.707 63 | inpainting: 64 | gap_length: 1000 #in ms 65 | start_gap_idx: None #in ms, None means at the middle 66 | comp_sens: 67 | percentage: 5 #% 68 | phase_retrieval: 69 | win_size: 1024 70 | hop_size: 256 71 | max_thresh_grads: 1 72 | type_spec: "linear" #or "mel" for phase retrieval 73 | declipping: 74 | SDR: 3 #in dB 75 | -------------------------------------------------------------------------------- /conf/tester/inpainting_tester.yaml: -------------------------------------------------------------------------------- 1 | do_test: True #boolean flag to run inference, False means no testing at all 2 | 3 | name: "inpainting_tester" #same as the file name, try to do that for all testers 4 | 5 | callable: 'testing.tester_inpainting.Tester' 6 | sampler_callable: 'testing.edm_sampler_inpainting.Sampler' 7 | 8 | #modes: ['unconditional', 'inpainting'] #modes to test 9 | 10 | modes: ['inpainting'] #basic time-domain inpainting, using the mask parameters specified in the inpainting section 11 | 12 | #modes: ['inpainting_fordamushra'] # mode to prepare the data for the long-gap MUSHRA test (fixed hard-coded gap lengths). Also, hard-coded paths 13 | 14 | #modes: ['inpainting_shortgaps'] #modes to prepare the data of the short-gap MUSHRA test. Requires a dedicated data_loader ("inpainting_musicnet.yaml"). Loads the mask from the data_loader, from some .mat files 15 | 16 | #modes: ['spectrogram_inpainting'] #experimenting with spectrogram inpainting. THe mask parameters are specified in the spectrogram_inpainting section 17 | # 18 | #modes: ['STN_inpainting'] #experimenting with spectrogram inpainting. THe mask parameters are specified in the spectrogram_inpainting section 19 | 20 | T: 35 #number of discretizatio steprs 21 | order: 2 #order of the discretization. Only 1 or 2 for now 22 | 23 | filter_out_cqt_DC_Nyq: True 24 | 25 | checkpoint: "experiments/54/22k_8s-790000.pt" 26 | 27 | unconditional: 28 | num_samples: 4 29 | audio_len: 184184 30 | 31 | posterior_sampling: 32 | xi: 0.25 #restoration guidance, 0 means no guidance 33 | norm: 2 #1 or 2 or "smoothl1" 34 | smoothl1_beta: 1 35 | 36 | data_consistency: 37 | use: True 38 | type: "always" # or "end" or "end_smooth" 39 | smooth: True #apply a smoother mask for data consistency steps 40 | hann_size: 50 #in samples 41 | 42 | 43 | 44 | #new diffusion parameters (only for sampling): 45 | diff_params: 46 | same_as_training: False 47 | sigma_data: 0.063 #default for maestro 48 | sigma_min: 1e-4 49 | sigma_max: 1 50 | P_mean: -1.2 #what is this for? 51 | P_std: 1.2 #ehat is this for? 52 | ro: 13 53 | ro_train: 13 54 | Schurn: 10 55 | Snoise: 1.000 56 | Stmin: 0 57 | Stmax: 50 58 | 59 | 60 | autoregressive: 61 | overlap: 0.25 62 | num_samples: 4 63 | 64 | sampler: "stochastic" #wether deterministic or stochastic, unused as Scurn defineds the stochasticity 65 | 66 | noise_in_observations_SNR: None 67 | 68 | inpainting: 69 | mask_mode: "long" #or "short" 70 | long: 71 | gap_length: 1500 #in ms 72 | start_gap_idx: None #in ms, None means at the middle 73 | short: 74 | num_gaps: 4 75 | gap_length: 25 #in ms 76 | start_gap_idx: None #in ms, None means random. If not None this should be a list of length num_gaps 77 | 78 | spectrogram_inpainting: #specifies a (rectangular for now) mask localized in time and frequency 79 | stft: 80 | window: "hann" 81 | n_fft: 1024 82 | hop_length: 256 83 | win_length: 1024 84 | time_mask_length: 2000 #in ms 85 | time_start_idx: None #in ms, None means at the middle 86 | min_masked_freq: 300 #in Hz (lowest frequency to mask) 87 | max_masked_freq: 2000 #in Hz (max frequency to mask) 88 | 89 | STN_inpainting: #TODO implement STN inpainting 90 | STN_params: 91 | nwin1: 4096 92 | G1: 0.65 93 | G2: 0.7 94 | type: "T" #or "S" or "T" or "N" 95 | 96 | 97 | 98 | 99 | comp_sens: 100 | percentage: 5 #% 101 | 102 | 103 | max_thresh_grads: 1 104 | type_spec: "linear" #or "mel" for phase retrieval 105 | declipping: 106 | SDR: 3 #in dB 107 | -------------------------------------------------------------------------------- /conf/tester/inpainting_tester_shortgaps.yaml: -------------------------------------------------------------------------------- 1 | do_test: True #boolean flag to run inference, False means no testing at all 2 | 3 | name: "inpainting_tester" #same as the file name, try to do that for all testers 4 | 5 | callable: 'testing.tester_inpainting.Tester' 6 | sampler_callable: 'testing.edm_sampler_inpainting.Sampler' 7 | 8 | #modes: ['unconditional', 'inpainting'] #modes to test 9 | 10 | #modes: ['inpainting'] #basic time-domain inpainting, using the mask parameters specified in the inpainting section 11 | 12 | #modes: ['inpainting_fordamushra'] # mode to prepare the data for the long-gap MUSHRA test (fixed hard-coded gap lengths). Also, hard-coded paths 13 | 14 | modes: ['inpainting_shortgaps'] #modes to prepare the data of the short-gap MUSHRA test. Requires a dedicated data_loader ("inpainting_musicnet.yaml"). Loads the mask from the data_loader, from some .mat files 15 | 16 | #modes: ['spectrogram_inpainting'] #experimenting with spectrogram inpainting. THe mask parameters are specified in the spectrogram_inpainting section 17 | # 18 | #modes: ['STN_inpainting'] #experimenting with spectrogram inpainting. THe mask parameters are specified in the spectrogram_inpainting section 19 | 20 | T: 70 #number of discretizatio steprs 21 | order: 2 #order of the discretization TODO: implement higher order samplers as the one used in ediffi 22 | 23 | filter_out_cqt_DC_Nyq: True 24 | 25 | checkpoint: "experiments/54/22k_8s-790000.pt" 26 | 27 | unconditional: 28 | num_samples: 4 29 | audio_len: 184184 30 | 31 | posterior_sampling: 32 | xi: 0.25 #restoration guidance, 0 means no guidance 33 | norm: 2 #1 or 2 or "smoothl1" 34 | smoothl1_beta: 1 35 | 36 | data_consistency: 37 | use: True 38 | type: "always" # or "end" or "end_smooth" 39 | smooth: True #apply a smoother mask for data consistency steps 40 | hann_size: 100 #in samples 41 | 42 | 43 | 44 | #new diffusion parameters (only for sampling): 45 | diff_params: 46 | same_as_training: False 47 | sigma_data: 0.063 #default for maestro 48 | sigma_min: 1e-4 49 | sigma_max: 1 50 | P_mean: -1.2 #what is this for? 51 | P_std: 1.2 #ehat is this for? 52 | ro: 13 53 | ro_train: 13 54 | Schurn: 10 55 | Snoise: 1.000 56 | Stmin: 0 57 | Stmax: 50 58 | 59 | 60 | autoregressive: 61 | overlap: 0.25 62 | num_samples: 4 63 | 64 | sampler: "stochastic" #wether deterministic or stochastic, unused as Scurn defineds the stochasticity 65 | 66 | noise_in_observations_SNR: None 67 | 68 | inpainting: 69 | mask_mode: "long" #or "short" 70 | long: 71 | gap_length: 1500 #in ms 72 | start_gap_idx: None #in ms, None means at the middle 73 | short: 74 | num_gaps: 4 75 | gap_length: 25 #in ms 76 | start_gap_idx: None #in ms, None means random. If not None this should be a list of length num_gaps 77 | 78 | spectrogram_inpainting: #specifies a (rectangular for now) mask localized in time and frequency 79 | stft: 80 | window: "hann" 81 | n_fft: 1024 82 | hop_length: 256 83 | win_length: 1024 84 | time_mask_length: 2000 #in ms 85 | time_start_idx: None #in ms, None means at the middle 86 | min_masked_freq: 300 #in Hz (lowest frequency to mask) 87 | max_masked_freq: 2000 #in Hz (max frequency to mask) 88 | 89 | STN_inpainting: #TODO implement STN inpainting 90 | STN_params: 91 | nwin1: 4096 92 | G1: 0.65 93 | G2: 0.7 94 | type: "TN" #or "S" or "T" or "N" 95 | 96 | 97 | 98 | 99 | comp_sens: 100 | percentage: 5 #% 101 | 102 | 103 | max_thresh_grads: 1 104 | type_spec: "linear" #or "mel" for phase retrieval 105 | declipping: 106 | SDR: 3 #in dB 107 | -------------------------------------------------------------------------------- /datasets/audiofolder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | #import PIL.Image 14 | import json 15 | import torch 16 | import utils.dnnlib as dnnlib 17 | import random 18 | import pandas as pd 19 | import glob 20 | import soundfile as sf 21 | 22 | #try: 23 | # import pyspng 24 | #except ImportError: 25 | # pyspng = None 26 | 27 | #---------------------------------------------------------------------------- 28 | # Dataset subclass that loads images recursively from the specified directory 29 | # or ZIP file. 30 | class AudioFolderDataset(torch.utils.data.IterableDataset): 31 | def __init__(self, 32 | dset_args, 33 | fs=44100, 34 | seg_len=131072, 35 | overfit=False, 36 | seed=42 ): 37 | self.overfit=overfit 38 | 39 | super().__init__() 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | path=dset_args.path 43 | 44 | filelist=glob.glob(os.path.join(path,"*.wav")) 45 | assert len(filelist)>0 , "error in dataloading: empty or nonexistent folder" 46 | 47 | self.train_samples=filelist 48 | 49 | self.seg_len=int(seg_len) 50 | self.fs=fs 51 | if self.overfit: 52 | file=self.train_samples[0] 53 | data, samplerate = sf.read(file) 54 | if len(data.shape)>1 : 55 | data=np.mean(data,axis=1) 56 | self.overfit_sample=data[10*samplerate:60*samplerate] #use only 50s 57 | 58 | def __iter__(self): 59 | if self.overfit: 60 | data_clean=self.overfit_sample 61 | while True: 62 | if not self.overfit: 63 | num=random.randint(0,len(self.train_samples)-1) 64 | #for file in self.train_samples: 65 | file=self.train_samples[num] 66 | data, samplerate = sf.read(file) 67 | assert(samplerate==self.fs, "wrong sampling rate") 68 | data_clean=data 69 | #Stereo to mono 70 | if len(data.shape)>1 : 71 | data_clean=np.mean(data_clean,axis=1) 72 | 73 | #normalize 74 | #no normalization!! 75 | #data_clean=data_clean/np.max(np.abs(data_clean)) 76 | 77 | #framify data clean files 78 | num_frames=np.floor(len(data_clean)/self.seg_len) 79 | 80 | #if num_frames>4: 81 | for i in range(8): 82 | #get 8 random batches to be a bit faster 83 | if not self.overfit: 84 | idx=np.random.randint(0,len(data_clean)-self.seg_len) 85 | else: 86 | idx=0 87 | segment=data_clean[idx:idx+self.seg_len] 88 | segment=segment.astype('float32') 89 | #b=np.mean(np.abs(segment)) 90 | #segment= (10/(b*np.sqrt(2)))*segment #default rms of 0.1. Is this scaling correct?? 91 | 92 | #let's make this shit a bit robust to input scale 93 | #scale=np.random.uniform(1.75,2.25) 94 | #this way I estimage sigma_data (after pre_emph) to be around 1 95 | 96 | #segment=10.0**(scale) *segment 97 | yield segment 98 | #else: 99 | # pass 100 | 101 | 102 | -------------------------------------------------------------------------------- /datasets/audiofolder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | #import PIL.Image 14 | import json 15 | import torch 16 | import utils.dnnlib as dnnlib 17 | import random 18 | import pandas as pd 19 | import glob 20 | import soundfile as sf 21 | 22 | #try: 23 | # import pyspng 24 | #except ImportError: 25 | # pyspng = None 26 | 27 | #---------------------------------------------------------------------------- 28 | # Dataset subclass that loads images recursively from the specified directory 29 | # or ZIP file. 30 | class AudioFolderDatasetTest(torch.utils.data.Dataset): 31 | def __init__(self, 32 | dset_args, 33 | fs=44100, 34 | seg_len=131072, 35 | num_samples=4, 36 | seed=42 ): 37 | 38 | super().__init__() 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | path=dset_args.test.path 42 | 43 | filelist=glob.glob(os.path.join(path,"*.wav")) 44 | assert len(filelist)>0 , "error in dataloading: empty or nonexistent folder" 45 | self.train_samples=filelist 46 | self.seg_len=int(seg_len) 47 | self.fs=fs 48 | 49 | self.test_samples=[] 50 | self.filenames=[] 51 | self._fs=[] 52 | for i in range(num_samples): 53 | file=self.train_samples[i] 54 | self.filenames.append(os.path.basename(file)) 55 | data, samplerate = sf.read(file) 56 | self._fs.append(samplerate) 57 | if len(data.shape)>1 : 58 | data=np.mean(data,axis=1) 59 | self.test_samples.append(data[2*samplerate:2*samplerate+self.seg_len]) #use only 50s 60 | 61 | 62 | def __getitem__(self, idx): 63 | #return self.test_samples[idx] 64 | return self.test_samples[idx], self._fs[idx], self.filenames[idx] 65 | 66 | def __len__(self): 67 | return len(self.test_samples) 68 | 69 | -------------------------------------------------------------------------------- /datasets/librispeech.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | #import PIL.Image 14 | import json 15 | import torch 16 | import utils.dnnlib as dnnlib 17 | import random 18 | import pandas as pd 19 | import glob 20 | import soundfile as sf 21 | 22 | #try: 23 | # import pyspng 24 | #except ImportError: 25 | # pyspng = None 26 | 27 | #---------------------------------------------------------------------------- 28 | # Dataset subclass that loads images recursively from the specified directory 29 | # or ZIP file. 30 | class LibrispeechTrain(torch.utils.data.IterableDataset): 31 | def __init__(self, 32 | dset_args, 33 | fs=44100, 34 | seg_len=131072, 35 | overfit=False, 36 | seed=42 ): 37 | self.overfit=overfit 38 | 39 | super().__init__() 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | path=dset_args.path 43 | filelist=[] 44 | for d in dset_args.train_dirs: 45 | filelist.extend(glob.glob(os.path.join(path,d,"*/*/*.flac"))) 46 | 47 | assert len(filelist)>0 , "error in dataloading: empty or nonexistent folder" 48 | 49 | self.train_samples=filelist 50 | 51 | self.seg_len=int(seg_len) 52 | self.fs=fs 53 | if self.overfit: 54 | file=self.train_samples[0] 55 | data, samplerate = sf.read(file) 56 | if len(data.shape)>1 : 57 | data=np.mean(data,axis=1) 58 | self.overfit_sample=data[10*samplerate:60*samplerate] #use only 50s 59 | 60 | def __iter__(self): 61 | if self.overfit: 62 | data_clean=self.overfit_sample 63 | while True: 64 | if not self.overfit: 65 | num=random.randint(0,len(self.train_samples)-1) 66 | #for file in self.train_samples: 67 | file=self.train_samples[num] 68 | data, samplerate = sf.read(file) 69 | assert(samplerate==self.fs, "wrong sampling rate") 70 | segment=data 71 | #Stereo to mono 72 | if len(data.shape)>1 : 73 | segment=np.mean(segment,axis=1) 74 | 75 | #normalize 76 | #no normalization!! 77 | #data_clean=data_clean/np.max(np.abs(data_clean)) 78 | L=len(segment) 79 | #print(L, self.seg_len) 80 | if L>self.seg_len: 81 | #get random segment 82 | idx=np.random.randint(0,L-self.seg_len) 83 | segment=segment[idx:idx+self.seg_len] 84 | elif L<=self.seg_len: 85 | #pad with zeros to get to the right length randomly 86 | idx=np.random.randint(0,self.seg_len-L) 87 | #segment=np.pad(segment,(idx,self.seg_len-L-idx),'constant') 88 | #copy segment to get to the right length 89 | segment=np.pad(segment,(idx,self.seg_len-L-idx),'wrap') 90 | 91 | #print the std of the segment 92 | #print(np.std(segment)) 93 | 94 | 95 | yield segment 96 | #else: 97 | # pass 98 | 99 | 100 | 101 | class LibrispeechTest(torch.utils.data.Dataset): 102 | def __init__(self, 103 | dset_args, 104 | fs=44100, 105 | seg_len=131072, 106 | num_samples=4, 107 | seed=42 ): 108 | 109 | super().__init__() 110 | random.seed(seed) 111 | np.random.seed(seed) 112 | path=dset_args.test.path 113 | 114 | 115 | filelist=glob.glob(os.path.join(path,"*/*/*.flac")) 116 | 117 | assert len(filelist)>0 , "error in dataloading: empty or nonexistent folder" 118 | 119 | self.train_samples=filelist 120 | self.seg_len=int(seg_len) 121 | self.fs=fs 122 | 123 | self.test_samples=[] 124 | self.filenames=[] 125 | self._fs=[] 126 | for i in range(num_samples): 127 | file=self.train_samples[i] 128 | self.filenames.append(os.path.basename(file)) 129 | segment, samplerate = sf.read(file) 130 | print(self.fs, samplerate) 131 | assert samplerate==self.fs, "wrong sampling rate" 132 | if len(segment.shape)>1 : 133 | segment=np.mean(segment,axis=1) 134 | L=len(segment) 135 | if L>self.seg_len: 136 | #get random segment 137 | idx=0 138 | segment=segment[idx:idx+self.seg_len] 139 | elif L<=self.seg_len: 140 | #pad with zeros to get to the right length randomly 141 | idx=0 142 | #segment=np.pad(segment,(idx,self.seg_len-L-idx),'constant') 143 | segment=np.pad(segment,(idx,self.seg_len-L-idx),'wrap') 144 | self._fs.append(samplerate) 145 | 146 | self.test_samples.append(segment) #use only 50s 147 | 148 | def __getitem__(self, idx): 149 | #return self.test_samples[idx] 150 | return self.test_samples[idx],self._fs[idx], self.filenames[idx] 151 | 152 | def __len__(self): 153 | return len(self.test_samples) 154 | -------------------------------------------------------------------------------- /datasets/maestro_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | #import PIL.Image 14 | import json 15 | import torch 16 | import utils.dnnlib as dnnlib 17 | import random 18 | import pandas as pd 19 | import glob 20 | import soundfile as sf 21 | 22 | #try: 23 | # import pyspng 24 | #except ImportError: 25 | # pyspng = None 26 | 27 | #---------------------------------------------------------------------------- 28 | # Dataset subclass that loads images recursively from the specified directory 29 | # or ZIP file. 30 | 31 | class MaestroDataset_fs(torch.utils.data.IterableDataset): 32 | def __init__(self, 33 | dset_args, 34 | overfit=False, #set to True for overfitting dataset (lightweight tests to vccheck that the dataloading is not bottlenecking) 35 | seed=42 ): 36 | 37 | super().__init__() 38 | self.overfit=overfit 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | path=dset_args.path 42 | years=dset_args.years 43 | 44 | metadata_file=os.path.join(path,"maestro-v3.0.0.csv") 45 | metadata=pd.read_csv(metadata_file) 46 | 47 | metadata=metadata[metadata["year"].isin(years)] 48 | metadata=metadata[metadata["split"]=="train"] 49 | filelist=metadata["audio_filename"] 50 | 51 | filelist=filelist.map(lambda x: os.path.join(path,x) , na_action='ignore') 52 | 53 | 54 | self.train_samples=filelist.to_list() 55 | 56 | self.seg_len=int(dset_args.load_len) 57 | 58 | 59 | def __iter__(self): 60 | if self.overfit: 61 | data_clean=self.overfit_sample 62 | while True: 63 | if not self.overfit: 64 | num=random.randint(0,len(self.train_samples)-1) 65 | #for file in self.train_samples: 66 | file=self.train_samples[num] 67 | data, samplerate = sf.read(file) 68 | #print(file,samplerate) 69 | 70 | data_clean=data 71 | #Stereo to mono 72 | if len(data.shape)>1 : 73 | data_clean=np.mean(data_clean,axis=1) 74 | 75 | #normalize 76 | #no normalization!! 77 | #data_clean=data_clean/np.max(np.abs(data_clean)) 78 | 79 | #framify data clean files 80 | 81 | num_frames=np.floor(len(data_clean)/self.seg_len) 82 | 83 | if num_frames>4: 84 | for i in range(8): 85 | #get 8 random batches to be a bit faster 86 | if not self.overfit: 87 | idx=np.random.randint(0,len(data_clean)-self.seg_len) 88 | else: 89 | idx=0 90 | segment=data_clean[idx:idx+self.seg_len] 91 | segment=segment.astype('float32') 92 | #b=np.mean(np.abs(segment)) 93 | #segment= (10/(b*np.sqrt(2)))*segment #default rms of 0.1. Is this scaling correct?? 94 | 95 | #let's make this shit a bit robust to input scale 96 | #scale=np.random.uniform(1.75,2.25) 97 | #this way I estimage sigma_data (after pre_emph) to be around 1 98 | 99 | #segment=10.0**(scale) *segment 100 | yield segment, samplerate 101 | else: 102 | pass 103 | class MaestroDataset(torch.utils.data.IterableDataset): 104 | def __init__(self, 105 | dset_args, 106 | fs=44100, 107 | seg_len=131072, 108 | overfit=False, #set to True for overfitting dataset (lightweight tests to vccheck that the dataloading is not bottlenecking) 109 | seed=42 ): 110 | 111 | super().__init__() 112 | self.overfit=overfit 113 | random.seed(seed) 114 | np.random.seed(seed) 115 | path=dset_args.path 116 | years=dset_args.years 117 | 118 | metadata_file=os.path.join(path,"maestro-v3.0.0.csv") 119 | metadata=pd.read_csv(metadata_file) 120 | 121 | metadata=metadata[metadata["year"].isin(years)] 122 | metadata=metadata[metadata["split"]=="train"] 123 | filelist=metadata["audio_filename"] 124 | 125 | filelist=filelist.map(lambda x: os.path.join(path,x) , na_action='ignore') 126 | 127 | 128 | self.train_samples=filelist.to_list() 129 | 130 | self.seg_len=int(seg_len) 131 | self.fs=fs 132 | if self.overfit: 133 | file=self.train_samples[0] 134 | data, samplerate = sf.read(file) 135 | assert samplerate==self.fs, "wrong sampling rate" 136 | if len(data.shape)>1 : 137 | data=np.mean(data,axis=1) 138 | self.overfit_sample=data[10*samplerate:60*samplerate] #use only 50s 139 | 140 | def __iter__(self): 141 | if self.overfit: 142 | data_clean=self.overfit_sample 143 | while True: 144 | if not self.overfit: 145 | num=random.randint(0,len(self.train_samples)-1) 146 | #for file in self.train_samples: 147 | file=self.train_samples[num] 148 | data, samplerate = sf.read(file) 149 | assert(samplerate==self.fs, "wrong sampling rate") 150 | data_clean=data 151 | #Stereo to mono 152 | if len(data.shape)>1 : 153 | data_clean=np.mean(data_clean,axis=1) 154 | 155 | #normalize 156 | #no normalization!! 157 | #data_clean=data_clean/np.max(np.abs(data_clean)) 158 | 159 | #framify data clean files 160 | num_frames=np.floor(len(data_clean)/self.seg_len) 161 | 162 | if num_frames>4: 163 | for i in range(8): 164 | #get 8 random batches to be a bit faster 165 | if not self.overfit: 166 | idx=np.random.randint(0,len(data_clean)-self.seg_len) 167 | else: 168 | idx=0 169 | segment=data_clean[idx:idx+self.seg_len] 170 | segment=segment.astype('float32') 171 | #b=np.mean(np.abs(segment)) 172 | #segment= (10/(b*np.sqrt(2)))*segment #default rms of 0.1. Is this scaling correct?? 173 | 174 | #let's make this shit a bit robust to input scale 175 | #scale=np.random.uniform(1.75,2.25) 176 | #this way I estimage sigma_data (after pre_emph) to be around 1 177 | 178 | #segment=10.0**(scale) *segment 179 | yield segment 180 | else: 181 | pass 182 | 183 | -------------------------------------------------------------------------------- /datasets/maestro_dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | #import PIL.Image 14 | import json 15 | import torch 16 | import utils.dnnlib as dnnlib 17 | import random 18 | import pandas as pd 19 | import glob 20 | import soundfile as sf 21 | 22 | #try: 23 | # import pyspng 24 | #except ImportError: 25 | # pyspng = None 26 | 27 | #---------------------------------------------------------------------------- 28 | # Dataset subclass that loads images recursively from the specified directory 29 | # or ZIP file. 30 | class MaestroDatasetTestChunks(torch.utils.data.Dataset): 31 | def __init__(self, 32 | dset_args, 33 | num_samples=4, 34 | seed=42 ): 35 | 36 | super().__init__() 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | path=dset_args.path 40 | years=dset_args.years 41 | 42 | self.seg_len=int(dset_args.load_len) 43 | 44 | metadata_file=os.path.join(path,"maestro-v3.0.0.csv") 45 | metadata=pd.read_csv(metadata_file) 46 | 47 | metadata=metadata[metadata["year"].isin(years)] 48 | metadata=metadata[metadata["split"]=="test"] 49 | filelist=metadata["audio_filename"] 50 | 51 | filelist=filelist.map(lambda x: os.path.join(path,x) , na_action='ignore') 52 | 53 | 54 | self.filelist=filelist.to_list() 55 | 56 | self.test_samples=[] 57 | self.filenames=[] 58 | self.f_s=[] 59 | for i in range(num_samples): 60 | file=self.filelist[i] 61 | self.filenames.append(os.path.basename(file)) 62 | data, samplerate = sf.read(file) 63 | if len(data.shape)>1 : 64 | data=np.mean(data,axis=1) 65 | 66 | self.test_samples.append(data[10*samplerate:10*samplerate+self.seg_len]) #use only 50s 67 | self.f_s.append(samplerate) 68 | 69 | 70 | def __getitem__(self, idx): 71 | return self.test_samples[idx], self.f_s[idx], self.filenames[idx] 72 | 73 | def __len__(self): 74 | return len(self.test_samples) 75 | 76 | 77 | -------------------------------------------------------------------------------- /diff_params/edm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import utils.training_utils as utils 5 | 6 | 7 | class EDM(): 8 | """ 9 | Definition of most of the diffusion parameterization, following ( Karras et al., "Elucidating...", 2022) 10 | """ 11 | 12 | def __init__(self, args): 13 | """ 14 | Args: 15 | args (dictionary): hydra arguments 16 | sigma_data (float): 17 | """ 18 | self.args=args 19 | self.sigma_min = args.diff_params.sigma_min 20 | self.sigma_max =args.diff_params.sigma_max 21 | self.P_mean=args.diff_params.P_mean 22 | self.P_std=args.diff_params.P_std 23 | self.ro=args.diff_params.ro 24 | self.ro_train=args.diff_params.ro_train 25 | self.sigma_data=args.diff_params.sigma_data #depends on the training data!! precalculated variance of the dataset 26 | #parameters stochastic sampling 27 | self.Schurn=args.diff_params.Schurn 28 | self.Stmin=args.diff_params.Stmin 29 | self.Stmax=args.diff_params.Stmax 30 | self.Snoise=args.diff_params.Snoise 31 | 32 | #perceptual filter 33 | if self.args.diff_params.aweighting.use_aweighting: 34 | self.AW=utils.FIRFilter(filter_type="aw", fs=args.exp.sample_rate, ntaps=self.args.diff_params.aweighting.ntaps) 35 | 36 | 37 | 38 | def get_gamma(self, t): 39 | """ 40 | Get the parameter gamma that defines the stochasticity of the sampler 41 | Args 42 | t (Tensor): shape: (N_steps, ) Tensor of timesteps, from which we will compute gamma 43 | """ 44 | N=t.shape[0] 45 | gamma=torch.zeros(t.shape).to(t.device) 46 | 47 | #If desired, only apply stochasticity between a certain range of noises Stmin is 0 by default and Stmax is a huge number by default. (Unless these parameters are specified, this does nothing) 48 | indexes=torch.logical_and(t>self.Stmin , t0: 109 | #apply rec. guidance 110 | score=self.get_score_rec_guidance(x, y, t_i, degradation) 111 | 112 | #optionally apply replacement or consistency step 113 | if self.data_consistency: 114 | #convert score to denoised estimate using Tweedie's formula 115 | x_hat=score*t_i**2+x 116 | 117 | if self.args.inference.mode=="phase_retrieval": 118 | x_hat=self.data_consistency_step_phase_retrieval(x_hat,y) 119 | else: 120 | x_hat=self.data_consistency_step(x_hat,y, degradation) 121 | 122 | #convert back to score 123 | score=(x_hat-x)/t_i**2 124 | 125 | else: 126 | #denoised with replacement method 127 | with torch.no_grad(): 128 | x_hat=self.diff_params.denoiser(x, self.model, t_i.unsqueeze(-1)) 129 | 130 | x_hat=self.data_consistency_step(x_hat,y, degradation) 131 | 132 | score=(x_hat-x)/t_i**2 133 | 134 | return score 135 | 136 | def predict_unconditional( 137 | self, 138 | shape, #observations (lowpssed signal) Tensor with shape ?? 139 | device 140 | ): 141 | self.y=None 142 | self.degradation=None 143 | return self.predict(shape, device) 144 | 145 | def predict_resample( 146 | self, 147 | y, #observations (lowpssed signal) Tensor with shape ?? 148 | shape, 149 | degradation, #lambda function 150 | ): 151 | self.degradation=degradation 152 | self.y=y 153 | print(shape) 154 | return self.predict(shape, y.device) 155 | 156 | 157 | def predict_conditional( 158 | self, 159 | y, #observations (lowpssed signal) Tensor with shape ?? 160 | degradation, #lambda function 161 | ): 162 | self.degradation=degradation 163 | self.y=y 164 | return self.predict(y.shape, y.device) 165 | 166 | def predict( 167 | self, 168 | shape, #observations (lowpssed signal) Tensor with shape ?? 169 | device, #lambda function 170 | ): 171 | 172 | if self.rid: 173 | data_denoised=torch.zeros((self.nb_steps,shape[0], shape[1])) 174 | 175 | #get the noise schedule 176 | t = self.diff_params.create_schedule(self.nb_steps).to(device) 177 | #sample from gaussian distribution with sigma_max variance 178 | x = self.diff_params.sample_prior(shape,t[0]).to(device) 179 | 180 | #parameter for langevin stochasticity, if Schurn is 0, gamma will be 0 to, so the sampler will be deterministic 181 | gamma=self.diff_params.get_gamma(t).to(device) 182 | 183 | 184 | for i in tqdm(range(0, self.nb_steps, 1)): 185 | #print("sampling step ",i," from ",self.nb_steps) 186 | 187 | if gamma[i]==0: 188 | #deterministic sampling, do nothing 189 | t_hat=t[i] 190 | x_hat=x 191 | else: 192 | #stochastic sampling 193 | #move timestep 194 | t_hat=t[i]+gamma[i]*t[i] 195 | #sample noise, Snoise is 1 by default 196 | epsilon=torch.randn(shape).to(device)*self.diff_params.Snoise 197 | #add extra noise 198 | x_hat=x+((t_hat**2 - t[i]**2)**(1/2))*epsilon 199 | 200 | score=self.get_score(x_hat, self.y, t_hat, self.degradation) 201 | 202 | #d=-t_hat*((denoised-x_hat)/t_hat**2) 203 | d=-t_hat*score 204 | 205 | #apply second order correction 206 | h=t[i+1]-t_hat 207 | 208 | 209 | if t[i+1]!=0 and self.order==2: #always except last step 210 | #second order correction2 211 | #h=t[i+1]-t_hat 212 | t_prime=t[i+1] 213 | x_prime=x_hat+h*d 214 | score=self.get_score(x_prime, self.y, t_prime, self.degradation) 215 | 216 | d_prime=-t_prime*score 217 | 218 | x=(x_hat+h*((1/2)*d +(1/2)*d_prime)) 219 | 220 | elif t[i+1]==0 or self.order==1: #first condition is to avoid dividing by 0 221 | #first order Euler step 222 | x=x_hat+h*d 223 | 224 | if self.rid: data_denoised[i]=x 225 | 226 | if self.rid: 227 | return x.detach(), data_denoised.detach(), t.detach() 228 | else: 229 | return x.detach() 230 | 231 | def apply_mask(self, x): 232 | return self.mask*x 233 | 234 | def predict_inpainting( 235 | self, 236 | y_masked, 237 | mask 238 | ): 239 | self.mask=mask.to(y_masked.device) 240 | 241 | degradation=lambda x: self.apply_mask(x) 242 | 243 | return self.predict_conditional(y_masked, degradation) 244 | 245 | def apply_FIR_filter(self,y): 246 | y=y.unsqueeze(1) 247 | 248 | #apply the filter with a convolution (it is an FIR) 249 | y_lpf=torch.nn.functional.conv1d(y,self.filt,padding="same") 250 | y_lpf=y_lpf.squeeze(1) 251 | 252 | return y_lpf 253 | def apply_IIR_filter(self,y): 254 | y_lpf=torchaudio.functional.lfilter(y, self.a,self.b, clamp=False) 255 | return y_lpf 256 | def apply_biquad(self,y): 257 | y_lpf=torchaudio.functional.biquad(y, self.b0, self.b1, self.b2, self.a0, self.a1, self.a2) 258 | return y_lpf 259 | def decimate(self,x): 260 | return x[...,0:-1:self.factor] 261 | 262 | def resample(self,x): 263 | N=100 264 | return torchaudio.functional.resample(x,orig_freq=int(N*self.factor), new_freq=N) 265 | 266 | def predict_bwe( 267 | self, 268 | ylpf, #observations (lowpssed signal) Tensor with shape (L,) 269 | filt, #filter Tensor with shape ?? 270 | filt_type 271 | ): 272 | 273 | #define the degradation model as a lambda 274 | if filt_type=="firwin": 275 | self.filt=filt.to(ylpf.device) 276 | degradation=lambda x: self.apply_FIR_filter(x) 277 | elif filt_type=="firwin_hpf": 278 | self.filt=filt.to(ylpf.device) 279 | degradation=lambda x: self.apply_FIR_filter(x) 280 | elif filt_type=="cheby1": 281 | b,a=filt 282 | self.a=torch.Tensor(a).to(ylpf.device) 283 | self.b=torch.Tensor(b).to(ylpf.device) 284 | degradation=lambda x: self.apply_IIR_filter(x) 285 | elif filt_type=="biquad": 286 | b0, b1, b2, a0, a1, a2=filt 287 | self.b0=torch.Tensor(b0).to(ylpf.device) 288 | self.b1=torch.Tensor(b1).to(ylpf.device) 289 | self.b2=torch.Tensor(b2).to(ylpf.device) 290 | self.a0=torch.Tensor(a0).to(ylpf.device) 291 | self.a1=torch.Tensor(a1).to(ylpf.device) 292 | self.a2=torch.Tensor(a2).to(ylpf.device) 293 | degradation=lambda x: self.apply_biquad(x) 294 | elif filt_type=="resample": 295 | self.factor =filt 296 | degradation= lambda x: self.resample(x) 297 | return self.predict_resample(ylpf,(ylpf.shape[0], self.args.exp.audio_len), degradation) 298 | elif filt_type=="decimate": 299 | self.factor =filt 300 | degradation= lambda x: self.decimate(x) 301 | return self.predict_resample(ylpf,(ylpf.shape[0], self.args.exp.audio_len), degradation) 302 | else: 303 | raise NotImplementedError 304 | 305 | return self.predict_conditional(ylpf, degradation) 306 | 307 | 308 | class SamplerPhaseRetrieval(Sampler): 309 | 310 | def __init__(self, model, diff_params, args, xi=0, order=2, data_consistency=False, rid=False): 311 | super().__init__(model, diff_params, args, xi, order, data_consistency, rid) 312 | #assert data_consistency==False 313 | assert xi>0 314 | 315 | def apply_stft(self,x): 316 | 317 | x2=torch.cat((x,self.zeropad ),-1) 318 | X=torch.stft(x2, self.win_size, hop_length=self.hop_size,window=self.window,center=False,return_complex=False) 319 | Y=torch.sqrt(X[...,0]**2 + X[...,1]**2) 320 | return Y 321 | 322 | def predict_pr( 323 | self, 324 | y 325 | ): 326 | 327 | self.win_size=self.args.inference.phase_retrieval.win_size 328 | self.hop_size=self.args.inference.phase_retrieval.hop_size 329 | 330 | print(y.shape) 331 | self.zeropad=torch.zeros(y.shape[0],self.win_size ).to(y.device) 332 | self.window=torch.hamming_window(window_length=self.win_size).to(y.device) 333 | 334 | degradation=lambda x: self.apply_stft(x) 335 | 336 | return self.predict_resample(y, (y.shape[0], self.args.exp.audio_len), degradation) 337 | class SamplerCompSens(Sampler): 338 | 339 | def __init__(self, model, diff_params, args, xi=0, order=2, data_consistency=False, rid=False): 340 | super().__init__(model, diff_params, args, xi, order, data_consistency, rid) 341 | assert data_consistency==False 342 | assert xi>0 343 | 344 | def apply_mask(self, x): 345 | return self.mask*x 346 | 347 | def predict_compsens( 348 | self, 349 | y_masked, 350 | mask 351 | ): 352 | 353 | self.mask=mask.to(y_masked.device) 354 | 355 | degradation=lambda x: self.apply_mask(x) 356 | 357 | return self.predict_conditional(y_masked, degradation) 358 | 359 | class SamplerDeclipping(Sampler): 360 | 361 | def __init__(self, model, diff_params, args, xi=0, order=2, data_consistency=False, rid=False): 362 | super().__init__(model, diff_params, args, xi, order, data_consistency, rid) 363 | assert data_consistency==False 364 | assert xi>0 365 | 366 | def apply_clip(self,x): 367 | x_hat=torch.clip(x,min=-self.clip_value, max=self.clip_value) 368 | return x_hat 369 | 370 | def predict_declipping( 371 | self, 372 | y_clipped, 373 | clip_value 374 | ): 375 | self.clip_value=clip_value 376 | 377 | degradation=lambda x: self.apply_clip(x) 378 | 379 | if self.rid: 380 | res, denoised, t=self.predict_conditional(y_clipped, degradation) 381 | return res, denoised, t 382 | else: 383 | res=self.predict_conditional(y_clipped, degradation) 384 | return res 385 | 386 | class SamplerAutoregressive(Sampler): 387 | 388 | def __init__(self, model, diff_params, args, xi=0, order=2, data_consistency=False, rid=False): 389 | super().__init__(model, diff_params, args, xi, order, data_consistency, rid) 390 | assert rid==False 391 | self.ov=self.args.inference.autoregressive.overlap 392 | 393 | def apply_mask(self, x): 394 | return self.mask*x 395 | 396 | def predict_autoregressive( 397 | self, 398 | shape, 399 | N, 400 | device 401 | ): 402 | 403 | endmask=int(self.ov*shape[-1]) 404 | self.mask=torch.ones((1,self.args.exp.audio_len)).to(device) #assume between 5 and 6s of total length 405 | self.mask[:,endmask::]=0 406 | 407 | degradation=lambda x: self.apply_mask(x) 408 | 409 | x= self.predict_unconditional(shape, device) 410 | xcat=x 411 | x_masked=torch.zeros((1,self.args.exp.audio_len)).to(device) 412 | 413 | for i in range(N-1): 414 | x_masked[:,0:endmask]=x[:,-endmask::] 415 | x=self.predict_conditional(x_masked, degradation) 416 | xcat=torch.cat((xcat,x[...,endmask::]),-1) 417 | 418 | return xcat 419 | 420 | 421 | 422 | 423 | 424 | class SamplerInpainting(Sampler): 425 | 426 | def __init__(self, model, diff_params, args, xi=0, order=2, data_consistency=False, rid=False): 427 | super().__init__(model, diff_params, args, xi, order, data_consistency, rid) 428 | 429 | def apply_mask(self, x): 430 | return self.mask*x 431 | 432 | def predict_inpainting( 433 | self, 434 | y_masked, 435 | mask 436 | ): 437 | self.mask=mask.to(y_masked.device) 438 | 439 | degradation=lambda x: self.apply_mask(x) 440 | 441 | return self.predict_conditional(y_masked, degradation) 442 | 443 | class SamplerBWE(Sampler): 444 | 445 | def __init__(self, model, diff_params, args, xi=0, order=2, data_consistency=False, rid=False): 446 | super().__init__(model, diff_params, args, xi, order, data_consistency, rid) 447 | 448 | def apply_FIR_filter(self,y): 449 | y=y.unsqueeze(1) 450 | 451 | #apply the filter with a convolution (it is an FIR) 452 | y_lpf=torch.nn.functional.conv1d(y,self.filt,padding="same") 453 | y_lpf=y_lpf.squeeze(1) 454 | 455 | return y_lpf 456 | def apply_IIR_filter(self,y): 457 | y_lpf=torchaudio.functional.lfilter(y, self.a,self.b, clamp=False) 458 | return y_lpf 459 | def apply_biquad(self,y): 460 | y_lpf=torchaudio.functional.biquad(y, self.b0, self.b1, self.b2, self.a0, self.a1, self.a2) 461 | return y_lpf 462 | def decimate(self,x): 463 | return x[...,0:-1:self.factor] 464 | 465 | def resample(self,x): 466 | N=100 467 | return torchaudio.functional.resample(x,orig_freq=int(N*self.factor), new_freq=N) 468 | 469 | def predict_bwe( 470 | self, 471 | ylpf, #observations (lowpssed signal) Tensor with shape (L,) 472 | filt, #filter Tensor with shape ?? 473 | filt_type 474 | ): 475 | 476 | #define the degradation model as a lambda 477 | if filt_type=="firwin": 478 | self.filt=filt.to(ylpf.device) 479 | degradation=lambda x: self.apply_FIR_filter(x) 480 | elif filt_type=="firwin_hpf": 481 | self.filt=filt.to(ylpf.device) 482 | degradation=lambda x: self.apply_FIR_filter(x) 483 | elif filt_type=="cheby1": 484 | b,a=filt 485 | self.a=torch.Tensor(a).to(ylpf.device) 486 | self.b=torch.Tensor(b).to(ylpf.device) 487 | degradation=lambda x: self.apply_IIR_filter(x) 488 | elif filt_type=="biquad": 489 | b0, b1, b2, a0, a1, a2=filt 490 | self.b0=torch.Tensor(b0).to(ylpf.device) 491 | self.b1=torch.Tensor(b1).to(ylpf.device) 492 | self.b2=torch.Tensor(b2).to(ylpf.device) 493 | self.a0=torch.Tensor(a0).to(ylpf.device) 494 | self.a1=torch.Tensor(a1).to(ylpf.device) 495 | self.a2=torch.Tensor(a2).to(ylpf.device) 496 | degradation=lambda x: self.apply_biquad(x) 497 | elif filt_type=="resample": 498 | self.factor =filt 499 | degradation= lambda x: self.resample(x) 500 | return self.predict_resample(ylpf,(ylpf.shape[0], self.args.audio_len), degradation) 501 | elif filt_type=="decimate": 502 | self.factor =filt 503 | degradation= lambda x: self.decimate(x) 504 | return self.predict_resample(ylpf,(ylpf.shape[0], self.args.audio_len), degradation) 505 | else: 506 | raise NotImplementedError 507 | 508 | return self.predict_conditional(ylpf, degradation) 509 | 510 | 511 | -------------------------------------------------------------------------------- /testing/edm_sampler_inpainting.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import torchaudio 4 | #import scipy.signal 5 | #import numpy as np 6 | #from utils.decSTN_pytorch import apply_STN_mask 7 | 8 | class Sampler(): 9 | 10 | def __init__(self, model, diff_params, args, rid=False): 11 | 12 | self.model = model 13 | self.diff_params = diff_params #same as training, useful if we need to apply a wrapper or something 14 | self.args=args 15 | if not(self.args.tester.diff_params.same_as_training): 16 | self.update_diff_params() 17 | 18 | 19 | self.order=self.args.tester.order 20 | self.xi=self.args.tester.posterior_sampling.xi 21 | #hyperparameter for the reconstruction guidance 22 | self.data_consistency=self.args.tester.data_consistency.use and self.args.tester.data_consistency.type=="always" 23 | 24 | self.data_consistency_end=self.args.tester.data_consistency.use and self.args.tester.data_consistency.type=="end" 25 | if self.data_consistency or self.data_consistency_end: 26 | if self.args.tester.data_consistency.smooth: 27 | self.smooth=True 28 | else: 29 | self.smooth=False 30 | 31 | #use reconstruction gudance without replacement 32 | self.nb_steps=self.args.tester.T 33 | 34 | #self.treshold_on_grads=args.tester.inference.max_thresh_grads 35 | self.rid=rid #this is for logging, ignore for now 36 | 37 | #try: 38 | # self.stereo=self.args.tester.stereo 39 | #except: 40 | # self.stereo=False 41 | 42 | 43 | def update_diff_params(self): 44 | #the parameters for testing might not be necesarily the same as the ones used for training 45 | self.diff_params.sigma_min=self.args.tester.diff_params.sigma_min 46 | self.diff_params.sigma_max =self.args.tester.diff_params.sigma_max 47 | self.diff_params.ro=self.args.tester.diff_params.ro 48 | self.diff_params.sigma_data=self.args.tester.diff_params.sigma_data 49 | #par.diff_params.meters stochastic sampling 50 | self.diff_params.Schurn=self.args.tester.diff_params.Schurn 51 | self.diff_params.Stmin=self.args.tester.diff_params.Stmin 52 | self.diff_params.Stmax=self.args.tester.diff_params.Stmax 53 | self.diff_params.Snoise=self.args.tester.diff_params.Snoise 54 | 55 | 56 | 57 | def get_score_rec_guidance(self, x, y, t_i, degradation): 58 | 59 | x.requires_grad_() 60 | x_hat=self.diff_params.denoiser(x, self.model, t_i.unsqueeze(-1)) 61 | 62 | if self.args.tester.filter_out_cqt_DC_Nyq: 63 | x_hat=self.model.CQTransform.apply_hpf_DC(x_hat) 64 | 65 | den_rec= degradation(x_hat) 66 | 67 | if len(y.shape)==3: 68 | dim=(1,2) 69 | elif len(y.shape)==2: 70 | dim=1 71 | 72 | if self.args.tester.posterior_sampling.norm=="smoothl1": 73 | norm=torch.nn.functional.smooth_l1_loss(y, den_rec, reduction='sum', beta=self.args.tester.posterior_sampling.smoothl1_beta) 74 | else: 75 | norm=torch.linalg.norm(y-den_rec,dim=dim, ord=self.args.tester.posterior_sampling.norm) 76 | 77 | 78 | rec_grads=torch.autograd.grad(outputs=norm, 79 | inputs=x) 80 | 81 | rec_grads=rec_grads[0] 82 | 83 | normguide=torch.linalg.norm(rec_grads)/self.args.exp.audio_len**0.5 84 | 85 | #normalize scaling 86 | #s=self.xi/(normguide*t_i+1e-6) 87 | s=t_i*self.xi/(normguide+1e-6) 88 | 89 | #optionally apply a treshold to the gradients 90 | if False: 91 | #pply tresholding to the gradients. It is a dirty trick but helps avoiding bad artifacts 92 | rec_grads=torch.clip(rec_grads, min=-self.treshold_on_grads, max=self.treshold_on_grads) 93 | 94 | if self.rid: 95 | x_hat_old=x_hat.detach().clone() 96 | 97 | x_hat=x_hat-s*rec_grads #apply gradients 98 | 99 | if self.rid: 100 | x_hat_old_2=x_hat.detach().clone() 101 | 102 | if self.data_consistency: 103 | x_hat=self.proj_convex_set(x_hat.detach()) 104 | 105 | score=(x_hat.detach()-x)/t_i**2 106 | 107 | #apply scaled guidance to the score 108 | #score=score-s*rec_grads 109 | if self.rid: 110 | #score, denoised esimate, s*rec_grads, denoised_estimate minus gradients, x_hat after pocs 111 | return score, x_hat_old, s*rec_grads, x_hat_old_2, x_hat 112 | else: 113 | return score 114 | 115 | def get_score(self,x, y, t_i, degradation): 116 | if y==None: 117 | assert degradation==None 118 | #unconditional sampling 119 | with torch.no_grad(): 120 | #print("In sampling", x.shape, t_i.shape) 121 | x_hat=self.diff_params.denoiser(x, self.model, t_i.unsqueeze(-1)) 122 | if self.args.tester.filter_out_cqt_DC_Nyq: 123 | x_hat=self.model.CQTransform.apply_hpf_DC(x_hat) 124 | score=(x_hat-x)/t_i**2 125 | return score 126 | else: 127 | if self.xi>0: 128 | #apply rec. guidance 129 | score=self.get_score_rec_guidance(x, y, t_i, degradation) 130 | 131 | #optionally apply replacement or consistency step 132 | #if self.data_consistency: 133 | # #convert score to denoised estimate using Tweedie's formula 134 | # x_hat=score*t_i**2+x 135 | 136 | # x_hat=self.proj_convex_set(x_hat) 137 | 138 | # #convert back to score 139 | # score=(x_hat-x)/t_i**2 140 | 141 | else: 142 | #denoised with replacement method 143 | with torch.no_grad(): 144 | x_hat=self.diff_params.denoiser(x, self.model, t_i.unsqueeze(-1)) 145 | x_hat=self.proj_convex_set(x_hat.detach()) 146 | 147 | score=(x_hat.detach()-x)/t_i**2 148 | 149 | #x_hat=self.data_consistency_step(x_hat,y, degradation) 150 | 151 | #score=(x_hat-x)/t_i**2 152 | 153 | return score 154 | 155 | def predict_unconditional( 156 | self, 157 | shape, #observations (lowpssed signal) Tensor with shape ?? 158 | device 159 | ): 160 | self.y=None 161 | self.degradation=None 162 | return self.predict(shape, device) 163 | 164 | def predict_resample( 165 | self, 166 | y, #observations (lowpssed signal) Tensor with shape ?? 167 | shape, 168 | degradation, #lambda function 169 | ): 170 | self.degradation=degradation 171 | self.y=y 172 | #print(shape) 173 | return self.predict(shape, y.device) 174 | 175 | 176 | 177 | 178 | def predict( 179 | self, 180 | shape, #observations (lowpssed signal) Tensor with shape ?? 181 | device, #lambda function 182 | ): 183 | 184 | if self.rid: 185 | rid_xt=torch.zeros((self.nb_steps,shape[0], shape[1])) 186 | rid_grads=torch.zeros((self.nb_steps,shape[0], shape[1])) 187 | rid_denoised=torch.zeros((self.nb_steps,shape[0], shape[1])) 188 | rid_grad_update=torch.zeros((self.nb_steps,shape[0], shape[1])) 189 | rid_pocs=torch.zeros((self.nb_steps,shape[0], shape[1])) 190 | rid_xt2=torch.zeros((self.nb_steps,shape[0], shape[1])) 191 | 192 | #get the noise schedule 193 | t = self.diff_params.create_schedule(self.nb_steps).to(device) 194 | #sample from gaussian distribution with sigma_max variance 195 | x = self.diff_params.sample_prior(shape,t[0]).to(device) 196 | 197 | #parameter for langevin stochasticity, if Schurn is 0, gamma will be 0 to, so the sampler will be deterministic 198 | gamma=self.diff_params.get_gamma(t).to(device) 199 | 200 | 201 | for i in tqdm(range(0, self.nb_steps, 1)): 202 | #print("sampling step ",i," from ",self.nb_steps) 203 | 204 | if gamma[i]==0: 205 | #deterministic sampling, do nothing 206 | t_hat=t[i] 207 | else: 208 | #stochastic sampling 209 | #move timestep 210 | t_hat=t[i]+gamma[i]*t[i] 211 | #sample noise, Snoise is 1 by default 212 | epsilon=torch.randn(shape).to(device)*self.diff_params.Snoise 213 | #add extra noise 214 | x=x+((t_hat**2 - t[i]**2)**(1/2))*epsilon #x_hat 215 | del epsilon 216 | 217 | if self.rid: 218 | rid_xt[i]=x 219 | 220 | score=self.get_score(x, self.y, t_hat, self.degradation) 221 | if self.rid: 222 | score, x_hat1, grads, x_hat2, x_hat3=score 223 | rid_denoised[i]=x_hat1 224 | rid_grads[i]=grads 225 | rid_grad_update[i]=x_hat2 226 | rid_pocs[i]=x_hat3 227 | 228 | 229 | #d=-t_hat*((denoised-x_hat)/t_hat**2) 230 | d=-t_hat*score 231 | 232 | #apply second order correction 233 | h=t[i+1]-t_hat 234 | 235 | 236 | if t[i+1]!=0 and self.order==2: #always except last step 237 | #second order correction2 238 | #h=t[i+1]-t_hat 239 | t_prime=t[i+1] 240 | x_prime=x+h*d 241 | score=self.get_score(x_prime, self.y, t_prime, self.degradation) 242 | if self.rid: 243 | score, x_hat1, grads, x_hat2, x_hat3=score 244 | 245 | d_prime=-t_prime*score 246 | 247 | x=(x+h*((1/2)*d +(1/2)*d_prime)) 248 | 249 | elif t[i+1]==0 or self.order==1: #first condition is to avoid dividing by 0 250 | #first order Euler step 251 | x=x+h*d 252 | 253 | if self.rid: 254 | rid_xt2[i]=x 255 | 256 | if self.data_consistency_end: 257 | x=self.proj_convex_set(x) 258 | 259 | if self.rid: 260 | return x.detach(), rid_denoised.detach(), rid_grads.detach(), rid_grad_update.detach(), rid_pocs.detach(), rid_xt.detach(), rid_xt2.detach(), t.detach() 261 | else: 262 | return x.detach() 263 | 264 | def apply_mask(self, x, mask=None): 265 | 266 | if mask is None: 267 | mask=self.mask 268 | 269 | return mask*x 270 | 271 | def apply_spectral_mask(self, x): 272 | 273 | if self.args.tester.spectrogram_inpainting.stft.window=="hann": 274 | window=torch.hann_window(self.args.tester.spectrogram_inpainting.stft.win_length).to(x.device) 275 | else: 276 | raise NotImplementedError("Only hann window is implemented for now") 277 | 278 | n_fft=self.args.tester.spectrogram_inpainting.stft.n_fft 279 | #add padding to the signal 280 | input_shape=x.shape 281 | x=torch.nn.functional.pad(x, (0, n_fft-x.shape[-1]%n_fft), mode='constant', value=0) 282 | x=torch.stft(x, n_fft, self.args.tester.spectrogram_inpainting.stft.hop_length, self.args.tester.spectrogram_inpainting.stft.win_length, window, return_complex=True) 283 | #X.shape=(B, F, T) 284 | x=x*self.mask.unsqueeze(0) 285 | #apply the inverse stft 286 | x=torch.istft(x, n_fft, self.args.tester.spectrogram_inpainting.stft.hop_length, self.args.tester.spectrogram_inpainting.stft.win_length, window, return_complex=False) 287 | #why is the size different? Because of the padding? 288 | x=x[...,0:input_shape[-1]] 289 | 290 | return x 291 | 292 | 293 | #def proj_convex_set(self, x_hat, y, degradation): 294 | # """ 295 | # Simple replacement method, used for inpainting and FIR bwe 296 | # """ 297 | # #get reconstruction estimate 298 | # den_rec= degradation(x_hat) 299 | # #apply replacment (valid for linear degradations) 300 | # return y+x_hat-den_rec 301 | 302 | def prepare_smooth_mask(self, mask, size=10): 303 | hann=torch.hann_window(size*2) 304 | hann_left=hann[0:size] 305 | hann_right=hann[size::] 306 | B,N=mask.shape 307 | mask=mask[0] 308 | prev=1 309 | new_mask=mask.clone() 310 | #print(hann.shape) 311 | for i in range(len(mask)): 312 | if mask[i] != prev: 313 | #print(i, mask.shape, mask[i], prev) 314 | #transition 315 | if mask[i]==0: 316 | #gap encountered, apply hann right before 317 | new_mask[i-size:i]=hann_right 318 | if mask[i]==1: 319 | #gap encountered, apply hann left after 320 | new_mask[i:i+size]=hann_left 321 | #print(mask[i-2*size:i+2*size]) 322 | #print(new_mask[i-2*size:i+2*size]) 323 | 324 | prev=mask[i] 325 | return new_mask.unsqueeze(0).expand(B,-1) 326 | 327 | def predict_inpainting( 328 | self, 329 | y_masked, 330 | mask 331 | ): 332 | self.mask=mask.to(y_masked.device) 333 | 334 | self.y=y_masked 335 | 336 | self.degradation=lambda x: self.apply_mask(x) 337 | if self.data_consistency or self.data_consistency_end: 338 | if self.smooth: 339 | smooth_mask=self.prepare_smooth_mask(mask, self.args.tester.data_consistency.hann_size) 340 | else: 341 | smooth_mask=mask 342 | 343 | self.proj_convex_set= lambda x: smooth_mask*y_masked+(1-smooth_mask)*x #will this work? I am too scared about lambdas 344 | 345 | 346 | return self.predict(self.y.shape, self.y.device) 347 | 348 | def predict_spectrogram_inpainting( 349 | self, 350 | y_masked, 351 | mask 352 | ): 353 | self.mask=mask.to(y_masked.device) 354 | del mask 355 | 356 | self.y=y_masked 357 | del y_masked 358 | 359 | self.degradation=lambda x: self.apply_spectral_mask(x) 360 | if self.data_consistency or self.data_consistency_end: 361 | 362 | self.proj_convex_set= lambda x: self.y+x-self.apply_spectral_mask(x) #If this fails, consider using a more appropiate projection 363 | 364 | return self.predict(self.y.shape, self.y.device) 365 | 366 | 367 | 368 | -------------------------------------------------------------------------------- /testing/tester.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | import re 3 | import torch 4 | import torchaudio 5 | #from src.models.unet_cqt import Unet_CQT 6 | #from src.models.unet_stft import Unet_STFT 7 | #from src.models.unet_1d import Unet_1d 8 | #import src.utils.setup as utils_setup 9 | #from src.sde import VE_Sde_Elucidating 10 | import utils.dnnlib as dnnlib 11 | import os 12 | 13 | import utils.logging as utils_logging 14 | import wandb 15 | import copy 16 | 17 | from glob import glob 18 | from tqdm import tqdm 19 | 20 | import utils.bandwidth_extension as utils_bwe 21 | import utils.training_utils as t_utils 22 | import omegaconf 23 | 24 | 25 | class Tester(): 26 | def __init__( 27 | self, args, network, diff_params, test_set=None, device=None, it=None 28 | ): 29 | self.args=args 30 | self.network=network 31 | self.diff_params=copy.copy(diff_params) 32 | self.device=device 33 | #choose gpu as the device if possible 34 | if self.device is None: 35 | self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | 37 | self.network=network 38 | 39 | torch.backends.cudnn.benchmark = True 40 | 41 | today=date.today() 42 | if it is None: 43 | self.it=0 44 | 45 | mode='test' #this is hardcoded for now, I'll have to figure out how to deal with the subdirectories once I want to test conditional sampling 46 | self.path_sampling=os.path.join(args.model_dir,mode+today.strftime("%d_%m_%Y")+"_"+str(self.it)) 47 | if not os.path.exists(self.path_sampling): 48 | os.makedirs(self.path_sampling) 49 | 50 | 51 | #I have to rethink if I want to create the same sampler object to do conditional and unconditional sampling 52 | self.setup_sampler() 53 | 54 | self.use_wandb=False #hardcoded for now 55 | 56 | #S=2 57 | #if S>2.1 and S<2.2: 58 | # #resampling 48k to 22.05k 59 | # self.resample=torchaudio.transforms.Resample(160*2,147).to(self.device) 60 | #elif S!=1: 61 | # N=int(self.args.exp.audio_len*S) 62 | # self.resample=torchaudio.transforms.Resample(N,self.args.exp.audio_len).to(self.device) 63 | 64 | if test_set is not None: 65 | self.test_set=test_set 66 | self.do_inpainting=True 67 | self.do_bwe=True 68 | else: 69 | self.test_set=None 70 | self.do_inpainting=False 71 | self.do_bwe=False #these need to be set up in the config file 72 | 73 | self.paths={} 74 | if self.do_inpainting and ("inpainting" in self.args.tester.modes): 75 | self.do_inpainting=True 76 | mode="inpainting" 77 | self.paths[mode], self.paths[mode+"degraded"], self.paths[mode+"original"], self.paths[mode+"reconstructed"]=self.prepare_experiment("inpainting","masked","inpainted") 78 | #TODO add more information in the subirectory names 79 | else: 80 | self.do_inpainting=False 81 | 82 | if self.do_bwe and ("bwe" in self.args.tester.modes): 83 | self.do_bwe=True 84 | mode="bwe" 85 | self.paths[mode], self.paths[mode+"degraded"], self.paths[mode+"original"], self.paths[mode+"reconstructed"]=self.prepare_experiment("bwe","lowpassed","bwe") 86 | #TODO add more information in the subirectory names 87 | else: 88 | self.do_bwe=False 89 | 90 | if ("unconditional" in self.args.tester.modes): 91 | mode="unconditional" 92 | self.paths[mode]=self.prepare_unc_experiment("unconditional") 93 | 94 | try: 95 | self.stereo=self.args.tester.stereo 96 | except: 97 | self.stereo=False 98 | 99 | def prepare_unc_experiment(self, str): 100 | path_exp=os.path.join(self.path_sampling,str) 101 | if not os.path.exists(path_exp): 102 | os.makedirs(path_exp) 103 | return path_exp 104 | 105 | def prepare_experiment(self, str, str_degraded="degraded", str_reconstruced="recosntucted"): 106 | path_exp=os.path.join(self.path_sampling,str) 107 | if not os.path.exists(path_exp): 108 | os.makedirs(path_exp) 109 | 110 | n=str_degraded 111 | path_degraded=os.path.join(path_exp, n) #path for the lowpassed 112 | #ensure the path exists 113 | if not os.path.exists(path_degraded): 114 | os.makedirs(path_degraded) 115 | 116 | path_original=os.path.join(path_exp, "original") #this will need a better organization 117 | #ensure the path exists 118 | if not os.path.exists(path_original): 119 | os.makedirs(path_original) 120 | 121 | n=str_reconstruced 122 | path_reconstructed=os.path.join(path_exp, n) #path for the clipped outputs 123 | #ensure the path exists 124 | if not os.path.exists(path_reconstructed): 125 | os.makedirs(path_reconstructed) 126 | 127 | return path_exp, path_degraded, path_original, path_reconstructed 128 | 129 | 130 | def setup_wandb(self): 131 | """ 132 | Configure wandb, open a new run and log the configuration. 133 | """ 134 | config=omegaconf.OmegaConf.to_container( 135 | self.args, resolve=True, throw_on_missing=True 136 | ) 137 | self.wandb_run=wandb.init(project="testing"+self.args.exp.wandb.project, entity=self.args.exp.wandb.entity, config=config) 138 | wandb.watch(self.network, log_freq=self.args.logging.heavy_log_interval) #wanb.watch is used to log the gradients and parameters of the model to wandb. And it is used to log the model architecture and the model summary and the model graph and the model weights and the model hyperparameters and the model performance metrics. 139 | self.wandb_run.name=os.path.basename(self.args.model_dir)+"_"+self.args.exp.exp_name+"_"+self.wandb_run.id #adding the experiment number to the run name, bery important, I hope this does not crash 140 | self.use_wandb=True 141 | 142 | def setup_wandb_run(self, run): 143 | #get the wandb run object from outside (in trainer.py or somewhere else) 144 | self.wandb_run=run 145 | self.use_wandb=True 146 | 147 | def setup_sampler(self): 148 | self.sampler=dnnlib.call_func_by_name(func_name=self.args.tester.sampler_callable, model=self.network, diff_params=self.diff_params, args=self.args) 149 | 150 | def load_latest_checkpoint(self ): 151 | #load the latest checkpoint from self.args.model_dir 152 | try: 153 | # find latest checkpoint_id 154 | save_basename = f"{self.args.exp.exp_name}-*.pt" 155 | save_name = f"{self.args.model_dir}/{save_basename}" 156 | list_weights = glob(save_name) 157 | id_regex = re.compile(f"{self.args.exp.exp_name}-(\d*)\.pt") 158 | list_ids = [int(id_regex.search(weight_path).groups()[0]) 159 | for weight_path in list_weights] 160 | checkpoint_id = max(list_ids) 161 | 162 | state_dict = torch.load( 163 | f"{self.args.model_dir}/{self.args.exp.exp_name}-{checkpoint_id}.pt", map_location=self.device) 164 | try: 165 | self.network.load_state_dict(state_dict['ema']) 166 | except Exception as e: 167 | print(e) 168 | print("Failed to load in strict mode, trying again without strict mode") 169 | self.network.load_state_dict(state_dict['model'], strict=False) 170 | 171 | print(f"Loaded checkpoint {checkpoint_id}") 172 | return True 173 | except (FileNotFoundError, ValueError): 174 | raise ValueError("No checkpoint found") 175 | 176 | def load_checkpoint(self, path): 177 | state_dict = torch.load(path, map_location=self.device) 178 | try: 179 | self.it=state_dict['it'] 180 | except: 181 | self.it=0 182 | print("loading checkpoint") 183 | return t_utils.load_state_dict(state_dict, ema=self.network) 184 | 185 | def load_checkpoint_legacy(self, path): 186 | state_dict = torch.load(path, map_location=self.device) 187 | 188 | try: 189 | print("load try 1") 190 | self.network.load_state_dict(state_dict['ema']) 191 | except: 192 | #self.network.load_state_dict(state_dict['model']) 193 | try: 194 | print("load try 2") 195 | dic_ema = {} 196 | for (key, tensor) in zip(state_dict['model'].keys(), state_dict['ema_weights']): 197 | dic_ema[key] = tensor 198 | self.network.load_state_dict(dic_ema) 199 | except: 200 | print("load try 3") 201 | dic_ema = {} 202 | i=0 203 | for (key, tensor) in zip(state_dict['model'].keys(), state_dict['model'].values()): 204 | if tensor.requires_grad: 205 | dic_ema[key]=state_dict['ema_weights'][i] 206 | i=i+1 207 | else: 208 | dic_ema[key]=tensor 209 | self.network.load_state_dict(dic_ema) 210 | try: 211 | self.it=state_dict['it'] 212 | except: 213 | self.it=0 214 | 215 | def log_audio(self,preds, mode:str): 216 | string=mode+"_"+self.args.tester.name 217 | audio_path=utils_logging.write_audio_file(preds,self.args.exp.sample_rate, string,path=os.path.join(self.args.model_dir, self.paths[mode]),stereo=self.stereo) 218 | print(audio_path) 219 | if self.use_wandb: 220 | self.wandb_run.log({"audio_"+str(string): wandb.Audio(audio_path, sample_rate=self.args.exp.sample_rate)},step=self.it) 221 | #TODO: log spectrogram of the audio file to wandb 222 | spec_sample=utils_logging.plot_spectrogram_from_raw_audio(preds, self.args.logging.stft) 223 | if self.use_wandb: 224 | self.wandb_run.log({"spec_"+str(string): spec_sample}, step=self.it) 225 | 226 | def sample_unconditional(self): 227 | #the audio length is specified in the args.exp, doesnt depend on the tester 228 | if self.stereo: 229 | shape=[self.args.tester.unconditional.num_samples,2, self.args.exp.audio_len] 230 | else: 231 | shape=[self.args.tester.unconditional.num_samples, self.args.exp.audio_len] 232 | #TODO assert that the audio_len is consistent with the model 233 | preds=self.sampler.predict_unconditional(shape, self.device) 234 | if self.use_wandb: 235 | self.log_audio(preds, "unconditional") 236 | else: 237 | #TODO do something else if wandb is not used, like saving the audio file to the model directory 238 | pass 239 | 240 | return preds 241 | 242 | def test_inpainting(self): 243 | if not self.do_inpainting or self.test_set is None: 244 | print("No test set specified, skipping inpainting test") 245 | return 246 | 247 | assert self.test_set is not None 248 | 249 | self.inpainting_mask=torch.ones((1,self.args.exp.audio_len)).to(self.device) #assume between 5 and 6s of total length 250 | gap=int(self.args.tester.inpainting.gap_length*self.args.exp.sample_rate/1000) 251 | 252 | if self.args.tester.inpainting.start_gap_idx =="None": #we were crashing here! 253 | #the gap is placed at the center 254 | start_gap_index=int(self.args.exp.audio_len//2 - gap//2) 255 | else: 256 | start_gap_index=int(self.args.tester.inpainting.start_gap_idx*self.args.exp.sample_rate/1000) 257 | self.inpainting_mask[...,start_gap_index:(start_gap_index+gap)]=0 258 | 259 | if len(self.test_set) == 0: 260 | print("No samples found in test set") 261 | 262 | res=torch.zeros((len(self.test_set),self.args.exp.audio_len)) 263 | #the conditional sampling uses batch_size=1, so we need to loop over the test set. This is done for simplicity, but it is not the most efficient way to do it. 264 | for i, (original, fs, filename) in enumerate(tqdm(self.test_set)): 265 | n=os.path.splitext(filename[0])[0] 266 | original=original.float().to(self.device) 267 | seg=self.resample_audio(original, fs) 268 | #seg=torchaudio.functional.resample(seg, self.args.exp.resample_factor, 1) 269 | utils_logging.write_audio_file(seg, self.args.exp.sample_rate, n, path=self.paths["inpainting"+"original"]) 270 | masked=seg*self.inpainting_mask 271 | utils_logging.write_audio_file(masked, self.args.exp.sample_rate, n, path=self.paths["inpainting"+"degraded"]) 272 | pred=self.sampler.predict_inpainting(masked, self.inpainting_mask) 273 | utils_logging.write_audio_file(pred, self.args.exp.sample_rate, n, path=self.paths["inpainting"+"reconstructed"]) 274 | res[i,:]=pred 275 | 276 | if self.use_wandb: 277 | self.log_audio(res, "inpainting") 278 | 279 | #TODO save the files in the subdirectory inpainting of the model directory 280 | 281 | def resample_audio(self, audio, fs): 282 | #this has been reused from the trainer.py 283 | return t_utils.resample_batch(audio, fs, self.args.exp.sample_rate, self.args.exp.audio_len) 284 | 285 | def sample_inpainting(self, y, mask): 286 | 287 | y_masked=y*mask 288 | #shape=[self.args.tester.unconditional.num_samples, self.args.tester.unconditional.audio_len] 289 | #TODO assert that the audio_len is consistent with the model 290 | preds=self.sampler.predict_inpainting(y_masked, mask) 291 | 292 | return preds 293 | 294 | def test_bwe(self, typefilter="whateverIignoreit"): 295 | if not self.do_bwe or self.test_set is None: 296 | print("No test set specified, skipping inpainting test") 297 | return 298 | 299 | assert self.test_set is not None 300 | 301 | if len(self.test_set) == 0: 302 | print("No samples found in test set") 303 | 304 | #prepare lowpass filters 305 | self.filter=utils_bwe.prepare_filter(self.args, self.args.exp.sample_rate) 306 | 307 | res=torch.zeros((len(self.test_set),self.args.exp.audio_len)) 308 | #the conditional sampling uses batch_size=1, so we need to loop over the test set. This is done for simplicity, but it is not the most efficient way to do it. 309 | for i, (original, fs, filename) in enumerate(tqdm(self.test_set)): 310 | n=os.path.splitext(filename[0])[0] 311 | original=original.float().to(self.device) 312 | seg=self.resample_audio(original, fs) 313 | 314 | utils_logging.write_audio_file(seg, self.args.exp.sample_rate, n, path=self.paths["bwe"+"original"]) 315 | 316 | y=utils_bwe.apply_low_pass(seg, self.filter, self.args.tester.bandwidth_extension.filter.type) 317 | 318 | if self.args.tester.noise_in_observations_SNR != "None": 319 | SNR=10**(self.args.tester.noise_in_observations_SNR/10) 320 | sigma2_s=torch.var(y, -1) 321 | sigma=torch.sqrt(sigma2_s/SNR) 322 | y+=sigma*torch.randn(y.shape).to(y.device) 323 | 324 | utils_logging.write_audio_file(y, self.args.exp.sample_rate, n, path=self.paths["bwe"+"degraded"]) 325 | 326 | pred=self.sampler.predict_bwe(y, self.filter, self.args.tester.bandwidth_extension.filter.type) 327 | utils_logging.write_audio_file(pred, self.args.exp.sample_rate, n, path=self.paths["bwe"+"reconstructed"]) 328 | res[i,:]=pred 329 | 330 | if self.use_wandb: 331 | self.log_audio(res, "bwe") 332 | 333 | #preprocess the audio file if necessary 334 | 335 | 336 | def dodajob(self): 337 | self.setup_wandb() 338 | if "unconditional" in self.args.tester.modes: 339 | print("testing unconditional") 340 | self.sample_unconditional() 341 | self.it+=1 342 | if "blind_bwe" in self.args.tester.modes: 343 | print("testing blind bwe") 344 | #tester.test_blind_bwe(typefilter="whatever") 345 | self.tester.test_blind_bwe(typefilter="3rdoct") 346 | self.it+=1 347 | if "filter_bwe" in self.args.tester.modes: 348 | print("testing filter bwe") 349 | self.test_filter_bwe(typefilter="3rdoct") 350 | self.it+=1 351 | if "unconditional_operator" in self.args.tester.modes: 352 | print("testing unconditional operator") 353 | self.sample_unconditional_operator() 354 | self.it+=1 355 | if "bwe" in self.args.tester.modes: 356 | print("testing bwe") 357 | self.test_bwe(typefilter="3rdoct") 358 | self.it+=1 359 | if "inpainting" in self.args.tester.modes: 360 | self.test_inpainting() 361 | self.it+=1 362 | 363 | #do I want to save this audio file locally? I think I do, but I'll have to figure out how to do it 364 | 365 | 366 | -------------------------------------------------------------------------------- /testing_shortgaps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --time=2:29:59 3 | #SBATCH --mem=30G 4 | #SBATCH --cpus-per-task=1 5 | #SBATCH --job-name=sgirtgaos 6 | ##SBATCH --gres=gpu:a100:1 7 | #SBATCH --gres=gpu:1 --constraint=volta 8 | #SBATCH --output=/scratch/work/%u/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/inpainting_test_shorgapts_%j.out 9 | 10 | 11 | module load anaconda 12 | source activate /scratch/work/molinee2/conda_envs/cqtdiff 13 | module load gcc/8.4.0 14 | export TORCH_USE_RTLD_GLOBAL=YES 15 | #export HYDRA_FULL_ERROR=1 16 | #export CUDA_LAUNCH_BLOCKING=1 17 | n=$SLURM_ARRAY_TASK_ID 18 | 19 | #maestro 20 | 21 | #n=3 #original CQTDiff (with fast implementation) (22kHz) 22 | #n=49 #dance diffusion (16 kHz) 23 | #n=44 #cqtdiff+ no attention maestro 8s 24 | #n=45 #cqtdiff+ attention maestro 8s 25 | #n=50 #cqtdiff+ attention maestro 8s (alt version) 26 | 27 | #n=54 #cqtdiff+ maestro 8s (alt version) 28 | #n=50 #cqtdiff+ attention maestro 8s (alt version) 29 | 30 | #n=56 #ADP 31 | 32 | n=51 #musicnet 33 | 34 | if [[ $n -eq 54 ]] 35 | then 36 | ckpt="/scratch/work/molinee2/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/54/22k_8s-850000.pt" 37 | exp=maestro22k_8s 38 | network=paper_1912_unet_cqt_oct_noattention_adaln 39 | tester=inpainting_tester 40 | dset=maestro_allyears 41 | CQT=True 42 | elif [[ $n -eq 3 ]] 43 | then 44 | ckpt="/scratch/work/molinee2/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/3/weights-489999.pt" 45 | exp=test_cqtdiff_22k 46 | network=unet_cqtdiff_original 47 | dset=maestro_allyears 48 | CQT=False 49 | 50 | elif [[ $n -eq 56 ]] 51 | then 52 | ckpt="/scratch/work/molinee2/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/56/22k_8s-510000.pt" 53 | exp=maestro22k_131072 54 | network=ADP_raw_patching 55 | tester=inpainting_tester 56 | dset=maestro_allyears 57 | CQT=False 58 | elif [[ $n -eq 50 ]] 59 | then 60 | ckpt="/scratch/work/molinee2/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/50/22k_8s-750000.pt" 61 | exp=maestro22k_8s 62 | network=paper_1912_unet_cqt_oct_attention_adaLN_2 63 | dset=maestro_allyears 64 | tester=inpainting_tester 65 | CQT=True 66 | 67 | elif [[ $n -eq 51 ]] 68 | then 69 | ckpt="/scratch/work/molinee2/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/51/44k_4s-560000.pt" 70 | exp=musicnet44k_4s 71 | network=paper_1912_unet_cqt_oct_attention_44k_2 72 | dset=inpainting_musicnet_50 73 | #dset=inpainting_musicnet 74 | tester=inpainting_tester_shortgaps 75 | CQT=True 76 | fi 77 | 78 | 79 | PATH_EXPERIMENT=experiments/inpainting_tests/$n 80 | #PATH_EXPERIMENT=experiments/cqtdiff_original 81 | mkdir $PATH_EXPERIMENT 82 | 83 | 84 | #python train_w_cqt.py path_experiment="$PATH_EXPERIMENT" $iteration 85 | python test.py model_dir="$PATH_EXPERIMENT" \ 86 | dset=$dset \ 87 | exp=$exp \ 88 | network=$network \ 89 | tester=$tester \ 90 | tester.checkpoint=$ckpt \ 91 | tester.filter_out_cqt_DC_Nyq=$CQT 92 | 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import hydra 5 | #import click 6 | import torch 7 | import utils.dnnlib as dnnlib 8 | from utils.torch_utils import distributed as dist 9 | import utils.setup as setup 10 | from training.trainer import Trainer 11 | 12 | import warnings 13 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 14 | 15 | 16 | def parse_int_list(s): 17 | if isinstance(s, list): return s 18 | ranges = [] 19 | range_re = re.compile(r'^(\d+)-(\d+)$') 20 | for p in s.split(','): 21 | m = range_re.match(p) 22 | if m: 23 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 24 | else: 25 | ranges.append(int(p)) 26 | return ranges 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | 31 | def _main(args): 32 | 33 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | #assert torch.cuda.is_available() 35 | #device="cuda" 36 | 37 | global __file__ 38 | __file__ = hydra.utils.to_absolute_path(__file__) 39 | dirname = os.path.dirname(__file__) 40 | args.model_dir = os.path.join(dirname, str(args.model_dir)) 41 | if dist.get_rank() == 0: 42 | if not os.path.exists(args.model_dir): 43 | os.makedirs(args.model_dir) 44 | args.exp.model_dir=args.model_dir 45 | 46 | 47 | #opts = dnnlib.EasyDict(kwargs) 48 | torch.multiprocessing.set_start_method('spawn') 49 | 50 | #dist.init() 51 | dset=setup.setup_dataset(args) 52 | diff_params=setup.setup_diff_parameters(args) 53 | network=setup.setup_network(args, device) 54 | optimizer=setup.setup_optimizer(args, network) 55 | #try: 56 | test_set=setup.setup_dataset_test(args) 57 | #except: 58 | # test_set=None 59 | tester=setup.setup_tester(args, network=network, diff_params=diff_params, test_set=test_set, device=device) #this will be used for making demos during training 60 | trainer=setup.setup_trainer(args, dset=dset, network=network, optimizer=optimizer, diff_params=diff_params, tester=tester, device=device) #this will be used for making demos during training 61 | 62 | 63 | # Print options. 64 | dist.print0() 65 | dist.print0('Training options:') 66 | dist.print0() 67 | dist.print0(f'Output directory: {args.model_dir}') 68 | dist.print0(f'Network architecture: {args.network.callable}') 69 | dist.print0(f'Diffusion parameterization: {args.diff_params.callable}') 70 | dist.print0(f'Batch size: {args.exp.batch}') 71 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 72 | dist.print0(f'Mixed-precision: {args.exp.use_fp16}') 73 | dist.print0() 74 | 75 | # Train. 76 | #trainer=Trainer(args=args, dset=dset, network=network, optimizer=optimizer, diff_params=diff_params, tester=tester, device=device) 77 | trainer.training_loop() 78 | 79 | @hydra.main(config_path="conf", config_name="conf") 80 | def main(args): 81 | _main(args) 82 | 83 | if __name__ == "__main__": 84 | main() 85 | 86 | #---------------------------------------------------------------------------- 87 | -------------------------------------------------------------------------------- /training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --time=2-23:59:59 3 | ##SBATCH --time=03:59:59 4 | #SBATCH --mem=30G 5 | #SBATCH --cpus-per-task=4 6 | #SBATCH --job-name=filter_score_model 7 | #SBATCH --gres=gpu:a100:1 8 | ##SBATCH --gres=gpu:1 --constraint=volta 9 | #SBATCH --output=/scratch/work/%u/projects/ddpm/diffusion_autumn_2022/A-diffusion/experiments/%a/train_%j.out 10 | 11 | #SBATCH --array=[50] 12 | 13 | module load anaconda 14 | source activate /scratch/work/molinee2/conda_envs/cqtdiff 15 | module load gcc/8.4.0 16 | export TORCH_USE_RTLD_GLOBAL=YES 17 | #export HYDRA_FULL_ERROR=1 18 | #export CUDA_LAUNCH_BLOCKING=1 19 | n=$SLURM_ARRAY_TASK_ID 20 | 21 | n=cqtdiff+_MAESTRO #original CQTDiff (with fast implementation) (22kHz) 22 | 23 | if [[ $n -eq CQTdiff+_MAESTRO ]] 24 | then 25 | ckpt="/scratch/work/molinee2/projects/ddpm/audio-inpainting-diffusion/experiments/cqtdiff+_MAESTRO/22k_8s-750000.pt" 26 | exp=maestro22k_8s 27 | network=paper_1912_unet_cqt_oct_attention_adaLN_2 28 | dset=maestro_allyears 29 | tester=inpainting_tester 30 | CQT=True 31 | fi 32 | 33 | PATH_EXPERIMENT=experiments/$n 34 | mkdir $PATH_EXPERIMENT 35 | 36 | #python train_w_cqt.py path_experiment="$PATH_EXPERIMENT" $iteration 37 | python train.py model_dir="$PATH_EXPERIMENT" \ 38 | dset=$dset \ 39 | exp=$exp \ 40 | network=$network \ 41 | tester=$tester \ 42 | tester.checkpoint=$ckpt \ 43 | tester.filter_out_cqt_DC_Nyq=$CQT \ 44 | logging=huge_model_logging \ 45 | exp.batch=1 \ 46 | exp.resume=False 47 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | """Main training loop.""" 2 | 3 | import os 4 | import time 5 | import copy 6 | 7 | import numpy as np 8 | import torch 9 | import torchaudio 10 | 11 | from utils.torch_utils import training_stats 12 | from utils.torch_utils import misc 13 | 14 | import librosa 15 | from glob import glob 16 | import re 17 | 18 | import wandb 19 | 20 | import utils.logging as utils_logging 21 | from torch.profiler import tensorboard_trace_handler 22 | 23 | import omegaconf 24 | 25 | import utils.training_utils as t_utils 26 | 27 | #---------------------------------------------------------------------------- 28 | class Trainer(): 29 | def __init__(self, args, dset, network, optimizer, diff_params, tester=None, device='cpu'): 30 | self.args=args 31 | self.dset=dset 32 | self.network=network 33 | self.optimizer=optimizer 34 | self.diff_params=diff_params 35 | self.device=device 36 | 37 | #testing means generating demos by sampling from the model 38 | self.tester=tester 39 | if self.tester is None or not(self.args.tester.do_test): 40 | self.do_test=False 41 | else: 42 | self.do_test=True 43 | 44 | torch.manual_seed(np.random.randint(1 << 31)) 45 | torch.backends.cudnn.enabled = True 46 | torch.backends.cudnn.benchmark = True 47 | 48 | torch.backends.cudnn.allow_tf32 = True 49 | torch.backends.cuda.matmul.allow_tf32 = True 50 | torch.backends.cudnn.deterministic = False 51 | 52 | self.total_params = sum(p.numel() for p in self.network.parameters() if p.requires_grad) 53 | print("total_params: ",self.total_params/1e6, "M") 54 | 55 | self.ema = copy.deepcopy(self.network).eval().requires_grad_(False) 56 | 57 | #resume from checkpoint 58 | self.latest_checkpoint=None 59 | resuming=False 60 | if self.args.exp.resume: 61 | if self.args.exp.resume_checkpoint != "None": 62 | resuming =self.resume_from_checkpoint(checkpoint_path=self.args.exp.resume_checkpoint) 63 | else: 64 | resuming =self.resume_from_checkpoint() 65 | if not resuming: 66 | print("Could not resume from checkpoint") 67 | print("training from scratch") 68 | else: 69 | print("Resuming from iteration {}".format(self.it)) 70 | 71 | if not resuming: 72 | self.it=0 73 | self.latest_checkpoint=None 74 | 75 | if self.args.logging.print_model_summary: 76 | #if dist.get_rank() == 0: 77 | with torch.no_grad(): 78 | audio=torch.zeros([args.exp.batch,args.exp.audio_len], device=device) 79 | sigma = torch.ones([args.exp.batch], device=device).unsqueeze(-1) 80 | misc.print_module_summary(self.network, [audio, sigma ], max_nesting=2) 81 | 82 | 83 | if self.args.logging.log: 84 | self.setup_wandb() 85 | if self.do_test: 86 | self.tester.setup_wandb_run(self.wandb_run) 87 | self.setup_logging_variables() 88 | 89 | self.profile=False 90 | if self.args.logging.profiling.enabled: 91 | try: 92 | print("Profiling is being enabled") 93 | wait=self.args.logging.profiling.wait 94 | warmup=self.args.logging.profiling.warmup 95 | active=self.args.logging.profiling.active 96 | repeat=self.args.logging.profiling.repeat 97 | 98 | schedule = torch.profiler.schedule( 99 | wait=wait, warmup=warmup, active=active, repeat=repeat) 100 | self.profiler = torch.profiler.profile( 101 | schedule=schedule, on_trace_ready=tensorboard_trace_handler("wandb/latest-run/tbprofile"), profile_memory=True, with_stack=False) 102 | self.profile=True 103 | self.profile_total_steps = (wait + warmup + active) * (1 + repeat) 104 | except Exception as e: 105 | 106 | print("Could not setup profiler") 107 | print(e) 108 | self.profile=False 109 | 110 | 111 | def setup_wandb(self): 112 | """ 113 | Configure wandb, open a new run and log the configuration. 114 | """ 115 | config=omegaconf.OmegaConf.to_container( 116 | self.args, resolve=True, throw_on_missing=True 117 | ) 118 | config["total_params"]=self.total_params 119 | self.wandb_run=wandb.init(project=self.args.exp.wandb.project, entity=self.args.exp.wandb.entity, config=config) 120 | wandb.watch(self.network, log="all", log_freq=self.args.logging.heavy_log_interval) #wanb.watch is used to log the gradients and parameters of the model to wandb. And it is used to log the model architecture and the model summary and the model graph and the model weights and the model hyperparameters and the model performance metrics. 121 | self.wandb_run.name=os.path.basename(self.args.model_dir)+"_"+self.args.exp.exp_name+"_"+self.wandb_run.id #adding the experiment number to the run name, bery important, I hope this does not crash 122 | 123 | 124 | def setup_logging_variables(self): 125 | 126 | self.sigma_bins = np.logspace(np.log10(self.args.diff_params.sigma_min), np.log10(self.args.diff_params.sigma_max), num=self.args.logging.num_sigma_bins, base=10) 127 | 128 | #logarithmically spaced bins for the frequency logging 129 | self.freq_bins=np.logspace(np.log2(self.args.logging.cqt.fmin), np.log2(self.args.logging.cqt.fmin*2**(self.args.logging.cqt.num_octs)), num=self.args.logging.cqt.num_octs*self.args.logging.cqt.bins_per_oct, base=2) 130 | self.freq_bins=self.freq_bins.astype(int) 131 | 132 | 133 | 134 | def load_state_dict(self, state_dict): 135 | #print(state_dict) 136 | return t_utils.load_state_dict(state_dict, network=self.network, ema=self.ema, optimizer=self.optimizer) 137 | 138 | 139 | def resume_from_checkpoint(self, checkpoint_path=None, checkpoint_id=None): 140 | # Resume training from latest checkpoint available in the output director 141 | if checkpoint_path is not None: 142 | try: 143 | checkpoint=torch.load(checkpoint_path, map_location=self.device) 144 | print(checkpoint.keys()) 145 | #if it is possible, retrieve the iteration number from the checkpoint 146 | try: 147 | self.it = checkpoint['it'] 148 | except: 149 | self.it=157007 #large number to mean that we loaded somethin, but it is arbitrary 150 | return self.load_state_dict(checkpoint) 151 | except Exception as e: 152 | print("Could not resume from checkpoint") 153 | print(e) 154 | print("training from scratch") 155 | self.it=0 156 | return False 157 | else: 158 | try: 159 | print("trying to load a project checkpoint") 160 | print("checkpoint_id", checkpoint_id) 161 | if checkpoint_id is None: 162 | # find latest checkpoint_id 163 | save_basename = f"{self.args.exp.exp_name}-*.pt" 164 | save_name = f"{self.args.model_dir}/{save_basename}" 165 | print(save_name) 166 | list_weights = glob(save_name) 167 | id_regex = re.compile(f"{self.args.exp.exp_name}-(\d*)\.pt") 168 | list_ids = [int(id_regex.search(weight_path).groups()[0]) 169 | for weight_path in list_weights] 170 | checkpoint_id = max(list_ids) 171 | print(checkpoint_id) 172 | 173 | checkpoint = torch.load( 174 | f"{self.args.model_dir}/{self.args.exp.exp_name}-{checkpoint_id}.pt", map_location=self.device) 175 | #if it is possible, retrieve the iteration number from the checkpoint 176 | try: 177 | self.it = checkpoint['it'] 178 | except: 179 | self.it=159000 #large number to mean that we loaded somethin, but it is arbitrary 180 | self.load_state_dict(checkpoint) 181 | return True 182 | except Exception as e: 183 | print(e) 184 | return False 185 | 186 | 187 | def state_dict(self): 188 | return { 189 | 'it': self.it, 190 | 'network': self.network.state_dict(), 191 | 'optimizer': self.optimizer.state_dict(), 192 | 'ema': self.ema.state_dict(), 193 | 'args': self.args, 194 | } 195 | 196 | def save_checkpoint(self): 197 | save_basename = f"{self.args.exp.exp_name}-{self.it}.pt" 198 | save_name = f"{self.args.model_dir}/{save_basename}" 199 | torch.save(self.state_dict(), save_name) 200 | print("saving",save_name) 201 | if self.args.logging.remove_last_checkpoint: 202 | try: 203 | os.remove(self.latest_checkpoint) 204 | print("removed last checkpoint", self.latest_checkpoint) 205 | except: 206 | print("could not remove last checkpoint", self.latest_checkpoint) 207 | self.latest_checkpoint=save_name 208 | 209 | 210 | def process_loss_for_logging(self, error: torch.Tensor, sigma: torch.Tensor): 211 | """ 212 | This function is used to process the loss for logging. It is used to group the losses by the values of sigma and report them using training_stats. 213 | args: 214 | error: the error tensor with shape [batch, audio_len] 215 | sigma: the sigma tensor with shape [batch] 216 | """ 217 | #sigma values are ranged between self.args.diff_params.sigma_min and self.args.diff_params.sigma_max. We need to quantize the values of sigma into 10 logarithmically spaced bins between self.args.diff_params.sigma_min and self.args.diff_params.sigma_max 218 | torch.nan_to_num(error) #not tested might crash 219 | error=error.detach().cpu().numpy() 220 | 221 | for i in range(len(self.sigma_bins)): 222 | if i == 0: 223 | mask = sigma <= self.sigma_bins[i] 224 | elif i == len(self.sigma_bins)-1: 225 | mask = (sigma <= self.sigma_bins[i]) & (sigma > self.sigma_bins[i-1]) 226 | 227 | else: 228 | mask = (sigma <= self.sigma_bins[i]) & (sigma > self.sigma_bins[i-1]) 229 | mask=mask.squeeze(-1).cpu() 230 | if mask.sum() > 0: 231 | #find the index of the first element of the mask 232 | idx = np.where(mask==True)[0][0] 233 | 234 | training_stats.report('error_sigma_'+str(self.sigma_bins[i]),error[idx].mean()) 235 | 236 | def get_batch(self): 237 | #load the data batch 238 | if self.args.dset.name == "maestro_allyears": 239 | #this dataset has data sampled at different frequencies, so we need to resample it. The dataset returns a tuple (audio, fs), where fs is the sampling frequency of the given audio sample. Moreover, the size of the audio tensor is [B, dset.load_len], where dset.load_len is an arbitrary number designed to be sufficiently large so that the 48kHz audio samples can be loaded without any problem. We need to resample the audio tensor to the desired sampling frequency, and then crop it to the desired length. 240 | 241 | audio, fs = next(self.dset) 242 | audio=audio.to(self.device).to(torch.float32) 243 | 244 | return t_utils.resample_batch(audio, fs, self.args.exp.sample_rate, self.args.exp.audio_len) 245 | else: 246 | audio = next(self.dset) 247 | audio=audio.to(self.device).to(torch.float32) 248 | #do resampling if needed 249 | if self.args.exp.resample_factor != 1: 250 | audio=torchaudio.functional.resample(audio, self.args.exp.resample_factor, 1) 251 | 252 | return audio 253 | def train_step(self): 254 | # Train step 255 | it_start_time = time.time() 256 | #self.optimizer.zero_grad(set_to_none=True) 257 | self.optimizer.zero_grad() 258 | st_time=time.time() 259 | for round_idx in range(self.args.exp.num_accumulation_rounds): 260 | #with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 261 | audio=self.get_batch() 262 | 263 | #print(audio.shape, self.args.exp.audio_len) 264 | error, sigma = self.diff_params.loss_fn(self.network, audio) 265 | loss=error.mean() 266 | loss.backward() #TODO: take care of the loss scaling if using mixed precision 267 | #do I want to call this at every round? It will slow down the training. I will try it and see what happens 268 | 269 | 270 | 271 | if self.it <= self.args.exp.lr_rampup_it: 272 | for g in self.optimizer.param_groups: 273 | #learning rate ramp up 274 | g['lr'] = self.args.exp.lr * min(self.it / max(self.args.exp.lr_rampup_it, 1e-8), 1) 275 | 276 | 277 | if self.args.exp.use_grad_clip: 278 | torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.args.exp.max_grad_norm) 279 | 280 | # Update weights. 281 | self.optimizer.step() 282 | 283 | end_time=time.time() 284 | if self.args.logging.log: 285 | self.process_loss_for_logging(error, sigma) 286 | 287 | it_end_time = time.time() 288 | print("it :",self.it, "time:, ",end_time-st_time, "total_time: ",training_stats.report('it_time',it_end_time-it_start_time) ,"loss: ", training_stats.report('loss', loss.item())) #TODO: take care of the logging 289 | 290 | 291 | def update_ema(self): 292 | """Update exponential moving average of self.network weights.""" 293 | 294 | ema_rampup = self.args.exp.ema_rampup #ema_rampup should be set to 10000 in the config file 295 | ema_rate=self.args.exp.ema_rate #ema_rate should be set to 0.9999 in the config file 296 | t = self.it * self.args.exp.batch 297 | with torch.no_grad(): 298 | if t < ema_rampup: 299 | s = np.clip(t / ema_rampup, 0.0, ema_rate) 300 | for dst, src in zip(self.ema.parameters(), self.network.parameters()): 301 | dst.copy_(dst * s + src * (1-s)) 302 | else: 303 | for dst, src in zip(self.ema.parameters(), self.network.parameters()): 304 | dst.copy_(dst * ema_rate + src * (1-ema_rate)) 305 | 306 | def easy_logging(self): 307 | """ 308 | Do the simplest logging here. This will be called every 1000 iterations or so 309 | I will use the training_stats.report function for this, and aim to report the means and stds of the losses in wandb 310 | """ 311 | training_stats.default_collector.update() 312 | #Is it a good idea to log the stds of the losses? I think it is not. 313 | loss_mean=training_stats.default_collector.mean('loss') 314 | self.wandb_run.log({'loss':loss_mean}, step=self.it) 315 | loss_std=training_stats.default_collector.std('loss') 316 | self.wandb_run.log({'loss_std':loss_std}, step=self.it) 317 | 318 | it_time_mean=training_stats.default_collector.mean('it_time') 319 | self.wandb_run.log({'it_time_mean':it_time_mean}, step=self.it) 320 | it_time_std=training_stats.default_collector.std('it_time') 321 | self.wandb_run.log({'it_time_std':it_time_std}, step=self.it) 322 | 323 | #here reporting the error respect to sigma. I should make a fancier plot too, with mean and std 324 | sigma_means=[] 325 | sigma_stds=[] 326 | for i in range(len(self.sigma_bins)): 327 | a=training_stats.default_collector.mean('error_sigma_'+str(self.sigma_bins[i])) 328 | sigma_means.append(a) 329 | self.wandb_run.log({'error_sigma_'+str(self.sigma_bins[i]):a}, step=self.it) 330 | a=training_stats.default_collector.std('error_sigma_'+str(self.sigma_bins[i])) 331 | sigma_stds.append(a) 332 | 333 | 334 | figure=utils_logging.plot_loss_by_sigma(sigma_means,sigma_stds, self.sigma_bins) 335 | wandb.log({"loss_dependent_on_sigma": figure}, step=self.it, commit=True) 336 | 337 | 338 | def heavy_logging(self): 339 | """ 340 | Do the heavy logging here. This will be called every 10000 iterations or so 341 | """ 342 | if self.do_test: 343 | 344 | if self.latest_checkpoint is not None: 345 | self.tester.load_checkpoint(self.latest_checkpoint) 346 | 347 | preds=self.tester.sample_unconditional() 348 | preds=self.tester.test_inpainting() 349 | 350 | def log_audio(self,x, name): 351 | string=name+"_"+self.args.tester.name 352 | audio_path=utils_logging.write_audio_file(x,self.args.exp.sample_rate, string,path=self.args.model_dir) 353 | self.wandb_run.log({"audio_"+str(string): wandb.Audio(audio_path, sample_rate=self.args.exp.sample_rate)},step=self.it) 354 | #TODO: log spectrogram of the audio file to wandb 355 | spec_sample=utils_logging.plot_spectrogram_from_raw_audio(x, self.args.logging.stft) 356 | self.wandb_run.log({"spec_"+str(string): spec_sample}, step=self.it) 357 | 358 | 359 | def training_loop(self): 360 | 361 | # Initialize. 362 | 363 | while True: 364 | # Accumulate gradients. 365 | 366 | self.train_step() 367 | 368 | self.update_ema() 369 | 370 | if self.profile and self.args.logging.log: 371 | print(self.profile, self.profile_total_steps, self.it) 372 | if self.itself.profile_total_steps +1: 381 | self.profile=False 382 | 383 | 384 | 385 | if self.it>0 and self.it%self.args.logging.save_interval==0 and self.args.logging.save_model: 386 | #self.save_snapshot() #are the snapshots necessary? I think they are not. 387 | self.save_checkpoint() 388 | 389 | 390 | if self.it>0 and self.it%self.args.logging.heavy_log_interval==0 and self.args.logging.log: 391 | self.heavy_logging() 392 | #self.conditional_demos() 393 | 394 | if self.it>0 and self.it%self.args.logging.log_interval==0 and self.args.logging.log: 395 | self.easy_logging() 396 | 397 | 398 | # Update state. 399 | self.it += 1 400 | 401 | 402 | #---------------------------------------------------------------------------- 403 | -------------------------------------------------------------------------------- /utils/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path, call_func_by_name 9 | -------------------------------------------------------------------------------- /utils/dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 59 | self.file = None 60 | 61 | if file_name is not None: 62 | self.file = open(file_name, file_mode) 63 | 64 | self.should_flush = should_flush 65 | self.stdout = sys.stdout 66 | self.stderr = sys.stderr 67 | 68 | sys.stdout = self 69 | sys.stderr = self 70 | 71 | def __enter__(self) -> "Logger": 72 | return self 73 | 74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 75 | self.close() 76 | 77 | def write(self, text: Union[str, bytes]) -> None: 78 | """Write text to stdout (and a file) and optionally flush.""" 79 | if isinstance(text, bytes): 80 | text = text.decode() 81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 82 | return 83 | 84 | if self.file is not None: 85 | self.file.write(text) 86 | 87 | self.stdout.write(text) 88 | 89 | if self.should_flush: 90 | self.flush() 91 | 92 | def flush(self) -> None: 93 | """Flush written text to both stdout and a file, if open.""" 94 | if self.file is not None: 95 | self.file.flush() 96 | 97 | self.stdout.flush() 98 | 99 | def close(self) -> None: 100 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 101 | self.flush() 102 | 103 | # if using multiple loggers, prevent closing in wrong order 104 | if sys.stdout is self: 105 | sys.stdout = self.stdout 106 | if sys.stderr is self: 107 | sys.stderr = self.stderr 108 | 109 | if self.file is not None: 110 | self.file.close() 111 | self.file = None 112 | 113 | 114 | # Cache directories 115 | # ------------------------------------------------------------------------------------------ 116 | 117 | _dnnlib_cache_dir = None 118 | 119 | def set_cache_dir(path: str) -> None: 120 | global _dnnlib_cache_dir 121 | _dnnlib_cache_dir = path 122 | 123 | def make_cache_dir_path(*paths: str) -> str: 124 | if _dnnlib_cache_dir is not None: 125 | return os.path.join(_dnnlib_cache_dir, *paths) 126 | if 'DNNLIB_CACHE_DIR' in os.environ: 127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 128 | if 'HOME' in os.environ: 129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 130 | if 'USERPROFILE' in os.environ: 131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 133 | 134 | # Small util functions 135 | # ------------------------------------------------------------------------------------------ 136 | 137 | 138 | def format_time(seconds: Union[int, float]) -> str: 139 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 140 | s = int(np.rint(seconds)) 141 | 142 | if s < 60: 143 | return "{0}s".format(s) 144 | elif s < 60 * 60: 145 | return "{0}m {1:02}s".format(s // 60, s % 60) 146 | elif s < 24 * 60 * 60: 147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 148 | else: 149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 150 | 151 | 152 | def format_time_brief(seconds: Union[int, float]) -> str: 153 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 154 | s = int(np.rint(seconds)) 155 | 156 | if s < 60: 157 | return "{0}s".format(s) 158 | elif s < 60 * 60: 159 | return "{0}m {1:02}s".format(s // 60, s % 60) 160 | elif s < 24 * 60 * 60: 161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 162 | else: 163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 164 | 165 | 166 | def ask_yes_no(question: str) -> bool: 167 | """Ask the user the question until the user inputs a valid answer.""" 168 | while True: 169 | try: 170 | print("{0} [y/n]".format(question)) 171 | return strtobool(input().lower()) 172 | except ValueError: 173 | pass 174 | 175 | 176 | def tuple_product(t: Tuple) -> Any: 177 | """Calculate the product of the tuple elements.""" 178 | result = 1 179 | 180 | for v in t: 181 | result *= v 182 | 183 | return result 184 | 185 | 186 | _str_to_ctype = { 187 | "uint8": ctypes.c_ubyte, 188 | "uint16": ctypes.c_uint16, 189 | "uint32": ctypes.c_uint32, 190 | "uint64": ctypes.c_uint64, 191 | "int8": ctypes.c_byte, 192 | "int16": ctypes.c_int16, 193 | "int32": ctypes.c_int32, 194 | "int64": ctypes.c_int64, 195 | "float32": ctypes.c_float, 196 | "float64": ctypes.c_double 197 | } 198 | 199 | 200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 202 | type_str = None 203 | 204 | if isinstance(type_obj, str): 205 | type_str = type_obj 206 | elif hasattr(type_obj, "__name__"): 207 | type_str = type_obj.__name__ 208 | elif hasattr(type_obj, "name"): 209 | type_str = type_obj.name 210 | else: 211 | raise RuntimeError("Cannot infer type name from input") 212 | 213 | assert type_str in _str_to_ctype.keys() 214 | 215 | my_dtype = np.dtype(type_str) 216 | my_ctype = _str_to_ctype[type_str] 217 | 218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 219 | 220 | return my_dtype, my_ctype 221 | 222 | 223 | def is_pickleable(obj: Any) -> bool: 224 | try: 225 | with io.BytesIO() as stream: 226 | pickle.dump(obj, stream) 227 | return True 228 | except: 229 | return False 230 | 231 | 232 | # Functionality to import modules/objects by name, and call functions by name 233 | # ------------------------------------------------------------------------------------------ 234 | 235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 236 | """Searches for the underlying module behind the name to some python object. 237 | Returns the module and the object name (original name with module part removed).""" 238 | 239 | # allow convenience shorthands, substitute them by full names 240 | obj_name = re.sub("^np.", "numpy.", obj_name) 241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 242 | 243 | # list alternatives for (module_name, local_obj_name) 244 | parts = obj_name.split(".") 245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 246 | 247 | # try each alternative in turn 248 | for module_name, local_obj_name in name_pairs: 249 | try: 250 | module = importlib.import_module(module_name) # may raise ImportError 251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 252 | return module, local_obj_name 253 | except: 254 | pass 255 | 256 | # maybe some of the modules themselves contain errors? 257 | for module_name, _local_obj_name in name_pairs: 258 | try: 259 | importlib.import_module(module_name) # may raise ImportError 260 | except ImportError: 261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 262 | raise 263 | 264 | # maybe the requested attribute is missing? 265 | for module_name, local_obj_name in name_pairs: 266 | try: 267 | module = importlib.import_module(module_name) # may raise ImportError 268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 269 | except ImportError: 270 | pass 271 | 272 | # we are out of luck, but we have no idea why 273 | raise ImportError(obj_name) 274 | 275 | 276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 277 | """Traverses the object name and returns the last (rightmost) python object.""" 278 | if obj_name == '': 279 | return module 280 | obj = module 281 | for part in obj_name.split("."): 282 | obj = getattr(obj, part) 283 | return obj 284 | 285 | 286 | def get_obj_by_name(name: str) -> Any: 287 | """Finds the python object with the given name.""" 288 | module, obj_name = get_module_from_obj_name(name) 289 | return get_obj_from_module(module, obj_name) 290 | 291 | 292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 293 | """Finds the python object with the given name and calls it as a function.""" 294 | assert func_name is not None 295 | func_obj = get_obj_by_name(func_name) 296 | assert callable(func_obj) 297 | return func_obj(*args, **kwargs) 298 | 299 | 300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 301 | """Finds the python class with the given name and constructs it with the given arguments.""" 302 | return call_func_by_name(*args, func_name=class_name, **kwargs) 303 | 304 | 305 | def get_module_dir_by_obj_name(obj_name: str) -> str: 306 | """Get the directory path of the module containing the given object name.""" 307 | module, _ = get_module_from_obj_name(obj_name) 308 | return os.path.dirname(inspect.getfile(module)) 309 | 310 | 311 | def is_top_level_function(obj: Any) -> bool: 312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 314 | 315 | 316 | def get_top_level_function_name(obj: Any) -> str: 317 | """Return the fully-qualified name of a top-level function.""" 318 | assert is_top_level_function(obj) 319 | module = obj.__module__ 320 | if module == '__main__': 321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 322 | return module + "." + obj.__name__ 323 | 324 | 325 | # File system helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 329 | """List all files recursively in a given directory while ignoring given file and directory names. 330 | Returns list of tuples containing both absolute and relative paths.""" 331 | assert os.path.isdir(dir_path) 332 | base_name = os.path.basename(os.path.normpath(dir_path)) 333 | 334 | if ignores is None: 335 | ignores = [] 336 | 337 | result = [] 338 | 339 | for root, dirs, files in os.walk(dir_path, topdown=True): 340 | for ignore_ in ignores: 341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 342 | 343 | # dirs need to be edited in-place 344 | for d in dirs_to_remove: 345 | dirs.remove(d) 346 | 347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 348 | 349 | absolute_paths = [os.path.join(root, f) for f in files] 350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 351 | 352 | if add_base_to_relative: 353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 354 | 355 | assert len(absolute_paths) == len(relative_paths) 356 | result += zip(absolute_paths, relative_paths) 357 | 358 | return result 359 | 360 | 361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 362 | """Takes in a list of tuples of (src, dst) paths and copies files. 363 | Will create all necessary directories.""" 364 | for file in files: 365 | target_dir_name = os.path.dirname(file[1]) 366 | 367 | # will create all intermediate-level directories 368 | if not os.path.exists(target_dir_name): 369 | os.makedirs(target_dir_name) 370 | 371 | shutil.copyfile(file[0], file[1]) 372 | 373 | 374 | # URL helpers 375 | # ------------------------------------------------------------------------------------------ 376 | 377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 378 | """Determine whether the given object is a valid URL string.""" 379 | if not isinstance(obj, str) or not "://" in obj: 380 | return False 381 | if allow_file_urls and obj.startswith('file://'): 382 | return True 383 | try: 384 | res = requests.compat.urlparse(obj) 385 | if not res.scheme or not res.netloc or not "." in res.netloc: 386 | return False 387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 388 | if not res.scheme or not res.netloc or not "." in res.netloc: 389 | return False 390 | except: 391 | return False 392 | return True 393 | 394 | 395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 396 | """Download the given URL and return a binary-mode file object to access the data.""" 397 | assert num_attempts >= 1 398 | assert not (return_filename and (not cache)) 399 | 400 | # Doesn't look like an URL scheme so interpret it as a local filename. 401 | if not re.match('^[a-z]+://', url): 402 | return url if return_filename else open(url, "rb") 403 | 404 | # Handle file URLs. This code handles unusual file:// patterns that 405 | # arise on Windows: 406 | # 407 | # file:///c:/foo.txt 408 | # 409 | # which would translate to a local '/c:/foo.txt' filename that's 410 | # invalid. Drop the forward slash for such pathnames. 411 | # 412 | # If you touch this code path, you should test it on both Linux and 413 | # Windows. 414 | # 415 | # Some internet resources suggest using urllib.request.url2pathname() but 416 | # but that converts forward slashes to backslashes and this causes 417 | # its own set of problems. 418 | if url.startswith('file://'): 419 | filename = urllib.parse.urlparse(url).path 420 | if re.match(r'^/[a-zA-Z]:', filename): 421 | filename = filename[1:] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | assert is_url(url) 425 | 426 | # Lookup from cache. 427 | if cache_dir is None: 428 | cache_dir = make_cache_dir_path('downloads') 429 | 430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 431 | if cache: 432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 433 | if len(cache_files) == 1: 434 | filename = cache_files[0] 435 | return filename if return_filename else open(filename, "rb") 436 | 437 | # Download. 438 | url_name = None 439 | url_data = None 440 | with requests.Session() as session: 441 | if verbose: 442 | print("Downloading %s ..." % url, end="", flush=True) 443 | for attempts_left in reversed(range(num_attempts)): 444 | try: 445 | with session.get(url) as res: 446 | res.raise_for_status() 447 | if len(res.content) == 0: 448 | raise IOError("No data received") 449 | 450 | if len(res.content) < 8192: 451 | content_str = res.content.decode("utf-8") 452 | if "download_warning" in res.headers.get("Set-Cookie", ""): 453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 454 | if len(links) == 1: 455 | url = requests.compat.urljoin(url, links[0]) 456 | raise IOError("Google Drive virus checker nag") 457 | if "Google Drive - Quota exceeded" in content_str: 458 | raise IOError("Google Drive download quota exceeded -- please try again later") 459 | 460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 461 | url_name = match[1] if match else url 462 | url_data = res.content 463 | if verbose: 464 | print(" done") 465 | break 466 | except KeyboardInterrupt: 467 | raise 468 | except: 469 | if not attempts_left: 470 | if verbose: 471 | print(" failed") 472 | raise 473 | if verbose: 474 | print(".", end="", flush=True) 475 | 476 | # Save to cache. 477 | if cache: 478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 479 | safe_name = safe_name[:min(len(safe_name), 128)] 480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 482 | os.makedirs(cache_dir, exist_ok=True) 483 | with open(temp_file, "wb") as f: 484 | f.write(url_data) 485 | os.replace(temp_file, cache_file) # atomic 486 | if return_filename: 487 | return cache_file 488 | 489 | # Return data as file object. 490 | assert not return_filename 491 | return io.BytesIO(url_data) 492 | -------------------------------------------------------------------------------- /utils/setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import utils.dnnlib as dnnlib 4 | 5 | def worker_init_fn(worker_id): 6 | st=np.random.get_state()[2] 7 | np.random.seed( st+ worker_id) 8 | 9 | 10 | def setup_dataset(args): 11 | try: 12 | overfit=args.dset.overfit 13 | except: 14 | overfit=False 15 | 16 | #the dataloader loads audio at the original sampling rate, then in the training loop we resample it to the target sampling rate. The mismatch between sampling rates is indicated by the resample_factor 17 | #if resample_factor=1, then the audio is not resampled, and everything is normal 18 | #try: 19 | if args.dset.name=="maestro_allyears" or args.dset.name=="maestro_fs": 20 | dataset_obj=dnnlib.call_func_by_name(func_name=args.dset.callable, dset_args=args.dset, overfit=overfit) 21 | else: 22 | dataset_obj=dnnlib.call_func_by_name(func_name=args.dset.callable, dset_args=args.dset, fs=args.exp.sample_rate*args.exp.resample_factor, seg_len=args.exp.audio_len*args.exp.resample_factor, overfit=overfit) 23 | 24 | 25 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, batch_size=args.exp.batch, num_workers=args.exp.num_workers, pin_memory=True, worker_init_fn=worker_init_fn)) 26 | 27 | return dataset_iterator 28 | 29 | def setup_dataset_test(args): 30 | 31 | if args.dset.name=="maestro_allyears" or args.dset.name=="maestro_fs": 32 | dataset_obj=dnnlib.call_func_by_name(func_name=args.dset.test.callable, dset_args=args.dset, num_samples=args.dset.test.num_samples) 33 | else: 34 | dataset_obj=dnnlib.call_func_by_name(func_name=args.dset.test.callable, dset_args=args.dset, fs=args.exp.sample_rate*args.exp.resample_factor,seg_len=args.exp.audio_len*args.exp.resample_factor, num_samples=args.dset.test.num_samples) 35 | 36 | dataset = torch.utils.data.DataLoader(dataset=dataset_obj, batch_size=args.dset.test.batch_size, num_workers=args.exp.num_workers, pin_memory=True, worker_init_fn=worker_init_fn) 37 | 38 | return dataset 39 | 40 | def setup_diff_parameters(args): 41 | 42 | diff_params_obj=dnnlib.call_func_by_name(func_name=args.diff_params.callable, args=args) 43 | 44 | return diff_params_obj 45 | 46 | def setup_network(args, device, operator=False): 47 | 48 | try: 49 | network_obj=dnnlib.call_func_by_name(func_name=args.network.callable, args=args, device=device) 50 | except Exception as e: 51 | print(e) 52 | network_obj=dnnlib.call_func_by_name(func_name=args.network.callable, args=args.network, device=device) 53 | return network_obj.to(device) 54 | 55 | def setup_optimizer(args, network): 56 | # setuo optimizer for training 57 | optimizer = torch.optim.Adam(network.parameters(), lr=args.exp.lr, betas=(args.exp.optimizer.beta1, args.exp.optimizer.beta2), eps=args.exp.optimizer.eps) 58 | return optimizer 59 | 60 | def setup_tester(args, network=None, diff_params=None, test_set=None, device="cpu"): 61 | assert network is not None 62 | assert diff_params is not None 63 | if args.tester.do_test: 64 | # setuo sampler for making demos during training 65 | sampler = dnnlib.call_func_by_name(func_name=args.tester.callable, args=args, network=network, test_set=test_set, diff_params=diff_params, device=device) 66 | return sampler 67 | else: 68 | return None 69 | trainer=setup.setup_trainer #this will be used for making demos during training 70 | def setup_trainer(args, dset=None, network=None, optimizer=None, diff_params=None, tester=None, device="cpu"): 71 | assert network is not None 72 | assert diff_params is not None 73 | assert optimizer is not None 74 | assert tester is not None 75 | trainer = dnnlib.call_func_by_name(func_name=args.exp.trainer_callable, args=args, dset=dset, network=network, optimizer=optimizer, diff_params=diff_params, tester=tester, device=device) 76 | return trainer 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /utils/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /utils/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /utils/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import utils.dnnlib as dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /utils/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import utils.dnnlib as dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchaudio 4 | import numpy as np 5 | import scipy.signal 6 | class EMAWarmup: 7 | """Implements an EMA warmup using an inverse decay schedule. 8 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 9 | good values for models you plan to train for a million or more steps (reaches decay 10 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 11 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 12 | 215.4k steps). 13 | Args: 14 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 15 | power (float): Exponential factor of EMA warmup. Default: 1. 16 | min_value (float): The minimum EMA decay rate. Default: 0. 17 | max_value (float): The maximum EMA decay rate. Default: 1. 18 | start_at (int): The epoch to start averaging at. Default: 0. 19 | last_epoch (int): The index of last epoch. Default: 0. 20 | """ 21 | 22 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 23 | last_epoch=0): 24 | self.inv_gamma = inv_gamma 25 | self.power = power 26 | self.min_value = min_value 27 | self.max_value = max_value 28 | self.start_at = start_at 29 | self.last_epoch = last_epoch 30 | 31 | def state_dict(self): 32 | """Returns the state of the class as a :class:`dict`.""" 33 | return dict(self.__dict__.items()) 34 | 35 | def load_state_dict(self, state_dict): 36 | """Loads the class's state. 37 | Args: 38 | state_dict (dict): scaler state. Should be an object returned 39 | from a call to :meth:`state_dict`. 40 | """ 41 | self.__dict__.update(state_dict) 42 | 43 | def get_value(self): 44 | """Gets the current EMA decay rate.""" 45 | epoch = max(0, self.last_epoch - self.start_at) 46 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 47 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 48 | 49 | def step(self): 50 | """Updates the step count.""" 51 | self.last_epoch += 1 52 | 53 | 54 | #from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/perceptual.py 55 | class FIRFilter(torch.nn.Module): 56 | """FIR pre-emphasis filtering module. 57 | Args: 58 | filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp" 59 | coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85 60 | ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101 61 | plot (bool): Plot the magnitude respond of the filter. Default: False 62 | Based upon the perceptual loss pre-empahsis filters proposed by 63 | [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). 64 | A-weighting filter - "aw" 65 | First-order highpass - "hp" 66 | Folded differentiator - "fd" 67 | Note that the default coefficeint value of 0.85 is optimized for 68 | a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. 69 | """ 70 | 71 | def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): 72 | """Initilize FIR pre-emphasis filtering module.""" 73 | super(FIRFilter, self).__init__() 74 | self.filter_type = filter_type 75 | self.coef = coef 76 | self.fs = fs 77 | self.ntaps = ntaps 78 | self.plot = plot 79 | 80 | if ntaps % 2 == 0: 81 | raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") 82 | 83 | if filter_type == "hp": 84 | self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) 85 | self.fir.weight.requires_grad = False 86 | self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) 87 | elif filter_type == "fd": 88 | self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) 89 | self.fir.weight.requires_grad = False 90 | self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1) 91 | elif filter_type == "aw": 92 | # Definition of analog A-weighting filter according to IEC/CD 1672. 93 | f1 = 20.598997 94 | f2 = 107.65265 95 | f3 = 737.86223 96 | f4 = 12194.217 97 | A1000 = 1.9997 98 | 99 | NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0] 100 | DENs = np.polymul( 101 | [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], 102 | [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], 103 | ) 104 | DENs = np.polymul( 105 | np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] 106 | ) 107 | 108 | # convert analog filter to digital filter 109 | b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) 110 | 111 | # compute the digital filter frequency response 112 | w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) 113 | 114 | # then we fit to 101 tap FIR filter with least squares 115 | taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) 116 | 117 | # now implement this digital FIR filter as a Conv1d layer 118 | self.fir = torch.nn.Conv1d( 119 | 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 120 | ) 121 | self.fir.weight.requires_grad = False 122 | self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) 123 | 124 | def forward(self, error): 125 | """Calculate forward propagation. 126 | Args: 127 | input (Tensor): Predicted signal (B, #channels, #samples). 128 | target (Tensor): Groundtruth signal (B, #channels, #samples). 129 | Returns: 130 | Tensor: Filtered signal. 131 | """ 132 | self.fir.weight.data=self.fir.weight.data.to(error.device) 133 | error=error.unsqueeze(1) 134 | error = torch.nn.functional.conv1d( 135 | error, self.fir.weight.data, padding=self.ntaps // 2 136 | ) 137 | error=error.squeeze(1) 138 | return error 139 | 140 | def resample_batch(audio, fs, fs_target, length_target): 141 | 142 | device=audio.device 143 | dtype=audio.dtype 144 | B=audio.shape[0] 145 | #if possible resampe in a batched way 146 | #check if all the fs are the same and equal to 44100 147 | if fs_target==22050: 148 | if (fs==44100).all(): 149 | audio=torchaudio.functional.resample(audio, 2,1) 150 | return audio[:, 0:length_target] #trow away the last samples 151 | elif (fs==48000).all(): 152 | #approcimate resamppleint 153 | audio=torchaudio.functional.resample(audio, 160*2,147) 154 | return audio[:, 0:length_target] 155 | else: 156 | #if revious is unsuccesful bccause we have examples at 441000 and 48000 in the same batch,, just iterate over the batch 157 | proc_batch=torch.zeros((B,length_target), device=device) 158 | for i, (a, f_s) in enumerate(zip(audio, fs)): #I hope this shit wll not slow down everythingh 159 | if f_s==44100: 160 | #resample by 2 161 | a=torchaudio.functional.resample(a, 2,1) 162 | elif f_s==48000: 163 | a=torchaudio.functional.resample(a, 160*2,147) 164 | else: 165 | print("WARNING, strange fs", f_s) 166 | 167 | proc_batch[i]=a[0:length_target] 168 | return proc_batch 169 | elif fs_target==44100: 170 | if (fs==44100).all(): 171 | return audio[:, 0:length_target] #trow away the last samples 172 | elif (fs==48000).all(): 173 | #approcimate resamppleint 174 | audio=torchaudio.functional.resample(audio, 160,147) 175 | return audio[:, 0:length_target] 176 | else: 177 | #if revious is unsuccesful bccause we have examples at 441000 and 48000 in the same batch,, just iterate over the batch 178 | proc_batch=torch.zeros((B,length_target), device=device) 179 | for i, (a, f_s) in enumerate(zip(audio, fs)): #I hope this shit wll not slow down everythingh 180 | if f_s==44100: 181 | #resample by 2 182 | pass 183 | elif f_s==48000: 184 | a=torchaudio.functional.resample(a, 160,147) 185 | else: 186 | print("WARNING, strange fs", f_s) 187 | 188 | proc_batch[i]=a[0:length_target] 189 | return proc_batch 190 | else: 191 | print(" resampling to fs_target", fs_target) 192 | if (fs==44100).all(): 193 | audio=torchaudio.functional.resample(audio, 44100, fs_target) 194 | return audio[:, 0:length_target] #trow away the last samples 195 | elif (fs==48000).all(): 196 | #approcimate resamppleint 197 | audio=torchaudio.functional.resample(audio, 48000,fs_target) 198 | return audio[:, 0:length_target] 199 | else: 200 | #if revious is unsuccesful bccause we have examples at 441000 and 48000 in the same batch,, just iterate over the batch 201 | proc_batch=torch.zeros((B,length_target), device=device) 202 | for i, (a, f_s) in enumerate(zip(audio, fs)): #I hope this shit wll not slow down everythingh 203 | if f_s==44100: 204 | #resample by 2 205 | a=torchaudio.functional.resample(a, 44100,fs_target) 206 | elif f_s==48000: 207 | a=torchaudio.functional.resample(a, 48000,fs_target) 208 | else: 209 | print("WARNING, strange fs", f_s) 210 | 211 | proc_batch[i]=a[0:length_target] 212 | return proc_batch 213 | 214 | def load_state_dict( state_dict, network=None, ema=None, optimizer=None, log=True): 215 | ''' 216 | utility for loading state dicts for different models. This function sequentially tries different strategies 217 | args: 218 | state_dict: the state dict to load 219 | returns: 220 | True if the state dict was loaded, False otherwise 221 | Assuming the operations are don in_place, this function will not create a copy of the network and optimizer (I hope) 222 | ''' 223 | #print(state_dict) 224 | if log: print("Loading state dict") 225 | if log: 226 | print(state_dict.keys()) 227 | #if there 228 | try: 229 | if log: print("Attempt 1: trying with strict=True") 230 | if network is not None: 231 | network.load_state_dict(state_dict['network']) 232 | if optimizer is not None: 233 | optimizer.load_state_dict(state_dict['optimizer']) 234 | if ema is not None: 235 | ema.load_state_dict(state_dict['ema']) 236 | return True 237 | except Exception as e: 238 | if log: 239 | print("Could not load state dict") 240 | print(e) 241 | try: 242 | if log: print("Attempt 2: trying with strict=False") 243 | if network is not None: 244 | network.load_state_dict(state_dict['network'], strict=False) 245 | #we cannot load the optimizer in this setting 246 | #self.optimizer.load_state_dict(state_dict['optimizer'], strict=False) 247 | if ema is not None: 248 | ema.load_state_dict(state_dict['ema'], strict=False) 249 | return True 250 | except Exception as e: 251 | if log: 252 | print("Could not load state dict") 253 | print(e) 254 | print("training from scratch") 255 | try: 256 | if log: print("Attempt 3: trying with strict=False,but making sure that the shapes are fine") 257 | if ema is not None: 258 | ema_state_dict = ema.state_dict() 259 | if network is not None: 260 | network_state_dict = network.state_dict() 261 | i=0 262 | if network is not None: 263 | for name, param in state_dict['network'].items(): 264 | if log: print("checking",name) 265 | if name in network_state_dict.keys(): 266 | if network_state_dict[name].shape==param.shape: 267 | network_state_dict[name]=param 268 | if log: 269 | print("assigning",name) 270 | i+=1 271 | network.load_state_dict(network_state_dict) 272 | if ema is not None: 273 | for name, param in state_dict['ema'].items(): 274 | if log: print("checking",name) 275 | if name in ema_state_dict.keys(): 276 | if ema_state_dict[name].shape==param.shape: 277 | ema_state_dict[name]=param 278 | if log: 279 | print("assigning",name) 280 | i+=1 281 | 282 | ema.load_state_dict(ema_state_dict) 283 | 284 | if i==0: 285 | if log: print("WARNING, no parameters were loaded") 286 | raise Exception("No parameters were loaded") 287 | elif i>0: 288 | if log: print("loaded", i, "parameters") 289 | return True 290 | 291 | except Exception as e: 292 | print(e) 293 | print("the second strict=False failed") 294 | 295 | 296 | try: 297 | if log: print("Attempt 4: Assuming the naming is different, with the network and ema called 'state_dict'") 298 | if network is not None: 299 | network.load_state_dict(state_dict['state_dict']) 300 | if ema is not None: 301 | ema.load_state_dict(state_dict['state_dict']) 302 | except Exception as e: 303 | if log: 304 | print("Could not load state dict") 305 | print(e) 306 | print("training from scratch") 307 | print("It failed 3 times!! but not giving up") 308 | #print the names of the parameters in self.network 309 | 310 | try: 311 | if log: print("Attempt 5: trying to load with different names, now model='model' and ema='ema_weights'") 312 | if ema is not None: 313 | dic_ema = {} 314 | for (key, tensor) in zip(state_dict['model'].keys(), state_dict['ema_weights']): 315 | dic_ema[key] = tensor 316 | ema.load_state_dict(dic_ema) 317 | return True 318 | except Exception as e: 319 | if log: 320 | print(e) 321 | 322 | try: 323 | if log: print("Attempt 6: If there is something wrong with the name of the ema parameters, we can try to load them using the names of the parameters in the model") 324 | if ema is not None: 325 | dic_ema = {} 326 | i=0 327 | for (key, tensor) in zip(state_dict['model'].keys(), state_dict['model'].values()): 328 | if tensor.requires_grad: 329 | dic_ema[key]=state_dict['ema_weights'][i] 330 | i=i+1 331 | else: 332 | dic_ema[key]=tensor 333 | ema.load_state_dict(dic_ema) 334 | return True 335 | except Exception as e: 336 | if log: 337 | print(e) 338 | 339 | #try: 340 | #assign the parameters in state_dict to self.network using a for loop 341 | print("Attempt 7: Trying to load the parameters one by one. This is for the dance diffusion model, looking for parameters starting with 'diffusion.' or 'diffusion_ema.'") 342 | if ema is not None: 343 | ema_state_dict = ema.state_dict() 344 | if network is not None: 345 | network_state_dict = ema.state_dict() 346 | i=0 347 | if network is not None: 348 | for name, param in state_dict['state_dict'].items(): 349 | print("checking",name) 350 | if name.startswith("diffusion."): 351 | i+=1 352 | name=name.replace("diffusion.","") 353 | if network_state_dict[name].shape==param.shape: 354 | #print(param.shape, network.state_dict()[name].shape) 355 | network_state_dict[name]=param 356 | #print("assigning",name) 357 | 358 | network.load_state_dict(network_state_dict, strict=False) 359 | 360 | if ema is not None: 361 | for name, param in state_dict['state_dict'].items(): 362 | if name.startswith("diffusion_ema."): 363 | i+=1 364 | name=name.replace("diffusion_ema.","") 365 | if ema_state_dict[name].shape==param.shape: 366 | if log: 367 | print(param.shape, ema.state_dict()[name].shape) 368 | ema_state_dict[name]=param 369 | 370 | ema.load_state_dict(ema_state_dict, strict=False) 371 | 372 | if i==0: 373 | print("WARNING, no parameters were loaded") 374 | raise Exception("No parameters were loaded") 375 | elif i>0: 376 | print("loaded", i, "parameters") 377 | return True 378 | #except Exception as e: 379 | # if log: 380 | # print(e) 381 | 382 | return False 383 | 384 | 385 | -------------------------------------------------------------------------------- /utils/utils_notebook.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | import torch 3 | import numpy as np 4 | 5 | def load_audio(name="test_dir/0.wav",Ls=65536): 6 | x, fs=sf.read(name) 7 | x=torch.Tensor(x) 8 | #x=x[1*fs:2*fs] 9 | Ls=65536 10 | x=x[0:Ls] 11 | return x, fs 12 | 13 | def save_wav(x, fs=22050, filename="test.wav"): 14 | x=x.numpy() 15 | sf.write(filename, x, fs) 16 | 17 | def plot_stft(x): 18 | NFFT=1024 19 | #hamming window 20 | window = torch.hann_window(NFFT) 21 | #apply STFT to x 22 | X = torch.stft(x, NFFT, hop_length=NFFT//2, win_length=NFFT, window=window, center=False, normalized=False, onesided=True) 23 | 24 | freqs=np.fft.rfftfreq(NFFT, 1/fs) 25 | 26 | X_abs=(X[...,0]**2+X[...,1]**2)**0.5 27 | #plot absolute value of STFT using px 28 | fig = px.imshow(20*np.log10(X_abs.numpy()+1e-8), labels=dict(x="Time", y="Frequency", color="Magnitude")) 29 | 30 | fig.show() --------------------------------------------------------------------------------