├── Arguments.md ├── Dockerfile_torch1p11_Dockerfile_torch1p11_diffiner ├── LICENSE ├── README.md ├── figures └── overview_SE_refiner.png ├── guided_diffusion ├── __init__.py ├── diffiner_util.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── run_example.sh ├── scripts ├── informed_denoiser.py ├── run_refiner.py └── train.py ├── speech_dataloader ├── data.py └── utils.py └── start-container-torch1p11_diffiner.sh /Arguments.md: -------------------------------------------------------------------------------- 1 | # Parameters for running scripts 2 | 3 | ## Common Parameters 4 | The following parameters are used in both the `run_refiner.py` and `stft_speech_train.py` scripts. 5 | | Command line Argument | Description | Default | 6 | |----------------------------|---------------------------------------------------------------------------------|-----------------| 7 | | `--image-size ` | Size of the complex spectrogram. Corresponds to the number of frequency bins and time frames. | `256` | 8 | | `--num-channels ` | Parameter for the pre-trained diffusion model | `128` | 9 | | `--num-res-blocks ` | Parameter for the pre-trained diffusion model | `3` | 10 | | `--diffusion-steps ` | the number of diffusion steps when the model is trained. | `4000` | 11 | | `--noise-scheduler ` | Type of noise scheduler used in model training | `linear` | 12 | | `--use-fp16` | change inference on fp32 to fp16 | `False` | 13 | 14 | ## Parameters for `run_refiner.py` (Diffiner/Diffiner+) 15 | 16 | | Command line Argument | Description | Default | 17 | |----------------------------|---------------------------------------------------------------------------------|-----------------| 18 | |`--simple-diffiner` | switch the running mode from `Diffiner+` (default) to `Diffiner` | `False` | 19 | |`--root-noisy ` | path to a directory storing the target noisy (=unprocessed) speeches | | 20 | |`--root-proc ` | path to a directory storing the target pre-processed speeches (aiming to refine) | | 21 | |`--max-dur ` | expected maximum duration of input speech. A longer speech than this duration will be automatically cut | `10.0` | 22 | |`--model-path ` | path to the pretrained model used to run refiner | | 23 | | `--etas list(str)` | a list of all etas (Diffiner; $\eta_a$ and $\eta_b$, Diffiner+; $\eta_a$, $\eta_b$ $\eta_c$ and $\eta_d$). The default sets were same as those of our experiments (used in our paper) | {Diffiner; $\eta_a=0.9$ and $\eta_b=0.9$}, {Diffiner+; $\eta_a=0.4$, $\eta_b=0.6$ $\eta_c=0.0$ and $\eta_d=0.6$} | 24 | | `--clip-denoised` | When ```--clip-denoised=True```, the clip function is applied at each diffsuion step. This technique is commonly used in the image domain and should be set to ```False``` in the complex spectrogram case.| `False` | 25 | | `--timestep-respacing ` | Time step selection for execution during sampling from training time steps. See ```guided_diffusion/respace.py``` for more details. | `ddim200` | 26 | | `--batch-size ` | the number of wav files to refine simultaneously. You should decide this depending on the memory size of GPU | `8` | 27 | | `--no-gpu` | run refiner w/o GPU | `False` | 28 | 29 | ## Parameters for `stft_speech_train.py` 30 | | Command line Argument | Description | Default | 31 | |----------------------------|---------------------------------------------------------------------------------|-----------------| 32 | |`--weight-decay `| Regularization term to prevent overfitting by penalizing large weights | `0.0` | 33 | |`--lr ` | learning rate of the optimizer (Adam)| `1e-4`| 34 | |`--batch-size `|Number of samples processed in one iteration of training|`1`| 35 | |`--ema-rate `| Rate for Exponential Moving Average (EMA)| `0.9999`| 36 | |`--log-interval `| Interval (in steps) at which training progress is logged | `10`| 37 | |`--save-interval `| Interval (in steps) at which the model checkpoints are saved| `10000`| 38 | |`--resume-checkpoint `| Path to a previous checkpoint to resume training from, if any | | 39 | |`--seq-dur `| sequence duration in seconds for the input speech | `4.2`| -------------------------------------------------------------------------------- /Dockerfile_torch1p11_Dockerfile_torch1p11_diffiner: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel 2 | 3 | # Update GPG key of NVIDIA Docker Images 4 | # (See https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/ for more detail) 5 | RUN rm -f /etc/apt/sources.list.d/cuda.list \ 6 | && apt-get update && apt-get install -y --no-install-recommends \ 7 | wget \ 8 | && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb \ 9 | && dpkg -i cuda-keyring_1.0-1_all.deb \ 10 | && rm -f cuda-keyring_1.0-1_all.deb 11 | 12 | # Install basic utilities 13 | RUN apt-get clean && \ 14 | apt-get -y update && \ 15 | apt-get install -y --no-install-recommends \ 16 | # add basic apt packages 17 | && apt-get -y install nano wget curl git zip unzip \ 18 | && apt-get -y install ca-certificates sudo bzip2 libx11-6 emacs htop docker.io sox \ 19 | && apt-get clean \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | # Install python packages 23 | RUN pip install --upgrade pip 24 | 25 | RUN pip install matplotlib seaborn 26 | RUN pip install ipywidgets jupyter ipykernel 27 | 28 | # In order to install mpi4py, we have to 29 | RUN apt-get -y update 30 | RUN apt-get -y install libopenmpi-dev 31 | 32 | RUN pip install mpi4py 33 | 34 | RUN pip install 'blobfile>=1.0.5' tqdm 35 | 36 | # Install torch audio 37 | RUN pip3 install torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 38 | 39 | # Install soundfile 40 | RUN apt-get -y update 41 | RUN apt-get -y install libsndfile1 42 | RUN pip install soundfile librosa pypesq wavio 43 | 44 | # set filesystem encoding to utf-8 45 | # so that soundfile module can read filename with not ascii code. 46 | ENV LC_CTYPE "C.UTF-8" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sony Group Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffiner: A Versatile Diffusion-based Generative Refiner for Speech Enhancement (INTERSPEECH 2023) 2 | 3 | ![cover-img](./figures/overview_SE_refiner.png) 4 | 5 | This repository includes the code implementation for the paper titled "Diffiner: A Versatile Diffusion-based Generative Refiner for Speech Enhancement" presented at INTERSPEECH 2023. 6 | 7 | ## Abstract 8 | Although deep neural network (DNN)-based speech enhancement (SE) methods outperform the previous non-DNN-based ones, they often degrade the perceptual quality of generated outputs. To tackle this problem, we introduce a DNN-based generative refiner, Diffiner, aiming to improve perceptual speech quality pre-processed by an SE method. We train a diffusion-based generative model by utilizing a dataset consisting of clean speech only.Then, our refiner effectively mixes clean parts newly generated via denoising diffusion restoration into the degraded and distorted parts caused by a preceding SE method, resulting in refined speech. Once our refiner is trained on a set of clean speech, it can be applied to various SE methods without additional training specialized for each SE module. Therefore, our refiner can be a versatile post-processing module w.r.t. SE methods and has high potential in terms of modularity. 9 | 10 | ## Paper 11 | 12 | [Paper on arXiv](https://arxiv.org/abs/2210.17287v2) 13 | 14 | [Paper on archive of Interspeech 2023](https://www.isca-speech.org/archive/interspeech_2023/sawata23_interspeech.html) 15 | 16 | [Creative AI: Demo and Papers, Sony Group Corporation](https://sony.github.io/creativeai/) 17 | 18 | ## Authors 19 | Ryosuke Sawata1, Naoki Murata2, Yuhta Takida2, Toshimitsu Uesaka2, Takashi Shibuya2, Shusuke Takahashi1, Yuki Mitsufuji1, 2 20 | 21 | 1 Sony Group Corporation, 2 Sony AI 22 | 23 | # Getting started 24 | ## Installation using Docker 25 | To set up the environment, we provide a Dockerfile. 26 | ### Building the Docker Image 27 | In the working directroy (e.g., `diffiner/`), execute the following command: 28 | ``` 29 | docker build -t diffiner -f Dockerfile_torch1p11_diffiner . 30 | ``` 31 | ### Starting a Docker Container 32 | run the command: 33 | ``` 34 | bash start-container-torch1p11_ddgm_se.sh 35 | ``` 36 | Please modify `start-container-torch1p11_diffiner.sh` script to match your environment. 37 | 38 | # Training 39 | You can train a diffusion model for the complex spectrograms using `train.py`. 40 | ``` 41 | python scripts/train.py --root dir_train_wav --image_size 256 --num_channels 128 --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule linear --lr 1e-4 --batch_size 2 42 | ``` 43 | The number of frequency bins and time steps in the complex spectrogram are determined by the `--imaze_size` parameter. 44 | 45 | ### Training data 46 | The training data should be placed in the `dir_train_wav` directroy, containing the required `.wav` files for training. 47 | 48 | ### Parameters 49 | For detaild information reagrding the parameters to be specified during training, please refer to the [Arguments.md](./Arguments.md). 50 | 51 | ## Pre-trained model 52 | We offer a pre-trained model that can be downloaded from [here](https://zenodo.org/record/7988790). 53 | 54 | # Run Diffiner/Diffiner+ (Inference) 55 | After training and preceding SE, you can run Diffiner/Diffiner+ as follows: 56 | ``` 57 | python scripts/run_refiner.py --root-noisy [dir to noisy] --root-proc [dir to pre-processed by preceding SE] --model-path [dir to the trained *.pt] 58 | ``` 59 | The default mode is running `Diffiner+`, but you can switch to `Diffiner` by just adding the option `--simple-diffiner`. The results processed by Diffiner/Diffiner+ will be stored in `[dir to pre-processed by preceding SE]/[diffiner or diffiner+]_etaA=***_etaB=***_etaC=***_etaD=***`, which it is automatically made when running this script. Note that the values actually used to run `run_refiner.py` will be automatically substituted for `***` of the name of result directory. All necessary parameters, e.g., $\eta_a$ and $\eta_b$, are defaultly set so as to be same as our experiments. The details of all required parameters to run Diffiner/Diffiner+ are given in the [Arguments.md](./Arguments.md). 60 | 61 | # Examples of refinement on VoiceBank+Demand 62 | You can easily run the pre-trained Diffiner+ on VoiceBank+Demand corpus, which it is pre-processed by Deep Complex U-net (DCUnet): 63 | ``` 64 | ./run_example.sh 65 | ``` 66 | First of all, downloading VoiceBank+Demand corpus, its pre-processed results by DCUnet, and our pre-trained diffiner will start. The all of audio files will be defaultly stored in `./data`, and the pre-trained model weights of diffiner will be stored in your current directory. Note that both DCUnet and diffiner were trained on VoiceBank+Demand corpus only, which this settings were adopoted in our paper's experiments. After that, running the pre-trained Diffiner+ on VoiceBank+Demand corpus pre-processed by the aforementioned DCUnet will nextly start. Finally, you can get examples refined by Diffiner+ in `./data/proc_dcunet/[results of Diffiner+]`. 67 | 68 | ## Citation 69 | If you find this work useful, please consider citing the paper: 70 | ``` 71 | @inproceedings{sawata23_interspeech, 72 | author={Ryosuke Sawata and Naoki Murata and Yuhta Takida and Toshimitsu Uesaka and Takashi Shibuya and Shusuke Takahashi and Yuki Mitsufuji}, 73 | title={{Diffiner: A Versatile Diffusion-based Generative Refiner for Speech Enhancement}}, 74 | year=2023, 75 | booktitle={Proc. INTERSPEECH 2023}, 76 | pages={3824--3828}, 77 | doi={10.21437/Interspeech.2023-1547} 78 | } 79 | ``` 80 | 81 | This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion). 82 | -------------------------------------------------------------------------------- /figures/overview_SE_refiner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/diffiner/b8ee5962fe17be0b70b923b0236e51db5311b29a/figures/overview_SE_refiner.png -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided_diffusion/diffiner_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | ) 64 | res.update(diffusion_defaults()) 65 | return res 66 | 67 | 68 | def classifier_and_diffusion_defaults(): 69 | res = classifier_defaults() 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | 74 | def create_model_and_diffusion( 75 | image_size, 76 | class_cond, 77 | learn_sigma, 78 | num_channels, 79 | num_res_blocks, 80 | channel_mult, 81 | num_heads, 82 | num_head_channels, 83 | num_heads_upsample, 84 | attention_resolutions, 85 | dropout, 86 | diffusion_steps, 87 | noise_schedule, 88 | timestep_respacing, 89 | use_kl, 90 | predict_xstart, 91 | rescale_timesteps, 92 | rescale_learned_sigmas, 93 | use_checkpoint, 94 | use_scale_shift_norm, 95 | resblock_updown, 96 | use_fp16, 97 | use_new_attention_order, 98 | image_channels=3, 99 | complex_conv=False, 100 | normalize=False, 101 | direct_out=False, 102 | ): 103 | model = create_model( 104 | image_size, 105 | num_channels, 106 | num_res_blocks, 107 | channel_mult=channel_mult, 108 | learn_sigma=learn_sigma, 109 | class_cond=class_cond, 110 | use_checkpoint=use_checkpoint, 111 | attention_resolutions=attention_resolutions, 112 | num_heads=num_heads, 113 | num_head_channels=num_head_channels, 114 | num_heads_upsample=num_heads_upsample, 115 | use_scale_shift_norm=use_scale_shift_norm, 116 | dropout=dropout, 117 | resblock_updown=resblock_updown, 118 | use_fp16=use_fp16, 119 | use_new_attention_order=use_new_attention_order, 120 | image_channels=image_channels, 121 | complex_conv=complex_conv, 122 | normalize=normalize, 123 | direct_out=direct_out, 124 | ) 125 | diffusion = create_gaussian_diffusion( 126 | steps=diffusion_steps, 127 | learn_sigma=learn_sigma, 128 | noise_schedule=noise_schedule, 129 | use_kl=use_kl, 130 | predict_xstart=predict_xstart, 131 | rescale_timesteps=rescale_timesteps, 132 | rescale_learned_sigmas=rescale_learned_sigmas, 133 | timestep_respacing=timestep_respacing, 134 | ) 135 | return model, diffusion 136 | 137 | 138 | def create_model( 139 | image_size, 140 | num_channels, 141 | num_res_blocks, 142 | channel_mult="", 143 | learn_sigma=False, 144 | class_cond=False, 145 | use_checkpoint=False, 146 | attention_resolutions="16", 147 | num_heads=1, 148 | num_head_channels=-1, 149 | num_heads_upsample=-1, 150 | use_scale_shift_norm=False, 151 | dropout=0, 152 | resblock_updown=False, 153 | use_fp16=False, 154 | use_new_attention_order=False, 155 | image_channels=3, 156 | complex_conv=False, 157 | normalize=False, 158 | direct_out=False, 159 | ): 160 | 161 | if channel_mult == "": 162 | if image_size == 512: 163 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 164 | elif image_size == 256: 165 | channel_mult = (1, 1, 2, 2, 4, 4) 166 | elif image_size == 128: 167 | channel_mult = (1, 1, 2, 3, 4) 168 | elif image_size == 64: 169 | channel_mult = (1, 2, 3, 4) 170 | else: 171 | raise ValueError(f"unsupported image size: {image_size}") 172 | else: 173 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 174 | 175 | attention_ds = [] 176 | for res in attention_resolutions.split(","): 177 | attention_ds.append(image_size // int(res)) 178 | 179 | return UNetModel( 180 | image_size=image_size, 181 | in_channels=image_channels, 182 | model_channels=num_channels, 183 | out_channels=(image_channels if not learn_sigma else 4), 184 | num_res_blocks=num_res_blocks, 185 | attention_resolutions=tuple(attention_ds), 186 | dropout=dropout, 187 | channel_mult=channel_mult, 188 | num_classes=(NUM_CLASSES if class_cond else None), 189 | use_checkpoint=use_checkpoint, 190 | use_fp16=use_fp16, 191 | num_heads=num_heads, 192 | num_head_channels=num_head_channels, 193 | num_heads_upsample=num_heads_upsample, 194 | use_scale_shift_norm=use_scale_shift_norm, 195 | resblock_updown=resblock_updown, 196 | use_new_attention_order=use_new_attention_order, 197 | complex_conv=complex_conv, 198 | ) 199 | 200 | 201 | def create_classifier_and_diffusion( 202 | image_size, 203 | classifier_use_fp16, 204 | classifier_width, 205 | classifier_depth, 206 | classifier_attention_resolutions, 207 | classifier_use_scale_shift_norm, 208 | classifier_resblock_updown, 209 | classifier_pool, 210 | learn_sigma, 211 | diffusion_steps, 212 | noise_schedule, 213 | timestep_respacing, 214 | use_kl, 215 | predict_xstart, 216 | rescale_timesteps, 217 | rescale_learned_sigmas, 218 | ): 219 | classifier = create_classifier( 220 | image_size, 221 | classifier_use_fp16, 222 | classifier_width, 223 | classifier_depth, 224 | classifier_attention_resolutions, 225 | classifier_use_scale_shift_norm, 226 | classifier_resblock_updown, 227 | classifier_pool, 228 | ) 229 | diffusion = create_gaussian_diffusion( 230 | steps=diffusion_steps, 231 | learn_sigma=learn_sigma, 232 | noise_schedule=noise_schedule, 233 | use_kl=use_kl, 234 | predict_xstart=predict_xstart, 235 | rescale_timesteps=rescale_timesteps, 236 | rescale_learned_sigmas=rescale_learned_sigmas, 237 | timestep_respacing=timestep_respacing, 238 | ) 239 | return classifier, diffusion 240 | 241 | 242 | def create_classifier( 243 | image_size, 244 | classifier_use_fp16, 245 | classifier_width, 246 | classifier_depth, 247 | classifier_attention_resolutions, 248 | classifier_use_scale_shift_norm, 249 | classifier_resblock_updown, 250 | classifier_pool, 251 | ): 252 | if image_size == 512: 253 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 254 | elif image_size == 256: 255 | channel_mult = (1, 1, 2, 2, 4, 4) 256 | elif image_size == 128: 257 | channel_mult = (1, 1, 2, 3, 4) 258 | elif image_size == 64: 259 | channel_mult = (1, 2, 3, 4) 260 | else: 261 | raise ValueError(f"unsupported image size: {image_size}") 262 | 263 | attention_ds = [] 264 | for res in classifier_attention_resolutions.split(","): 265 | attention_ds.append(image_size // int(res)) 266 | 267 | return EncoderUNetModel( 268 | image_size=image_size, 269 | in_channels=3, 270 | model_channels=classifier_width, 271 | out_channels=1000, 272 | num_res_blocks=classifier_depth, 273 | attention_resolutions=tuple(attention_ds), 274 | channel_mult=channel_mult, 275 | use_fp16=classifier_use_fp16, 276 | num_head_channels=64, 277 | use_scale_shift_norm=classifier_use_scale_shift_norm, 278 | resblock_updown=classifier_resblock_updown, 279 | pool=classifier_pool, 280 | ) 281 | 282 | 283 | def sr_model_and_diffusion_defaults(): 284 | res = model_and_diffusion_defaults() 285 | res["large_size"] = 256 286 | res["small_size"] = 64 287 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 288 | for k in res.copy().keys(): 289 | if k not in arg_names: 290 | del res[k] 291 | return res 292 | 293 | 294 | def sr_create_model_and_diffusion( 295 | large_size, 296 | small_size, 297 | class_cond, 298 | learn_sigma, 299 | num_channels, 300 | num_res_blocks, 301 | num_heads, 302 | num_head_channels, 303 | num_heads_upsample, 304 | attention_resolutions, 305 | dropout, 306 | diffusion_steps, 307 | noise_schedule, 308 | timestep_respacing, 309 | use_kl, 310 | predict_xstart, 311 | rescale_timesteps, 312 | rescale_learned_sigmas, 313 | use_checkpoint, 314 | use_scale_shift_norm, 315 | resblock_updown, 316 | use_fp16, 317 | ): 318 | model = sr_create_model( 319 | large_size, 320 | small_size, 321 | num_channels, 322 | num_res_blocks, 323 | learn_sigma=learn_sigma, 324 | class_cond=class_cond, 325 | use_checkpoint=use_checkpoint, 326 | attention_resolutions=attention_resolutions, 327 | num_heads=num_heads, 328 | num_head_channels=num_head_channels, 329 | num_heads_upsample=num_heads_upsample, 330 | use_scale_shift_norm=use_scale_shift_norm, 331 | dropout=dropout, 332 | resblock_updown=resblock_updown, 333 | use_fp16=use_fp16, 334 | ) 335 | diffusion = create_gaussian_diffusion( 336 | steps=diffusion_steps, 337 | learn_sigma=learn_sigma, 338 | noise_schedule=noise_schedule, 339 | use_kl=use_kl, 340 | predict_xstart=predict_xstart, 341 | rescale_timesteps=rescale_timesteps, 342 | rescale_learned_sigmas=rescale_learned_sigmas, 343 | timestep_respacing=timestep_respacing, 344 | ) 345 | return model, diffusion 346 | 347 | 348 | def sr_create_model( 349 | large_size, 350 | small_size, 351 | num_channels, 352 | num_res_blocks, 353 | learn_sigma, 354 | class_cond, 355 | use_checkpoint, 356 | attention_resolutions, 357 | num_heads, 358 | num_head_channels, 359 | num_heads_upsample, 360 | use_scale_shift_norm, 361 | dropout, 362 | resblock_updown, 363 | use_fp16, 364 | ): 365 | _ = small_size # hack to prevent unused variable 366 | 367 | if large_size == 512: 368 | channel_mult = (1, 1, 2, 2, 4, 4) 369 | elif large_size == 256: 370 | channel_mult = (1, 1, 2, 2, 4, 4) 371 | elif large_size == 64: 372 | channel_mult = (1, 2, 3, 4) 373 | else: 374 | raise ValueError(f"unsupported large size: {large_size}") 375 | 376 | attention_ds = [] 377 | for res in attention_resolutions.split(","): 378 | attention_ds.append(large_size // int(res)) 379 | 380 | return SuperResModel( 381 | image_size=large_size, 382 | in_channels=3, 383 | model_channels=num_channels, 384 | out_channels=(3 if not learn_sigma else 6), 385 | num_res_blocks=num_res_blocks, 386 | attention_resolutions=tuple(attention_ds), 387 | dropout=dropout, 388 | channel_mult=channel_mult, 389 | num_classes=(NUM_CLASSES if class_cond else None), 390 | use_checkpoint=use_checkpoint, 391 | num_heads=num_heads, 392 | num_head_channels=num_head_channels, 393 | num_heads_upsample=num_heads_upsample, 394 | use_scale_shift_norm=use_scale_shift_norm, 395 | resblock_updown=resblock_updown, 396 | use_fp16=use_fp16, 397 | ) 398 | 399 | 400 | def create_gaussian_diffusion( 401 | *, 402 | steps=1000, 403 | learn_sigma=False, 404 | sigma_small=False, 405 | noise_schedule="linear", 406 | use_kl=False, 407 | predict_xstart=False, 408 | rescale_timesteps=False, 409 | rescale_learned_sigmas=False, 410 | timestep_respacing="", 411 | ): 412 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 413 | if use_kl: 414 | loss_type = gd.LossType.RESCALED_KL 415 | elif rescale_learned_sigmas: 416 | loss_type = gd.LossType.RESCALED_MSE 417 | else: 418 | loss_type = gd.LossType.MSE 419 | if not timestep_respacing: 420 | timestep_respacing = [steps] 421 | return SpacedDiffusion( 422 | use_timesteps=space_timesteps(steps, timestep_respacing), 423 | betas=betas, 424 | model_mean_type=( 425 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 426 | ), 427 | model_var_type=( 428 | ( 429 | gd.ModelVarType.FIXED_LARGE 430 | if not sigma_small 431 | else gd.ModelVarType.FIXED_SMALL 432 | ) 433 | if not learn_sigma 434 | else gd.ModelVarType.LEARNED_RANGE 435 | ), 436 | loss_type=loss_type, 437 | rescale_timesteps=rescale_timesteps, 438 | ) 439 | 440 | 441 | def add_dict_to_argparser(parser, default_dict): 442 | for k, v in default_dict.items(): 443 | v_type = type(v) 444 | if v is None: 445 | v_type = str 446 | elif isinstance(v, bool): 447 | v_type = str2bool 448 | parser.add_argument(f"--{k}", default=v, type=v_type) 449 | 450 | 451 | def args_to_dict(args, keys): 452 | return {k: getattr(args, k) for k in keys} 453 | 454 | 455 | def str2bool(v): 456 | """ 457 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 458 | """ 459 | if isinstance(v, bool): 460 | return v 461 | if v.lower() in ("yes", "true", "t", "y", "1"): 462 | return True 463 | elif v.lower() in ("no", "false", "f", "n", "0"): 464 | return False 465 | else: 466 | raise argparse.ArgumentTypeError("boolean value expected") 467 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /guided_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | from .nn import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | 22 | The beta schedule library consists of beta schedules which remain similar 23 | in the limit of num_diffusion_timesteps. 24 | Beta schedules may be added, but should not be removed or changed once 25 | they are committed to maintain backwards compatibility. 26 | """ 27 | if schedule_name == "linear": 28 | # Linear schedule from Ho et al, extended to work for any number of 29 | # diffusion steps. 30 | scale = 1000 / num_diffusion_timesteps 31 | beta_start = scale * 0.0001 32 | beta_end = scale * 0.02 33 | return np.linspace( 34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 35 | ) 36 | elif schedule_name == "cosine": 37 | return betas_for_alpha_bar( 38 | num_diffusion_timesteps, 39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 40 | ) 41 | else: 42 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 43 | 44 | 45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 46 | """ 47 | Create a beta schedule that discretizes the given alpha_t_bar function, 48 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 49 | 50 | :param num_diffusion_timesteps: the number of betas to produce. 51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 52 | produces the cumulative product of (1-beta) up to that 53 | part of the diffusion process. 54 | :param max_beta: the maximum beta to use; use values lower than 1 to 55 | prevent singularities. 56 | """ 57 | betas = [] 58 | for i in range(num_diffusion_timesteps): 59 | t1 = i / num_diffusion_timesteps 60 | t2 = (i + 1) / num_diffusion_timesteps 61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 62 | return np.array(betas) 63 | 64 | 65 | class ModelMeanType(enum.Enum): 66 | """ 67 | Which type of output the model predicts. 68 | """ 69 | 70 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 71 | START_X = enum.auto() # the model predicts x_0 72 | EPSILON = enum.auto() # the model predicts epsilon 73 | 74 | 75 | class ModelVarType(enum.Enum): 76 | """ 77 | What is used as the model's output variance. 78 | 79 | The LEARNED_RANGE option has been added to allow the model to predict 80 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 81 | """ 82 | 83 | LEARNED = enum.auto() 84 | FIXED_SMALL = enum.auto() 85 | FIXED_LARGE = enum.auto() 86 | LEARNED_RANGE = enum.auto() 87 | 88 | 89 | class LossType(enum.Enum): 90 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 91 | RESCALED_MSE = ( 92 | enum.auto() 93 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 94 | KL = enum.auto() # use the variational lower-bound 95 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 96 | 97 | def is_vb(self): 98 | return self == LossType.KL or self == LossType.RESCALED_KL 99 | 100 | 101 | class GaussianDiffusion: 102 | """ 103 | Utilities for training and sampling diffusion models. 104 | 105 | Ported directly from here, and then adapted over time to further experimentation. 106 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 107 | 108 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 109 | starting at T and going to 1. 110 | :param model_mean_type: a ModelMeanType determining what the model outputs. 111 | :param model_var_type: a ModelVarType determining how variance is output. 112 | :param loss_type: a LossType determining the loss function to use. 113 | :param rescale_timesteps: if True, pass floating point timesteps into the 114 | model so that they are always scaled like in the 115 | original paper (0 to 1000). 116 | """ 117 | 118 | def __init__( 119 | self, 120 | *, 121 | betas, 122 | model_mean_type, 123 | model_var_type, 124 | loss_type, 125 | rescale_timesteps=False, 126 | ): 127 | self.model_mean_type = model_mean_type 128 | self.model_var_type = model_var_type 129 | self.loss_type = loss_type 130 | self.rescale_timesteps = rescale_timesteps 131 | 132 | # Use float64 for accuracy. 133 | betas = np.array(betas, dtype=np.float64) 134 | self.betas = betas 135 | assert len(betas.shape) == 1, "betas must be 1-D" 136 | assert (betas > 0).all() and (betas <= 1).all() 137 | 138 | self.num_timesteps = int(betas.shape[0]) 139 | 140 | alphas = 1.0 - betas 141 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 142 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 143 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 144 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 145 | 146 | # calculations for diffusion q(x_t | x_{t-1}) and others 147 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 148 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 149 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 150 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 151 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 152 | 153 | # calculations for posterior q(x_{t-1} | x_t, x_0) 154 | self.posterior_variance = ( 155 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 156 | ) 157 | # log calculation clipped because the posterior variance is 0 at the 158 | # beginning of the diffusion chain. 159 | self.posterior_log_variance_clipped = np.log( 160 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 161 | ) 162 | self.posterior_mean_coef1 = ( 163 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 164 | ) 165 | self.posterior_mean_coef2 = ( 166 | (1.0 - self.alphas_cumprod_prev) 167 | * np.sqrt(alphas) 168 | / (1.0 - self.alphas_cumprod) 169 | ) 170 | 171 | def q_mean_variance(self, x_start, t): 172 | """ 173 | Get the distribution q(x_t | x_0). 174 | 175 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 176 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 177 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 178 | """ 179 | mean = ( 180 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 181 | ) 182 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 183 | log_variance = _extract_into_tensor( 184 | self.log_one_minus_alphas_cumprod, t, x_start.shape 185 | ) 186 | return mean, variance, log_variance 187 | 188 | def q_sample(self, x_start, t, noise=None): 189 | """ 190 | Diffuse the data for a given number of diffusion steps. 191 | 192 | In other words, sample from q(x_t | x_0). 193 | 194 | :param x_start: the initial data batch. 195 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 196 | :param noise: if specified, the split-out normal noise. 197 | :return: A noisy version of x_start. 198 | """ 199 | if noise is None: 200 | noise = th.randn_like(x_start) 201 | assert noise.shape == x_start.shape 202 | return ( 203 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 204 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 205 | * noise 206 | ) 207 | 208 | def q_posterior_mean_variance(self, x_start, x_t, t): 209 | """ 210 | Compute the mean and variance of the diffusion posterior: 211 | 212 | q(x_{t-1} | x_t, x_0) 213 | 214 | """ 215 | assert x_start.shape == x_t.shape 216 | posterior_mean = ( 217 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 218 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 219 | ) 220 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 221 | posterior_log_variance_clipped = _extract_into_tensor( 222 | self.posterior_log_variance_clipped, t, x_t.shape 223 | ) 224 | assert ( 225 | posterior_mean.shape[0] 226 | == posterior_variance.shape[0] 227 | == posterior_log_variance_clipped.shape[0] 228 | == x_start.shape[0] 229 | ) 230 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 231 | 232 | def p_mean_variance( 233 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 234 | ): 235 | """ 236 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 237 | the initial x, x_0. 238 | 239 | :param model: the model, which takes a signal and a batch of timesteps 240 | as input. 241 | :param x: the [N x C x ...] tensor at time t. 242 | :param t: a 1-D Tensor of timesteps. 243 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 244 | :param denoised_fn: if not None, a function which applies to the 245 | x_start prediction before it is used to sample. Applies before 246 | clip_denoised. 247 | :param model_kwargs: if not None, a dict of extra keyword arguments to 248 | pass to the model. This can be used for conditioning. 249 | :return: a dict with the following keys: 250 | - 'mean': the model mean output. 251 | - 'variance': the model variance output. 252 | - 'log_variance': the log of 'variance'. 253 | - 'pred_xstart': the prediction for x_0. 254 | """ 255 | if model_kwargs is None: 256 | model_kwargs = {} 257 | 258 | B, C = x.shape[:2] 259 | assert t.shape == (B,) 260 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 261 | 262 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 263 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 264 | model_output, model_var_values = th.split(model_output, C, dim=1) 265 | if self.model_var_type == ModelVarType.LEARNED: 266 | model_log_variance = model_var_values 267 | model_variance = th.exp(model_log_variance) 268 | else: 269 | min_log = _extract_into_tensor( 270 | self.posterior_log_variance_clipped, t, x.shape 271 | ) 272 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 273 | # The model_var_values is [-1, 1] for [min_var, max_var]. 274 | frac = (model_var_values + 1) / 2 275 | model_log_variance = frac * max_log + (1 - frac) * min_log 276 | model_variance = th.exp(model_log_variance) 277 | else: 278 | model_variance, model_log_variance = { 279 | # for fixedlarge, we set the initial (log-)variance like so 280 | # to get a better decoder log likelihood. 281 | ModelVarType.FIXED_LARGE: ( 282 | np.append(self.posterior_variance[1], self.betas[1:]), 283 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 284 | ), 285 | ModelVarType.FIXED_SMALL: ( 286 | self.posterior_variance, 287 | self.posterior_log_variance_clipped, 288 | ), 289 | }[self.model_var_type] 290 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 291 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 292 | 293 | def process_xstart(x): 294 | if denoised_fn is not None: 295 | x = denoised_fn(x) 296 | if clip_denoised: 297 | return x.clamp(-1, 1) 298 | return x 299 | 300 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 301 | pred_xstart = process_xstart( 302 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 303 | ) 304 | model_mean = model_output 305 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 306 | if self.model_mean_type == ModelMeanType.START_X: 307 | pred_xstart = process_xstart(model_output) 308 | else: 309 | pred_xstart = process_xstart( 310 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 311 | ) 312 | model_mean, _, _ = self.q_posterior_mean_variance( 313 | x_start=pred_xstart, x_t=x, t=t 314 | ) 315 | else: 316 | raise NotImplementedError(self.model_mean_type) 317 | 318 | assert ( 319 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 320 | ) 321 | return { 322 | "mean": model_mean, 323 | "variance": model_variance, 324 | "log_variance": model_log_variance, 325 | "pred_xstart": pred_xstart, 326 | } 327 | 328 | def _predict_xstart_from_eps(self, x_t, t, eps): 329 | assert x_t.shape == eps.shape 330 | return ( 331 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 332 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 333 | ) 334 | 335 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 336 | assert x_t.shape == xprev.shape 337 | return ( # (xprev - coef2*x_t) / coef1 338 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 339 | - _extract_into_tensor( 340 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 341 | ) 342 | * x_t 343 | ) 344 | 345 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 346 | return ( 347 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 348 | - pred_xstart 349 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 350 | 351 | def _scale_timesteps(self, t): 352 | if self.rescale_timesteps: 353 | return t.float() * (1000.0 / self.num_timesteps) 354 | return t 355 | 356 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 357 | """ 358 | Compute the mean for the previous step, given a function cond_fn that 359 | computes the gradient of a conditional log probability with respect to 360 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 361 | condition on y. 362 | 363 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 364 | """ 365 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) 366 | new_mean = ( 367 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 368 | ) 369 | return new_mean 370 | 371 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 372 | """ 373 | Compute what the p_mean_variance output would have been, should the 374 | model's score function be conditioned by cond_fn. 375 | 376 | See condition_mean() for details on cond_fn. 377 | 378 | Unlike condition_mean(), this instead uses the conditioning strategy 379 | from Song et al (2020). 380 | """ 381 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 382 | 383 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 384 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn( 385 | x, self._scale_timesteps(t), **model_kwargs 386 | ) 387 | 388 | out = p_mean_var.copy() 389 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 390 | out["mean"], _, _ = self.q_posterior_mean_variance( 391 | x_start=out["pred_xstart"], x_t=x, t=t 392 | ) 393 | return out 394 | 395 | def p_sample( 396 | self, 397 | model, 398 | x, 399 | t, 400 | clip_denoised=True, 401 | denoised_fn=None, 402 | cond_fn=None, 403 | model_kwargs=None, 404 | ): 405 | """ 406 | Sample x_{t-1} from the model at the given timestep. 407 | 408 | :param model: the model to sample from. 409 | :param x: the current tensor at x_{t-1}. 410 | :param t: the value of t, starting at 0 for the first diffusion step. 411 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 412 | :param denoised_fn: if not None, a function which applies to the 413 | x_start prediction before it is used to sample. 414 | :param cond_fn: if not None, this is a gradient function that acts 415 | similarly to the model. 416 | :param model_kwargs: if not None, a dict of extra keyword arguments to 417 | pass to the model. This can be used for conditioning. 418 | :return: a dict containing the following keys: 419 | - 'sample': a random sample from the model. 420 | - 'pred_xstart': a prediction of x_0. 421 | """ 422 | out = self.p_mean_variance( 423 | model, 424 | x, 425 | t, 426 | clip_denoised=clip_denoised, 427 | denoised_fn=denoised_fn, 428 | model_kwargs=model_kwargs, 429 | ) 430 | noise = th.randn_like(x) 431 | nonzero_mask = ( 432 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 433 | ) # no noise when t == 0 434 | if cond_fn is not None: 435 | out["mean"] = self.condition_mean( 436 | cond_fn, out, x, t, model_kwargs=model_kwargs 437 | ) 438 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 439 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 440 | 441 | def p_sample_loop( 442 | self, 443 | model, 444 | shape, 445 | noise=None, 446 | clip_denoised=True, 447 | denoised_fn=None, 448 | cond_fn=None, 449 | model_kwargs=None, 450 | device=None, 451 | progress=False, 452 | ): 453 | """ 454 | Generate samples from the model. 455 | 456 | :param model: the model module. 457 | :param shape: the shape of the samples, (N, C, H, W). 458 | :param noise: if specified, the noise from the encoder to sample. 459 | Should be of the same shape as `shape`. 460 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 461 | :param denoised_fn: if not None, a function which applies to the 462 | x_start prediction before it is used to sample. 463 | :param cond_fn: if not None, this is a gradient function that acts 464 | similarly to the model. 465 | :param model_kwargs: if not None, a dict of extra keyword arguments to 466 | pass to the model. This can be used for conditioning. 467 | :param device: if specified, the device to create the samples on. 468 | If not specified, use a model parameter's device. 469 | :param progress: if True, show a tqdm progress bar. 470 | :return: a non-differentiable batch of samples. 471 | """ 472 | final = None 473 | for sample in self.p_sample_loop_progressive( 474 | model, 475 | shape, 476 | noise=noise, 477 | clip_denoised=clip_denoised, 478 | denoised_fn=denoised_fn, 479 | cond_fn=cond_fn, 480 | model_kwargs=model_kwargs, 481 | device=device, 482 | progress=progress, 483 | ): 484 | final = sample 485 | return final["sample"] 486 | 487 | def p_sample_loop_progressive( 488 | self, 489 | model, 490 | shape, 491 | noise=None, 492 | clip_denoised=True, 493 | denoised_fn=None, 494 | cond_fn=None, 495 | model_kwargs=None, 496 | device=None, 497 | progress=False, 498 | ): 499 | """ 500 | Generate samples from the model and yield intermediate samples from 501 | each timestep of diffusion. 502 | 503 | Arguments are the same as p_sample_loop(). 504 | Returns a generator over dicts, where each dict is the return value of 505 | p_sample(). 506 | """ 507 | if device is None: 508 | device = next(model.parameters()).device 509 | assert isinstance(shape, (tuple, list)) 510 | if noise is not None: 511 | img = noise 512 | else: 513 | img = th.randn(*shape, device=device) 514 | indices = list(range(self.num_timesteps))[::-1] 515 | 516 | if progress: 517 | # Lazy import so that we don't depend on tqdm. 518 | from tqdm.auto import tqdm 519 | 520 | indices = tqdm(indices) 521 | 522 | for i in indices: 523 | t = th.tensor([i] * shape[0], device=device) 524 | with th.no_grad(): 525 | out = self.p_sample( 526 | model, 527 | img, 528 | t, 529 | clip_denoised=clip_denoised, 530 | denoised_fn=denoised_fn, 531 | cond_fn=cond_fn, 532 | model_kwargs=model_kwargs, 533 | ) 534 | yield out 535 | img = out["sample"] 536 | 537 | def ddim_sample( 538 | self, 539 | model, 540 | x, 541 | t, 542 | clip_denoised=True, 543 | denoised_fn=None, 544 | cond_fn=None, 545 | model_kwargs=None, 546 | eta=0.0, 547 | ): 548 | """ 549 | Sample x_{t-1} from the model using DDIM. 550 | 551 | Same usage as p_sample(). 552 | """ 553 | out = self.p_mean_variance( 554 | model, 555 | x, 556 | t, 557 | clip_denoised=clip_denoised, 558 | denoised_fn=denoised_fn, 559 | model_kwargs=model_kwargs, 560 | ) 561 | if cond_fn is not None: 562 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 563 | 564 | # Usually our model outputs epsilon, but we re-derive it 565 | # in case we used x_start or x_prev prediction. 566 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 567 | 568 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 569 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 570 | sigma = ( 571 | eta 572 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 573 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 574 | ) 575 | # Equation 12. 576 | noise = th.randn_like(x) 577 | mean_pred = ( 578 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 579 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 580 | ) 581 | nonzero_mask = ( 582 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 583 | ) # no noise when t == 0 584 | sample = mean_pred + nonzero_mask * sigma * noise 585 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 586 | 587 | def ddim_reverse_sample( 588 | self, 589 | model, 590 | x, 591 | t, 592 | clip_denoised=True, 593 | denoised_fn=None, 594 | model_kwargs=None, 595 | eta=0.0, 596 | ): 597 | """ 598 | Sample x_{t+1} from the model using DDIM reverse ODE. 599 | """ 600 | assert eta == 0.0, "Reverse ODE only for deterministic path" 601 | out = self.p_mean_variance( 602 | model, 603 | x, 604 | t, 605 | clip_denoised=clip_denoised, 606 | denoised_fn=denoised_fn, 607 | model_kwargs=model_kwargs, 608 | ) 609 | # Usually our model outputs epsilon, but we re-derive it 610 | # in case we used x_start or x_prev prediction. 611 | eps = ( 612 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 613 | - out["pred_xstart"] 614 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 615 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 616 | 617 | # Equation 12. reversed 618 | mean_pred = ( 619 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 620 | + th.sqrt(1 - alpha_bar_next) * eps 621 | ) 622 | 623 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 624 | 625 | def ddim_sample_loop( 626 | self, 627 | model, 628 | shape, 629 | noise=None, 630 | clip_denoised=True, 631 | denoised_fn=None, 632 | cond_fn=None, 633 | model_kwargs=None, 634 | device=None, 635 | progress=False, 636 | eta=0.0, 637 | ): 638 | """ 639 | Generate samples from the model using DDIM. 640 | 641 | Same usage as p_sample_loop(). 642 | """ 643 | final = None 644 | for sample in self.ddim_sample_loop_progressive( 645 | model, 646 | shape, 647 | noise=noise, 648 | clip_denoised=clip_denoised, 649 | denoised_fn=denoised_fn, 650 | cond_fn=cond_fn, 651 | model_kwargs=model_kwargs, 652 | device=device, 653 | progress=progress, 654 | eta=eta, 655 | ): 656 | final = sample 657 | return final["sample"] 658 | 659 | def ddim_sample_loop_progressive( 660 | self, 661 | model, 662 | shape, 663 | noise=None, 664 | clip_denoised=True, 665 | denoised_fn=None, 666 | cond_fn=None, 667 | model_kwargs=None, 668 | device=None, 669 | progress=False, 670 | eta=0.0, 671 | ): 672 | """ 673 | Use DDIM to sample from the model and yield intermediate samples from 674 | each timestep of DDIM. 675 | 676 | Same usage as p_sample_loop_progressive(). 677 | """ 678 | if device is None: 679 | device = next(model.parameters()).device 680 | assert isinstance(shape, (tuple, list)) 681 | if noise is not None: 682 | img = noise 683 | else: 684 | img = th.randn(*shape, device=device) 685 | indices = list(range(self.num_timesteps))[::-1] 686 | 687 | if progress: 688 | # Lazy import so that we don't depend on tqdm. 689 | from tqdm.auto import tqdm 690 | 691 | indices = tqdm(indices) 692 | 693 | for i in indices: 694 | t = th.tensor([i] * shape[0], device=device) 695 | with th.no_grad(): 696 | out = self.ddim_sample( 697 | model, 698 | img, 699 | t, 700 | clip_denoised=clip_denoised, 701 | denoised_fn=denoised_fn, 702 | cond_fn=cond_fn, 703 | model_kwargs=model_kwargs, 704 | eta=eta, 705 | ) 706 | yield out 707 | img = out["sample"] 708 | 709 | def _vb_terms_bpd( 710 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 711 | ): 712 | """ 713 | Get a term for the variational lower-bound. 714 | 715 | The resulting units are bits (rather than nats, as one might expect). 716 | This allows for comparison to other papers. 717 | 718 | :return: a dict with the following keys: 719 | - 'output': a shape [N] tensor of NLLs or KLs. 720 | - 'pred_xstart': the x_0 predictions. 721 | """ 722 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 723 | x_start=x_start, x_t=x_t, t=t 724 | ) 725 | out = self.p_mean_variance( 726 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 727 | ) 728 | kl = normal_kl( 729 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 730 | ) 731 | kl = mean_flat(kl) / np.log(2.0) 732 | 733 | decoder_nll = -discretized_gaussian_log_likelihood( 734 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 735 | ) 736 | assert decoder_nll.shape == x_start.shape 737 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 738 | 739 | # At the first timestep return the decoder NLL, 740 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 741 | output = th.where((t == 0), decoder_nll, kl) 742 | return {"output": output, "pred_xstart": out["pred_xstart"]} 743 | 744 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 745 | """ 746 | Compute training losses for a single timestep. 747 | 748 | :param model: the model to evaluate loss on. 749 | :param x_start: the [N x C x ...] tensor of inputs. 750 | :param t: a batch of timestep indices. 751 | :param model_kwargs: if not None, a dict of extra keyword arguments to 752 | pass to the model. This can be used for conditioning. 753 | :param noise: if specified, the specific Gaussian noise to try to remove. 754 | :return: a dict with the key "loss" containing a tensor of shape [N]. 755 | Some mean or variance settings may also have other keys. 756 | """ 757 | if model_kwargs is None: 758 | model_kwargs = {} 759 | if noise is None: 760 | noise = th.randn_like(x_start) 761 | x_t = self.q_sample(x_start, t, noise=noise) 762 | 763 | terms = {} 764 | 765 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 766 | terms["loss"] = self._vb_terms_bpd( 767 | model=model, 768 | x_start=x_start, 769 | x_t=x_t, 770 | t=t, 771 | clip_denoised=False, 772 | model_kwargs=model_kwargs, 773 | )["output"] 774 | if self.loss_type == LossType.RESCALED_KL: 775 | terms["loss"] *= self.num_timesteps 776 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 777 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 778 | 779 | if self.model_var_type in [ 780 | ModelVarType.LEARNED, 781 | ModelVarType.LEARNED_RANGE, 782 | ]: 783 | B, C = x_t.shape[:2] 784 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 785 | model_output, model_var_values = th.split(model_output, C, dim=1) 786 | # Learn the variance using the variational bound, but don't let 787 | # it affect our mean prediction. 788 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 789 | terms["vb"] = self._vb_terms_bpd( 790 | model=lambda *args, r=frozen_out: r, 791 | x_start=x_start, 792 | x_t=x_t, 793 | t=t, 794 | clip_denoised=False, 795 | )["output"] 796 | if self.loss_type == LossType.RESCALED_MSE: 797 | # Divide by 1000 for equivalence with initial implementation. 798 | # Without a factor of 1/1000, the VB term hurts the MSE term. 799 | terms["vb"] *= self.num_timesteps / 1000.0 800 | 801 | target = { 802 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 803 | x_start=x_start, x_t=x_t, t=t 804 | )[0], 805 | ModelMeanType.START_X: x_start, 806 | ModelMeanType.EPSILON: noise, 807 | }[self.model_mean_type] 808 | assert model_output.shape == target.shape == x_start.shape 809 | terms["mse"] = mean_flat((target - model_output) ** 2) 810 | if "vb" in terms: 811 | terms["loss"] = terms["mse"] + terms["vb"] 812 | else: 813 | terms["loss"] = terms["mse"] 814 | else: 815 | raise NotImplementedError(self.loss_type) 816 | 817 | return terms 818 | 819 | def _prior_bpd(self, x_start): 820 | """ 821 | Get the prior KL term for the variational lower-bound, measured in 822 | bits-per-dim. 823 | 824 | This term can't be optimized, as it only depends on the encoder. 825 | 826 | :param x_start: the [N x C x ...] tensor of inputs. 827 | :return: a batch of [N] KL values (in bits), one per batch element. 828 | """ 829 | batch_size = x_start.shape[0] 830 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 831 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 832 | kl_prior = normal_kl( 833 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 834 | ) 835 | return mean_flat(kl_prior) / np.log(2.0) 836 | 837 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 838 | """ 839 | Compute the entire variational lower-bound, measured in bits-per-dim, 840 | as well as other related quantities. 841 | 842 | :param model: the model to evaluate loss on. 843 | :param x_start: the [N x C x ...] tensor of inputs. 844 | :param clip_denoised: if True, clip denoised samples. 845 | :param model_kwargs: if not None, a dict of extra keyword arguments to 846 | pass to the model. This can be used for conditioning. 847 | 848 | :return: a dict containing the following keys: 849 | - total_bpd: the total variational lower-bound, per batch element. 850 | - prior_bpd: the prior term in the lower-bound. 851 | - vb: an [N x T] tensor of terms in the lower-bound. 852 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 853 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 854 | """ 855 | device = x_start.device 856 | batch_size = x_start.shape[0] 857 | 858 | vb = [] 859 | xstart_mse = [] 860 | mse = [] 861 | for t in list(range(self.num_timesteps))[::-1]: 862 | t_batch = th.tensor([t] * batch_size, device=device) 863 | noise = th.randn_like(x_start) 864 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 865 | # Calculate VLB term at the current timestep 866 | with th.no_grad(): 867 | out = self._vb_terms_bpd( 868 | model, 869 | x_start=x_start, 870 | x_t=x_t, 871 | t=t_batch, 872 | clip_denoised=clip_denoised, 873 | model_kwargs=model_kwargs, 874 | ) 875 | vb.append(out["output"]) 876 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 877 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 878 | mse.append(mean_flat((eps - noise) ** 2)) 879 | 880 | vb = th.stack(vb, dim=1) 881 | xstart_mse = th.stack(xstart_mse, dim=1) 882 | mse = th.stack(mse, dim=1) 883 | 884 | prior_bpd = self._prior_bpd(x_start) 885 | total_bpd = vb.sum(dim=1) + prior_bpd 886 | return { 887 | "total_bpd": total_bpd, 888 | "prior_bpd": prior_bpd, 889 | "vb": vb, 890 | "xstart_mse": xstart_mse, 891 | "mse": mse, 892 | } 893 | 894 | 895 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 896 | """ 897 | Extract values from a 1-D numpy array for a batch of indices. 898 | 899 | :param arr: the 1-D numpy array. 900 | :param timesteps: a tensor of indices into the array to extract. 901 | :param broadcast_shape: a larger shape of K dimensions with the batch 902 | dimension equal to the length of timesteps. 903 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 904 | """ 905 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 906 | while len(res.shape) < len(broadcast_shape): 907 | res = res[..., None] 908 | return res.expand(broadcast_shape) 909 | -------------------------------------------------------------------------------- /guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("sony-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import numpy as np 10 | from torch.nn.parameter import Parameter 11 | 12 | 13 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 14 | class ComplexConv2D(nn.Module): 15 | def __init__( 16 | self, in_channel, out_channel, ksize=None, nobias=False, init=False, **kwargs 17 | ): 18 | super().__init__() 19 | in_channels_r = in_channel // 2 20 | in_channels_i = in_channel // 2 21 | if in_channels_r == 0: 22 | in_channels_r = in_channel 23 | in_channels_i = in_channel 24 | out_channels_r = out_channel // 2 25 | out_channels_i = out_channel // 2 26 | 27 | self.conv_r = nn.Conv2d( 28 | in_channels=in_channels_r, 29 | out_channels=out_channels_r, 30 | kernel_size=ksize, 31 | bias=not (nobias), 32 | **kwargs, 33 | ) 34 | self.conv_i = nn.Conv2d( 35 | in_channels=in_channels_i, 36 | out_channels=out_channels_i, 37 | kernel_size=ksize, 38 | bias=not (nobias), 39 | **kwargs, 40 | ) 41 | if init and not (nobias): 42 | nn.init.kaiming_normal_(self.conv_r.weight) 43 | nn.init.kaiming_normal_(self.conv_i.weight) 44 | nn.init.zeros_(self.conv_r.bias) 45 | nn.init.zeros_(self.conv_i.bias) 46 | elif init and nobias: 47 | nn.init.kaiming_normal_(self.conv_r.weight) 48 | nn.init.kaiming_normal_(self.conv_i.weight) 49 | 50 | def forward(self, x): 51 | x_r = x[:, : x.shape[1] // 2, ...] 52 | x_i = x[:, x.shape[1] // 2 :, ...] 53 | 54 | mr_kr = self.conv_r(x_r) 55 | mi_ki = self.conv_i(x_i) 56 | mi_kr = self.conv_r(x_i) 57 | mr_ki = self.conv_i(x_r) 58 | 59 | ret = th.cat(((mr_kr - mi_ki), (mr_ki + mi_kr)), dim=1) 60 | return ret 61 | 62 | 63 | class ComplexDeconv2D(nn.Module): 64 | def __init__( 65 | self, 66 | in_channel, 67 | out_channel, 68 | ksize=None, 69 | stride=1, 70 | pad=0, 71 | output_pad=0, 72 | nobias=False, 73 | outsize=None, 74 | init=False, 75 | **kwargs, 76 | ): 77 | super().__init__() 78 | in_channels_r = in_channel // 2 79 | in_channels_i = in_channel // 2 80 | out_channels_r = out_channel // 2 81 | out_channels_i = out_channel // 2 82 | if in_channels_r == 0: 83 | in_channels_r = in_channel 84 | in_channels_i = in_channel 85 | elif out_channels_r == 0: 86 | out_channels_r = out_channel 87 | out_channels_i = out_channel 88 | 89 | self.deconv_r = nn.ConvTranspose2d( 90 | in_channels=in_channels_r, 91 | out_channels=out_channels_r, 92 | kernel_size=ksize, 93 | stride=stride, 94 | padding=pad, 95 | output_padding=output_pad, 96 | bias=not (nobias), 97 | ) 98 | self.deconv_i = nn.ConvTranspose2d( 99 | in_channels=in_channels_i, 100 | out_channels=out_channels_i, 101 | kernel_size=ksize, 102 | stride=stride, 103 | padding=pad, 104 | output_padding=output_pad, 105 | bias=not (nobias), 106 | ) 107 | 108 | if init and not (nobias): 109 | nn.init.kaiming_normal_(self.deconv_r.weight) 110 | nn.init.kaiming_normal_(self.deconv_i.weight) 111 | nn.init.zeros_(self.deconv_r.bias) 112 | nn.init.zeros_(self.deconv_i.bias) 113 | elif init and nobias: 114 | nn.init.kaiming_normal_(self.deconv_r.weight) 115 | nn.init.kaiming_normal_(self.deconv_i.weight) 116 | 117 | def forward(self, x): 118 | x_r = x[:, : x.shape[1] // 2, ...] 119 | x_i = x[:, x.shape[1] // 2 :, ...] 120 | 121 | mr_kr = self.deconv_r(x_r) 122 | mi_ki = self.deconv_i(x_i) 123 | mi_kr = self.deconv_r(x_i) 124 | mr_ki = self.deconv_i(x_r) 125 | 126 | ret = th.cat(((mr_kr - mi_ki), (mr_ki + mi_kr)), dim=1) 127 | return ret 128 | 129 | 130 | class ComplexBatchNormalization(nn.Module): 131 | def __init__( 132 | self, 133 | d_r, 134 | eps=1e-4, 135 | decay=0.9, 136 | initial_gamma_rr=None, 137 | initial_gamma_ii=None, 138 | initial_gamma_ri=None, 139 | initial_beta=None, 140 | initial_avg_mean=None, 141 | initial_avg_vrr=None, 142 | initial_avg_vii=None, 143 | initial_avg_vri=None, 144 | ): 145 | super().__init__() 146 | if d_r == 0: 147 | d_r = 1 # For Real-value U-net 148 | self._d_r = d_r # The 1/2 number of dimensions due to [Real, Imag]. 149 | self.eps = eps 150 | self.decay = decay 151 | 152 | if initial_beta is None: 153 | self.beta = Parameter(th.zeros(d_r * 2, dtype=th.float32)) 154 | else: 155 | self.beta = Parameter(initial_beta) 156 | 157 | if initial_gamma_rr is None: 158 | self.gamma_rr = Parameter(th.ones(d_r, dtype=th.float32) / (2 ** 0.5)) 159 | self.gamma_ii = Parameter(th.ones(d_r, dtype=th.float32) / (2 ** 0.5)) 160 | self.gamma_ri = Parameter(th.ones(d_r, dtype=th.float32)) 161 | else: 162 | raise 163 | self.gamma_rr = Parameter(initial_gamma_rr) 164 | self.gamma_ii = Parameter(initial_gamma_ii) 165 | self.gamma_ri = Parameter(initial_gamma_ri) 166 | 167 | if initial_avg_mean is None: 168 | self._avg_mean = th.zeros(2 * d_r, dtype=th.float32) 169 | self._avg_mean = self._avg_mean[None, :, None, None] 170 | else: 171 | self._avg_mean = th.tensor(initial_avg_mean) 172 | self.register_buffer("avg_mean", self._avg_mean) 173 | 174 | if initial_avg_vrr is None: 175 | self._avg_vrr = th.ones(d_r, dtype=th.float32) / (2 ** 0.5) 176 | self._avg_vii = th.ones(d_r, dtype=th.float32) / (2 ** 0.5) 177 | self._avg_vri = th.zeros(d_r, dtype=th.float32) 178 | self._avg_vrr = self._avg_vrr[None, :, None, None] 179 | self._avg_vii = self._avg_vii[None, :, None, None] 180 | self._avg_vri = self._avg_vri[None, :, None, None] 181 | else: 182 | self._avg_vrr = th.tensor(initial_avg_vrr) 183 | self._avg_vii = th.tensor(initial_avg_vii) 184 | self._avg_vri = th.tensor(initial_avg_vri) 185 | self.register_buffer("avg_vrr", self._avg_vrr) 186 | self.register_buffer("avg_vii", self._avg_vii) 187 | self.register_buffer("avg_vri", self._avg_vri) 188 | 189 | def forward(self, x, **kwargs): 190 | # assert x.shape[1] == self._d_r*2 191 | 192 | if self.training: 193 | # Calc. Statistic Values 194 | mean = th.mean(x, dim=(0, 2, 3), keepdims=True) 195 | x_centered = x - mean 196 | centered_squared = x_centered ** 2.0 197 | 198 | centered_real = x_centered[:, : self._d_r, Ellipsis] 199 | centered_imag = x_centered[:, self._d_r :, Ellipsis] 200 | centered_squared_real = centered_squared[:, : self._d_r, Ellipsis] 201 | centered_squared_imag = centered_squared[:, self._d_r :, Ellipsis] 202 | 203 | Vrr = ( 204 | th.mean(centered_squared_real, dim=(0, 2, 3), keepdims=True) + self.eps 205 | ) 206 | Vii = ( 207 | th.mean(centered_squared_imag, dim=(0, 2, 3), keepdims=True) + self.eps 208 | ) 209 | Vri = th.mean(centered_real * centered_imag, dim=(0, 2, 3), keepdims=True) 210 | 211 | # Saving Running Statistics 212 | self.avg_mean *= self.decay 213 | self.avg_mean += (1.0 - self.decay) * mean.data 214 | self.avg_vrr *= self.decay 215 | self.avg_vrr += (1.0 - self.decay) * Vrr.data 216 | self.avg_vii *= self.decay 217 | self.avg_vii += (1.0 - self.decay) * Vii.data 218 | self.avg_vri *= self.decay 219 | self.avg_vri += (1.0 - self.decay) * Vri.data 220 | else: 221 | # Calc. Statistic Values 222 | mean = self.avg_mean 223 | x_centered = x - mean 224 | centered_squared = x_centered ** 2.0 225 | 226 | centered_real = x_centered[:, : self._d_r, Ellipsis] 227 | centered_imag = x_centered[:, self._d_r :, Ellipsis] 228 | 229 | Vrr = self.avg_vrr 230 | Vii = self.avg_vii 231 | Vri = self.avg_vri 232 | 233 | # Inverse of Variance Matrix 234 | tau = Vrr + Vii 235 | delta = (Vrr * Vii) - (Vri ** 2.0) 236 | s = th.sqrt(delta + np.finfo("float32").eps) 237 | t = th.sqrt(tau + 2.0 * s + np.finfo("float32").eps) 238 | inverse_st = 1.0 / ((s * t) + np.finfo("float32").eps) 239 | Wrr = (Vii + s) * inverse_st 240 | Wii = (Vrr + s) * inverse_st 241 | Wri = -Vri * inverse_st 242 | 243 | # Complex Standardization 244 | cat_W_4_real = th.cat((Wrr, Wii), dim=1) 245 | cat_W_4_imag = th.cat((Wri, Wri), dim=1) 246 | rolled_x = th.cat((centered_imag, centered_real), dim=1) 247 | x_stdized = cat_W_4_real * x_centered + cat_W_4_imag * rolled_x 248 | 249 | # Re-scaling using Gamma, Beta 250 | x_stdized_r = x_stdized[:, : self._d_r, Ellipsis] 251 | x_stdized_i = x_stdized[:, self._d_r :, Ellipsis] 252 | rolled_x_stdized = th.cat((x_stdized_i, x_stdized_r), dim=1) 253 | broadcast_gamma_rr = th.broadcast_to( 254 | th.reshape(self.gamma_rr, (1, self._d_r, 1, 1)), x_stdized_r.shape 255 | ) 256 | broadcast_gamma_ii = th.broadcast_to( 257 | th.reshape(self.gamma_ii, (1, self._d_r, 1, 1)), x_stdized_r.shape 258 | ) 259 | broadcast_gamma_ri = th.broadcast_to( 260 | th.reshape(self.gamma_ri, (1, self._d_r, 1, 1)), x_stdized_r.shape 261 | ) 262 | cat_gamma_4_real = th.cat((broadcast_gamma_rr, broadcast_gamma_ii), dim=1) 263 | cat_gamma_4_imag = th.cat((broadcast_gamma_ri, broadcast_gamma_ri), dim=1) 264 | 265 | out = cat_gamma_4_real * x_stdized + cat_gamma_4_imag * rolled_x_stdized 266 | out += th.reshape(self.beta, (1, self._d_r * 2, 1, 1)) 267 | 268 | return out 269 | 270 | 271 | class SiLU(nn.Module): 272 | def forward(self, x): 273 | return x * th.sigmoid(x) 274 | 275 | 276 | class GroupNorm32(nn.GroupNorm): 277 | def forward(self, x): 278 | return super().forward(x.float()).type(x.dtype) 279 | 280 | 281 | def conv_nd(dims, *args, complex_conv=False, **kwargs): 282 | """ 283 | Create a 1D, 2D, or 3D convolution module. 284 | """ 285 | if dims == 1: 286 | return nn.Conv1d(*args, **kwargs) 287 | elif dims == 2 and complex_conv: 288 | return ComplexConv2D( 289 | in_channel=args[0], out_channel=args[1], ksize=args[2], **kwargs 290 | ) 291 | elif dims == 2 and not (complex_conv): 292 | return nn.Conv2d(*args, **kwargs) 293 | elif dims == 3: 294 | return nn.Conv3d(*args, **kwargs) 295 | raise ValueError(f"unsupported dimensions: {dims}") 296 | 297 | 298 | def linear(*args, **kwargs): 299 | """ 300 | Create a linear module. 301 | """ 302 | return nn.Linear(*args, **kwargs) 303 | 304 | 305 | def avg_pool_nd(dims, *args, **kwargs): 306 | """ 307 | Create a 1D, 2D, or 3D average pooling module. 308 | """ 309 | if dims == 1: 310 | return nn.AvgPool1d(*args, **kwargs) 311 | elif dims == 2: 312 | return nn.AvgPool2d(*args, **kwargs) 313 | elif dims == 3: 314 | return nn.AvgPool3d(*args, **kwargs) 315 | raise ValueError(f"unsupported dimensions: {dims}") 316 | 317 | 318 | def update_ema(target_params, source_params, rate=0.99): 319 | """ 320 | Update target parameters to be closer to those of source parameters using 321 | an exponential moving average. 322 | 323 | :param target_params: the target parameter sequence. 324 | :param source_params: the source parameter sequence. 325 | :param rate: the EMA rate (closer to 1 means slower). 326 | """ 327 | for targ, src in zip(target_params, source_params): 328 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 329 | 330 | 331 | def zero_module(module): 332 | """ 333 | Zero out the parameters of a module and return it. 334 | """ 335 | for p in module.parameters(): 336 | p.detach().zero_() 337 | return module 338 | 339 | 340 | def scale_module(module, scale): 341 | """ 342 | Scale the parameters of a module and return it. 343 | """ 344 | for p in module.parameters(): 345 | p.detach().mul_(scale) 346 | return module 347 | 348 | 349 | def mean_flat(tensor): 350 | """ 351 | Take the mean over all non-batch dimensions. 352 | """ 353 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 354 | 355 | 356 | def normalization(channels): 357 | """ 358 | Make a standard normalization layer. 359 | 360 | :param channels: number of input channels. 361 | :return: an nn.Module for normalization. 362 | """ 363 | return GroupNorm32(32, channels) 364 | 365 | 366 | def timestep_embedding(timesteps, dim, max_period=10000): 367 | """ 368 | Create sinusoidal timestep embeddings. 369 | 370 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 371 | These may be fractional. 372 | :param dim: the dimension of the output. 373 | :param max_period: controls the minimum frequency of the embeddings. 374 | :return: an [N x dim] Tensor of positional embeddings. 375 | """ 376 | half = dim // 2 377 | freqs = th.exp( 378 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 379 | ).to(device=timesteps.device) 380 | args = timesteps[:, None].float() * freqs[None] 381 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 382 | if dim % 2: 383 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 384 | return embedding 385 | 386 | 387 | def checkpoint(func, inputs, params, flag): 388 | """ 389 | Evaluate a function without caching intermediate activations, allowing for 390 | reduced memory at the expense of extra compute in the backward pass. 391 | 392 | :param func: the function to evaluate. 393 | :param inputs: the argument sequence to pass to `func`. 394 | :param params: a sequence of parameters `func` depends on but does not 395 | explicitly take as arguments. 396 | :param flag: if False, disable gradient checkpointing. 397 | """ 398 | if flag: 399 | args = tuple(inputs) + tuple(params) 400 | return CheckpointFunction.apply(func, len(inputs), *args) 401 | else: 402 | return func(*inputs) 403 | 404 | 405 | class CheckpointFunction(th.autograd.Function): 406 | @staticmethod 407 | def forward(ctx, run_function, length, *args): 408 | ctx.run_function = run_function 409 | ctx.input_tensors = list(args[:length]) 410 | ctx.input_params = list(args[length:]) 411 | with th.no_grad(): 412 | output_tensors = ctx.run_function(*ctx.input_tensors) 413 | return output_tensors 414 | 415 | @staticmethod 416 | def backward(ctx, *output_grads): 417 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 418 | with th.enable_grad(): 419 | # Fixes a bug where the first op in run_function modifies the 420 | # Tensor storage in place, which is not allowed for detach()'d 421 | # Tensors. 422 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 423 | output_tensors = ctx.run_function(*shallow_copies) 424 | input_grads = th.autograd.grad( 425 | output_tensors, 426 | ctx.input_tensors + ctx.input_params, 427 | output_grads, 428 | allow_unused=True, 429 | ) 430 | del ctx.input_tensors 431 | del ctx.input_params 432 | del output_tensors 433 | return (None, None) + input_grads 434 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | ) 25 | 26 | 27 | def classifier_defaults(): 28 | """ 29 | Defaults for classifier models. 30 | """ 31 | return dict( 32 | image_size=64, 33 | classifier_use_fp16=False, 34 | classifier_width=128, 35 | classifier_depth=2, 36 | classifier_attention_resolutions="32,16,8", # 16 37 | classifier_use_scale_shift_norm=True, # False 38 | classifier_resblock_updown=True, # False 39 | classifier_pool="attention", 40 | ) 41 | 42 | 43 | def model_and_diffusion_defaults(): 44 | """ 45 | Defaults for image training. 46 | """ 47 | res = dict( 48 | image_size=64, 49 | num_channels=128, 50 | num_res_blocks=2, 51 | num_heads=4, 52 | num_heads_upsample=-1, 53 | num_head_channels=-1, 54 | attention_resolutions="16,8", 55 | channel_mult="", 56 | dropout=0.0, 57 | class_cond=False, 58 | use_checkpoint=False, 59 | use_scale_shift_norm=True, 60 | resblock_updown=False, 61 | use_fp16=False, 62 | use_new_attention_order=False, 63 | complex_conv=False, 64 | ) 65 | res.update(diffusion_defaults()) 66 | return res 67 | 68 | 69 | def classifier_and_diffusion_defaults(): 70 | res = classifier_defaults() 71 | res.update(diffusion_defaults()) 72 | return res 73 | 74 | 75 | def create_model_and_diffusion( 76 | image_size, 77 | class_cond, 78 | learn_sigma, 79 | num_channels, 80 | num_res_blocks, 81 | channel_mult, 82 | num_heads, 83 | num_head_channels, 84 | num_heads_upsample, 85 | attention_resolutions, 86 | dropout, 87 | diffusion_steps, 88 | noise_schedule, 89 | timestep_respacing, 90 | use_kl, 91 | predict_xstart, 92 | rescale_timesteps, 93 | rescale_learned_sigmas, 94 | use_checkpoint, 95 | use_scale_shift_norm, 96 | resblock_updown, 97 | use_fp16, 98 | use_new_attention_order, 99 | complex_conv, 100 | ): 101 | model = create_model( 102 | image_size, 103 | num_channels, 104 | num_res_blocks, 105 | channel_mult=channel_mult, 106 | learn_sigma=learn_sigma, 107 | class_cond=class_cond, 108 | use_checkpoint=use_checkpoint, 109 | attention_resolutions=attention_resolutions, 110 | num_heads=num_heads, 111 | num_head_channels=num_head_channels, 112 | num_heads_upsample=num_heads_upsample, 113 | use_scale_shift_norm=use_scale_shift_norm, 114 | dropout=dropout, 115 | resblock_updown=resblock_updown, 116 | use_fp16=use_fp16, 117 | use_new_attention_order=use_new_attention_order, 118 | complex_conv=complex_conv, 119 | ) 120 | diffusion = create_gaussian_diffusion( 121 | steps=diffusion_steps, 122 | learn_sigma=learn_sigma, 123 | noise_schedule=noise_schedule, 124 | use_kl=use_kl, 125 | predict_xstart=predict_xstart, 126 | rescale_timesteps=rescale_timesteps, 127 | rescale_learned_sigmas=rescale_learned_sigmas, 128 | timestep_respacing=timestep_respacing, 129 | ) 130 | return model, diffusion 131 | 132 | 133 | def create_model( 134 | image_size, 135 | num_channels, 136 | num_res_blocks, 137 | channel_mult="", 138 | learn_sigma=False, 139 | class_cond=False, 140 | use_checkpoint=False, 141 | attention_resolutions="16", 142 | num_heads=1, 143 | num_head_channels=-1, 144 | num_heads_upsample=-1, 145 | use_scale_shift_norm=False, 146 | dropout=0, 147 | resblock_updown=False, 148 | use_fp16=False, 149 | use_new_attention_order=False, 150 | complex_conv=False, 151 | ): 152 | if channel_mult == "": 153 | if image_size == 512: 154 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 155 | elif image_size == 256: 156 | channel_mult = (1, 1, 2, 2, 4, 4) 157 | elif image_size == 128: 158 | channel_mult = (1, 1, 2, 3, 4) 159 | elif image_size == 64: 160 | channel_mult = (1, 2, 3, 4) 161 | else: 162 | raise ValueError(f"unsupported image size: {image_size}") 163 | else: 164 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 165 | 166 | attention_ds = [] 167 | for res in attention_resolutions.split(","): 168 | attention_ds.append(image_size // int(res)) 169 | 170 | return UNetModel( 171 | image_size=image_size, 172 | in_channels=3, 173 | model_channels=num_channels, 174 | out_channels=(3 if not learn_sigma else 6), 175 | num_res_blocks=num_res_blocks, 176 | attention_resolutions=tuple(attention_ds), 177 | dropout=dropout, 178 | channel_mult=channel_mult, 179 | num_classes=(NUM_CLASSES if class_cond else None), 180 | use_checkpoint=use_checkpoint, 181 | use_fp16=use_fp16, 182 | num_heads=num_heads, 183 | num_head_channels=num_head_channels, 184 | num_heads_upsample=num_heads_upsample, 185 | use_scale_shift_norm=use_scale_shift_norm, 186 | resblock_updown=resblock_updown, 187 | use_new_attention_order=use_new_attention_order, 188 | complex_conv=complex_conv, 189 | ) 190 | 191 | 192 | def create_classifier_and_diffusion( 193 | image_size, 194 | classifier_use_fp16, 195 | classifier_width, 196 | classifier_depth, 197 | classifier_attention_resolutions, 198 | classifier_use_scale_shift_norm, 199 | classifier_resblock_updown, 200 | classifier_pool, 201 | learn_sigma, 202 | diffusion_steps, 203 | noise_schedule, 204 | timestep_respacing, 205 | use_kl, 206 | predict_xstart, 207 | rescale_timesteps, 208 | rescale_learned_sigmas, 209 | ): 210 | classifier = create_classifier( 211 | image_size, 212 | classifier_use_fp16, 213 | classifier_width, 214 | classifier_depth, 215 | classifier_attention_resolutions, 216 | classifier_use_scale_shift_norm, 217 | classifier_resblock_updown, 218 | classifier_pool, 219 | ) 220 | diffusion = create_gaussian_diffusion( 221 | steps=diffusion_steps, 222 | learn_sigma=learn_sigma, 223 | noise_schedule=noise_schedule, 224 | use_kl=use_kl, 225 | predict_xstart=predict_xstart, 226 | rescale_timesteps=rescale_timesteps, 227 | rescale_learned_sigmas=rescale_learned_sigmas, 228 | timestep_respacing=timestep_respacing, 229 | ) 230 | return classifier, diffusion 231 | 232 | 233 | def create_classifier( 234 | image_size, 235 | classifier_use_fp16, 236 | classifier_width, 237 | classifier_depth, 238 | classifier_attention_resolutions, 239 | classifier_use_scale_shift_norm, 240 | classifier_resblock_updown, 241 | classifier_pool, 242 | ): 243 | if image_size == 512: 244 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 245 | elif image_size == 256: 246 | channel_mult = (1, 1, 2, 2, 4, 4) 247 | elif image_size == 128: 248 | channel_mult = (1, 1, 2, 3, 4) 249 | elif image_size == 64: 250 | channel_mult = (1, 2, 3, 4) 251 | else: 252 | raise ValueError(f"unsupported image size: {image_size}") 253 | 254 | attention_ds = [] 255 | for res in classifier_attention_resolutions.split(","): 256 | attention_ds.append(image_size // int(res)) 257 | 258 | return EncoderUNetModel( 259 | image_size=image_size, 260 | in_channels=3, 261 | model_channels=classifier_width, 262 | out_channels=1000, 263 | num_res_blocks=classifier_depth, 264 | attention_resolutions=tuple(attention_ds), 265 | channel_mult=channel_mult, 266 | use_fp16=classifier_use_fp16, 267 | num_head_channels=64, 268 | use_scale_shift_norm=classifier_use_scale_shift_norm, 269 | resblock_updown=classifier_resblock_updown, 270 | pool=classifier_pool, 271 | ) 272 | 273 | 274 | def sr_model_and_diffusion_defaults(): 275 | res = model_and_diffusion_defaults() 276 | res["large_size"] = 256 277 | res["small_size"] = 64 278 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 279 | for k in res.copy().keys(): 280 | if k not in arg_names: 281 | del res[k] 282 | return res 283 | 284 | 285 | def sr_create_model_and_diffusion( 286 | large_size, 287 | small_size, 288 | class_cond, 289 | learn_sigma, 290 | num_channels, 291 | num_res_blocks, 292 | num_heads, 293 | num_head_channels, 294 | num_heads_upsample, 295 | attention_resolutions, 296 | dropout, 297 | diffusion_steps, 298 | noise_schedule, 299 | timestep_respacing, 300 | use_kl, 301 | predict_xstart, 302 | rescale_timesteps, 303 | rescale_learned_sigmas, 304 | use_checkpoint, 305 | use_scale_shift_norm, 306 | resblock_updown, 307 | use_fp16, 308 | ): 309 | model = sr_create_model( 310 | large_size, 311 | small_size, 312 | num_channels, 313 | num_res_blocks, 314 | learn_sigma=learn_sigma, 315 | class_cond=class_cond, 316 | use_checkpoint=use_checkpoint, 317 | attention_resolutions=attention_resolutions, 318 | num_heads=num_heads, 319 | num_head_channels=num_head_channels, 320 | num_heads_upsample=num_heads_upsample, 321 | use_scale_shift_norm=use_scale_shift_norm, 322 | dropout=dropout, 323 | resblock_updown=resblock_updown, 324 | use_fp16=use_fp16, 325 | ) 326 | diffusion = create_gaussian_diffusion( 327 | steps=diffusion_steps, 328 | learn_sigma=learn_sigma, 329 | noise_schedule=noise_schedule, 330 | use_kl=use_kl, 331 | predict_xstart=predict_xstart, 332 | rescale_timesteps=rescale_timesteps, 333 | rescale_learned_sigmas=rescale_learned_sigmas, 334 | timestep_respacing=timestep_respacing, 335 | ) 336 | return model, diffusion 337 | 338 | 339 | def sr_create_model( 340 | large_size, 341 | small_size, 342 | num_channels, 343 | num_res_blocks, 344 | learn_sigma, 345 | class_cond, 346 | use_checkpoint, 347 | attention_resolutions, 348 | num_heads, 349 | num_head_channels, 350 | num_heads_upsample, 351 | use_scale_shift_norm, 352 | dropout, 353 | resblock_updown, 354 | use_fp16, 355 | ): 356 | _ = small_size # hack to prevent unused variable 357 | 358 | if large_size == 512: 359 | channel_mult = (1, 1, 2, 2, 4, 4) 360 | elif large_size == 256: 361 | channel_mult = (1, 1, 2, 2, 4, 4) 362 | elif large_size == 64: 363 | channel_mult = (1, 2, 3, 4) 364 | else: 365 | raise ValueError(f"unsupported large size: {large_size}") 366 | 367 | attention_ds = [] 368 | for res in attention_resolutions.split(","): 369 | attention_ds.append(large_size // int(res)) 370 | 371 | return SuperResModel( 372 | image_size=large_size, 373 | in_channels=3, 374 | model_channels=num_channels, 375 | out_channels=(3 if not learn_sigma else 6), 376 | num_res_blocks=num_res_blocks, 377 | attention_resolutions=tuple(attention_ds), 378 | dropout=dropout, 379 | channel_mult=channel_mult, 380 | num_classes=(NUM_CLASSES if class_cond else None), 381 | use_checkpoint=use_checkpoint, 382 | num_heads=num_heads, 383 | num_head_channels=num_head_channels, 384 | num_heads_upsample=num_heads_upsample, 385 | use_scale_shift_norm=use_scale_shift_norm, 386 | resblock_updown=resblock_updown, 387 | use_fp16=use_fp16, 388 | ) 389 | 390 | 391 | def create_gaussian_diffusion( 392 | *, 393 | steps=1000, 394 | learn_sigma=False, 395 | sigma_small=False, 396 | noise_schedule="linear", 397 | use_kl=False, 398 | predict_xstart=False, 399 | rescale_timesteps=False, 400 | rescale_learned_sigmas=False, 401 | timestep_respacing="", 402 | ): 403 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 404 | if use_kl: 405 | loss_type = gd.LossType.RESCALED_KL 406 | elif rescale_learned_sigmas: 407 | loss_type = gd.LossType.RESCALED_MSE 408 | else: 409 | loss_type = gd.LossType.MSE 410 | if not timestep_respacing: 411 | timestep_respacing = [steps] 412 | return SpacedDiffusion( 413 | use_timesteps=space_timesteps(steps, timestep_respacing), 414 | betas=betas, 415 | model_mean_type=( 416 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 417 | ), 418 | model_var_type=( 419 | ( 420 | gd.ModelVarType.FIXED_LARGE 421 | if not sigma_small 422 | else gd.ModelVarType.FIXED_SMALL 423 | ) 424 | if not learn_sigma 425 | else gd.ModelVarType.LEARNED_RANGE 426 | ), 427 | loss_type=loss_type, 428 | rescale_timesteps=rescale_timesteps, 429 | ) 430 | 431 | 432 | def add_dict_to_argparser(parser, default_dict): 433 | for k, v in default_dict.items(): 434 | v_type = type(v) 435 | if v is None: 436 | v_type = str 437 | elif isinstance(v, bool): 438 | v_type = str2bool 439 | parser.add_argument(f"--{k}", default=v, type=v_type) 440 | 441 | 442 | def args_to_dict(args, keys): 443 | return {k: getattr(args, k) for k in keys} 444 | 445 | 446 | def str2bool(v): 447 | """ 448 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 449 | """ 450 | if isinstance(v, bool): 451 | return v 452 | if v.lower() in ("yes", "true", "t", "y", "1"): 453 | return True 454 | elif v.lower() in ("no", "false", "f", "n", "0"): 455 | return False 456 | else: 457 | raise argparse.ArgumentTypeError("boolean value expected") 458 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | ): 42 | self.model = model 43 | self.diffusion = diffusion 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.microbatch = microbatch if microbatch > 0 else batch_size 47 | self.lr = lr 48 | self.ema_rate = ( 49 | [ema_rate] 50 | if isinstance(ema_rate, float) 51 | else [float(x) for x in ema_rate.split(",")] 52 | ) 53 | self.log_interval = log_interval 54 | self.save_interval = save_interval 55 | self.resume_checkpoint = resume_checkpoint 56 | self.use_fp16 = use_fp16 57 | self.fp16_scale_growth = fp16_scale_growth 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 59 | self.weight_decay = weight_decay 60 | self.lr_anneal_steps = lr_anneal_steps 61 | 62 | self.step = 0 63 | self.resume_step = 0 64 | self.global_batch = self.batch_size * dist.get_world_size() 65 | 66 | self.sync_cuda = th.cuda.is_available() 67 | 68 | self._load_and_sync_parameters() 69 | self.mp_trainer = MixedPrecisionTrainer( 70 | model=self.model, 71 | use_fp16=self.use_fp16, 72 | fp16_scale_growth=fp16_scale_growth, 73 | ) 74 | 75 | self.opt = AdamW( 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 77 | ) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.mp_trainer.master_params) 88 | for _ in range(len(self.ema_rate)) 89 | ] 90 | 91 | if th.cuda.is_available(): 92 | self.use_ddp = True 93 | self.ddp_model = DDP( 94 | self.model, 95 | device_ids=[dist_util.dev()], 96 | output_device=dist_util.dev(), 97 | broadcast_buffers=False, 98 | bucket_cap_mb=128, 99 | find_unused_parameters=False, 100 | ) 101 | else: 102 | if dist.get_world_size() > 1: 103 | logger.warn( 104 | "Distributed training requires CUDA. " 105 | "Gradients will not be synchronized properly!" 106 | ) 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | def _load_and_sync_parameters(self): 111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 112 | 113 | if resume_checkpoint: 114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 115 | if dist.get_rank() == 0: 116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 117 | self.model.load_state_dict( 118 | dist_util.load_state_dict( 119 | resume_checkpoint, map_location=dist_util.dev() 120 | ) 121 | ) 122 | 123 | dist_util.sync_params(self.model.parameters()) 124 | 125 | def _load_ema_parameters(self, rate): 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 127 | 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 130 | if ema_checkpoint: 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 133 | state_dict = dist_util.load_state_dict( 134 | ema_checkpoint, map_location=dist_util.dev() 135 | ) 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 137 | 138 | dist_util.sync_params(ema_params) 139 | return ema_params 140 | 141 | def _load_optimizer_state(self): 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 143 | opt_checkpoint = bf.join( 144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 145 | ) 146 | if bf.exists(opt_checkpoint): 147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 148 | state_dict = dist_util.load_state_dict( 149 | opt_checkpoint, map_location=dist_util.dev() 150 | ) 151 | self.opt.load_state_dict(state_dict) 152 | 153 | def run_loop(self): 154 | while ( 155 | not self.lr_anneal_steps 156 | or self.step + self.resume_step < self.lr_anneal_steps 157 | ): 158 | batch, cond = next(self.data) 159 | self.run_step(batch, cond) 160 | if self.step % self.log_interval == 0: 161 | logger.dumpkvs() 162 | if self.step % self.save_interval == 0: 163 | self.save() 164 | # Run for a finite amount of time in integration tests. 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 166 | return 167 | self.step += 1 168 | # Save the last checkpoint if it wasn't already saved. 169 | if (self.step - 1) % self.save_interval != 0: 170 | self.save() 171 | 172 | def run_step(self, batch, cond): 173 | self.forward_backward(batch, cond) 174 | took_step = self.mp_trainer.optimize(self.opt) 175 | if took_step: 176 | self._update_ema() 177 | self._anneal_lr() 178 | self.log_step() 179 | 180 | def forward_backward(self, batch, cond): 181 | self.mp_trainer.zero_grad() 182 | for i in range(0, batch.shape[0], self.microbatch): 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 184 | micro_cond = { 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) 186 | if th.is_tensor(v[i : i + self.microbatch]) 187 | else v[i : i + self.microbatch] 188 | for k, v in cond.items() 189 | } 190 | last_batch = (i + self.microbatch) >= batch.shape[0] 191 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 192 | 193 | compute_losses = functools.partial( 194 | self.diffusion.training_losses, 195 | self.ddp_model, 196 | micro, 197 | t, 198 | model_kwargs=micro_cond, 199 | ) 200 | 201 | if last_batch or not self.use_ddp: 202 | losses = compute_losses() 203 | else: 204 | with self.ddp_model.no_sync(): 205 | losses = compute_losses() 206 | 207 | if isinstance(self.schedule_sampler, LossAwareSampler): 208 | self.schedule_sampler.update_with_local_losses( 209 | t, losses["loss"].detach() 210 | ) 211 | 212 | loss = (losses["loss"] * weights).mean() 213 | log_loss_dict( 214 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 215 | ) 216 | self.mp_trainer.backward(loss) 217 | if ( 218 | hasattr(self.mp_trainer.model, "grad_clip") 219 | and self.mp_trainer.model.grad_clip is not None 220 | ): 221 | th.nn.utils.clip_grad_norm_( 222 | self.mp_trainer.master_params, 223 | max_norm=self.mp_trainer.model.grad_clip, 224 | ) 225 | 226 | def _update_ema(self): 227 | for rate, params in zip(self.ema_rate, self.ema_params): 228 | update_ema(params, self.mp_trainer.master_params, rate=rate) 229 | 230 | def _anneal_lr(self): 231 | if not self.lr_anneal_steps: 232 | return 233 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 234 | lr = self.lr * (1 - frac_done) 235 | for param_group in self.opt.param_groups: 236 | param_group["lr"] = lr 237 | 238 | def log_step(self): 239 | logger.logkv("step", self.step + self.resume_step) 240 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 241 | 242 | def save(self): 243 | def save_checkpoint(rate, params): 244 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 245 | if dist.get_rank() == 0: 246 | logger.log(f"saving model {rate}...") 247 | if not rate: 248 | filename = f"model{(self.step+self.resume_step):06d}.pt" 249 | else: 250 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 251 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 252 | th.save(state_dict, f) 253 | 254 | save_checkpoint(0, self.mp_trainer.master_params) 255 | for rate, params in zip(self.ema_rate, self.ema_params): 256 | save_checkpoint(rate, params) 257 | 258 | if dist.get_rank() == 0: 259 | with bf.BlobFile( 260 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 261 | "wb", 262 | ) as f: 263 | th.save(self.opt.state_dict(), f) 264 | 265 | dist.barrier() 266 | 267 | 268 | def parse_resume_step_from_filename(filename): 269 | """ 270 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 271 | checkpoint's number of steps. 272 | """ 273 | split = filename.split("model") 274 | if len(split) < 2: 275 | return 0 276 | split1 = split[-1].split(".")[0] 277 | try: 278 | return int(split1) 279 | except ValueError: 280 | return 0 281 | 282 | 283 | def get_blob_logdir(): 284 | # You can change this to be a separate path to save checkpoints to 285 | # a blobstore or some external drive. 286 | return logger.get_dir() 287 | 288 | 289 | def find_resume_checkpoint(): 290 | # On your infrastructure, you may want to override this to automatically 291 | # discover the latest checkpoint on your blob storage, etc. 292 | return None 293 | 294 | 295 | def find_ema_checkpoint(main_checkpoint, step, rate): 296 | if main_checkpoint is None: 297 | return None 298 | filename = f"ema_{rate}_{(step):06d}.pt" 299 | path = bf.join(bf.dirname(main_checkpoint), filename) 300 | if bf.exists(path): 301 | return path 302 | return None 303 | 304 | 305 | def log_loss_dict(diffusion, ts, losses): 306 | for key, values in losses.items(): 307 | logger.logkv_mean(key, values.mean().item()) 308 | # Log the quantiles (four quartiles, in particular). 309 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 310 | quartile = int(4 * sub_t / diffusion.num_timesteps) 311 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 312 | -------------------------------------------------------------------------------- /run_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # You can easily run our refiner upon the pre-processed VoiceBank+Demand results by DCUnet. 4 | 5 | # This example defaultly runs "Diffiner+". 6 | # Note that you need to add "--simple-diffiner" if you want to run just "Diffiner". 7 | # Please refer to the details described in README.md and Inference.md 8 | 9 | # Exit on error 10 | set -e 11 | set -o pipefail 12 | 13 | # Main storage directory. 14 | # If you start from downloading VoiceBank+Demand (VBD), you'll need disk space to dump the VBD and its wav. 15 | vbd_dir=data 16 | 17 | # Path to the python you'll use for the experiment. Defaults to the current python 18 | python_path=python 19 | 20 | # Start from downloading or not 21 | stage=0 # Controls from which stage to start 22 | 23 | # The index of GPU. If you set negative number, then inference will run on only CPU (w/o GPU). 24 | id=0 # $CUDA_VISIBLE_DEVICES 25 | 26 | if [[ $stage -le 0 ]]; then 27 | echo "Stage 0: Downloading VoiceBank+Demand and its pre-processed results by DCUnet into $vbd_dir" 28 | wget -c --tries=0 --read-timeout=20 https://datashare.ed.ac.uk/bitstream/handle/10283/2791/noisy_testset_wav.zip -P $vbd_dir 29 | mkdir -p $vbd_dir/logs 30 | unzip $vbd_dir/noisy_testset_wav.zip -d $vbd_dir >> $vbd_dir/logs/unzip_vbdtestset.log 31 | mkdir -p $vbd_dir/noisy_testset_wav/16kHz 32 | find $vbd_dir/noisy_testset_wav -name '*.wav' -printf '%f\n' | xargs -I % sox $vbd_dir/noisy_testset_wav/% -r 16000 $vbd_dir/noisy_testset_wav/16kHz/% 33 | wget -c --tries=0 --read-timeout=20 https://zenodo.org/record/7988790/files/proc_dcunet.tar.gz -P $vbd_dir 34 | tar xzvf $vbd_dir/proc_dcunet.tar.gz -C $vbd_dir >> $vbd_dir/logs/tar_procdcunet.log 35 | wget -c --tries=0 --read-timeout=20 https://zenodo.org/record/7988790/files/pretrained_diffiner_onVB.pt 36 | fi 37 | 38 | if [[ $stage -le 1 ]]; then 39 | echo "Stage 1: Evaluation" 40 | if [[ $id -lt 0 ]]; then 41 | $python_path scripts/run_refiner.py \ 42 | --no-gpu \ 43 | --root-noisy $vbd_dir/noisy_testset_wav/16kHz \ 44 | --root-proc $vbd_dir/proc_dcunet \ 45 | --model-path ./pretrained_diffiner_onVB.pt 46 | else 47 | CUDA_VISIBLE_DEVICES=$id $python_path scripts/run_refiner.py \ 48 | --root-noisy $vbd_dir/noisy_testset_wav/16kHz \ 49 | --root-proc $vbd_dir/proc_dcunet \ 50 | --model-path ./pretrained_diffiner_onVB.pt 51 | fi 52 | fi 53 | 54 | -------------------------------------------------------------------------------- /scripts/informed_denoiser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | 4 | from guided_diffusion.diffiner_util import create_model_and_diffusion 5 | 6 | 7 | def get_informed_denoiser(diffusion): 8 | 9 | """ 10 | get informed denoiser, which denoise data with given noise map. 11 | args: 12 | diffusion : instance of diffusion class. 13 | currently "diffusion" must be the instance of GaussianDiffusion in guided_diffusion/gaussian_diffusion.py. 14 | use_ddim (deprecated): flag whether DDIM sampling is used or not when sampling. 15 | eta_ddim (deprecated): The parameter for DDIM sampling. 16 | retunrs: 17 | informed_denoiser : function 18 | """ 19 | 20 | def informed_denoiser( 21 | model, 22 | noisy_data, 23 | noise_map, 24 | clip_denoised=False, 25 | model_kwargs=None, 26 | etaA_ddrm=1.0, 27 | etaB_ddrm=1.0, 28 | ): 29 | 30 | """ 31 | conduct the informed denoising with given noise map. 32 | args: 33 | model : score model, whose output should contain "pred_xstart" field. That it the estimation of x_0 given noisy input x_t. 34 | noisy_data : (bsz, c, h, w) noisy data to be denoised. 35 | noise_map : (bsz, c, h, w) the amplitude of Gaussian noise at each pixel. 36 | clip_denoised (bool) : if True, clip the denoised signal into [-1, 1]. 37 | model_kwargs (dict) : if not None, a dict of extra keyword arguments to 38 | pass to the model. This can be used for conditioning. 39 | etaA_ddrm (double) : Hyper parameter when sampling in (sigma_t < noise_map) 40 | if etaA_ddrm is near to 0.0, the method uses more information from the observation. 41 | if etaA_ddrm is near to 1.0, the method uses more information from the generative model. 42 | etaB_ddrm (double) : Hyper parameter when sampling in sigma_t > noise_map 43 | if etaB_ddrm is near to 0.0, the method uses more information from the generative model. 44 | if etaB_ddrm is near to 1.0, the method uses more information from the observation. 45 | """ 46 | 47 | device = next(model.parameters()).device 48 | 49 | etaA_ddrm = torch.tensor(etaA_ddrm, device=device).float() 50 | etaB_ddrm = torch.tensor(etaB_ddrm, device=device).float() 51 | 52 | x_t = torch.randn_like(noisy_data) 53 | 54 | indices = list(range(diffusion.num_timesteps))[::-1] 55 | b, c, h, w = noisy_data.shape 56 | 57 | for i in tqdm.tqdm(indices): 58 | t = torch.tensor([i] * b, device=device) 59 | with torch.no_grad(): 60 | 61 | # x_t = \sqrt(cumalpha_t) * x_0 + \sqrt(1.0 - cumalpha_t) * z 62 | sqrt_1_m_cumalpha = torch.sqrt( 63 | torch.tensor(1.0 - diffusion.alphas_cumprod_prev[i], device=device) 64 | ).float() 65 | sqrt_cumalpha = torch.sqrt( 66 | torch.tensor( 67 | diffusion.alphas_cumprod_prev[i], device=device 68 | ).float() 69 | ) 70 | mask_sigmat_is_larger = 1.0 * ( 71 | sqrt_1_m_cumalpha[None, None, None, None] 72 | > sqrt_cumalpha * noise_map 73 | ) 74 | 75 | scale_x0_at_t = sqrt_cumalpha 76 | scaled_noisy_data = noisy_data * scale_x0_at_t 77 | 78 | noise = torch.randn_like(x_t) 79 | nonzero_mask = ( 80 | (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))) 81 | ) # no noise when t == 0 82 | 83 | # Estimation of x_0 with diffusion model 84 | res_p_mean_variance = diffusion.p_mean_variance( 85 | model, 86 | x_t, 87 | t, 88 | clip_denoised=clip_denoised, 89 | model_kwargs=model_kwargs, 90 | ) 91 | 92 | est_x_0 = res_p_mean_variance["pred_xstart"] * scale_x0_at_t 93 | 94 | # For sigma_t > noise_map 95 | sigma_for_larger_sigmat = torch.sqrt( 96 | mask_sigmat_is_larger 97 | * ( 98 | sqrt_1_m_cumalpha[None, None, None, None] ** 2 99 | - (etaB_ddrm ** 2) * ((sqrt_cumalpha ** 2) * (noise_map ** 2)) 100 | ) 101 | ) 102 | data_for_larger_sigmat = ( 103 | (1.0 - etaB_ddrm) * est_x_0 104 | + etaB_ddrm * scaled_noisy_data 105 | + nonzero_mask * sigma_for_larger_sigmat * torch.randn_like(x_t) 106 | ) 107 | 108 | # For sigma_t < noise_map 109 | sigma_for_smaller_sigmat = torch.sqrt( 110 | (1.0 - mask_sigmat_is_larger) 111 | * (sqrt_1_m_cumalpha ** 2) 112 | * (etaA_ddrm ** 2) 113 | ) 114 | coef = (sqrt_1_m_cumalpha / sqrt_cumalpha) / (noise_map + 1e-5) 115 | data_for_smaller_sigmat = ( 116 | est_x_0 117 | + torch.sqrt(1 - etaA_ddrm ** 2) 118 | * coef 119 | * (scaled_noisy_data - est_x_0) 120 | ) + nonzero_mask * sigma_for_smaller_sigmat * torch.randn_like(x_t) 121 | 122 | x_t = ( 123 | data_for_smaller_sigmat * (1.0 - mask_sigmat_is_larger) 124 | + data_for_larger_sigmat * mask_sigmat_is_larger 125 | ) 126 | 127 | return x_t 128 | 129 | return informed_denoiser 130 | 131 | 132 | def get_improved_informed_denoiser(diffusion): 133 | 134 | """ 135 | get informed denoiser, which denoise data with given noise map. 136 | args: 137 | diffusion : instance of diffusion class. 138 | currently "diffusion" must be the instance of GaussianDiffusion in guided_diffusion/gaussian_diffusion.py. 139 | retunrs: 140 | informed_denoiser : function 141 | """ 142 | 143 | def informed_denoiser_v2( 144 | model, 145 | noisy_data, 146 | noise_map, 147 | clip_denoised=False, 148 | model_kwargs=None, 149 | etaA=1.0, 150 | etaB=1.0, 151 | etaC=0.0, 152 | inp_mask=None, 153 | etaD=1.0, 154 | ): 155 | 156 | """ 157 | conduct the informed denoising with given noise map. 158 | args: 159 | model : score model, whose output should contain "pred_xstart" field. That it the estimation of x_0 given noisy input x_t. 160 | noisy_data : (bsz, c, h, w) noisy data to be denoised. 161 | noise_map : (bsz, c, h, w) the amplitude of Gaussian noise at each pixel. 162 | clip_denoised (bool) : if True, clip the denoised signal into [-1, 1]. 163 | model_kwargs (dict) : if not None, a dict of extra keyword arguments to 164 | pass to the model. This can be used for conditioning. 165 | etaA (double) : Hyper parameter when sampling in (sigma_t < noise_map) 166 | etaB (double) : Hyper parameter when sampling in sigma_t > noise_map 167 | etaC (double) : Hyper parameter when sampling in sigma_t < noise_map 168 | inp_mask : (bsz, c, h, w) : if inp_mask == 1.0 with an element, the element is generated with DDIM sampling (with parameter etaD). 169 | """ 170 | 171 | assert (etaA ** 2 + etaC ** 2) <= 1.0 + 1e-10, "etaA^2 + etaC^2 <= 1.0" 172 | 173 | device = next(model.parameters()).device 174 | 175 | etaA = torch.tensor(etaA, device=device).float() 176 | etaB = torch.tensor(etaB, device=device).float() 177 | etaC = torch.tensor(etaC, device=device).float() 178 | 179 | x_t = torch.randn_like(noisy_data) 180 | if inp_mask is None: 181 | inp_mask = torch.zeros_like(noisy_data) 182 | 183 | indices = list(range(diffusion.num_timesteps))[::-1] 184 | b, c, h, w = noisy_data.shape 185 | 186 | for i in tqdm.tqdm(indices): 187 | t = torch.tensor([i] * b, device=device) 188 | with torch.no_grad(): 189 | 190 | # x_t = \sqrt(cumalpha_t) * x_0 + \sqrt(1.0 - cumalpha_t) * z 191 | sqrt_1_m_cumalpha = torch.sqrt( 192 | torch.tensor(1.0 - diffusion.alphas_cumprod_prev[i], device=device) 193 | ).float() 194 | sqrt_cumalpha = torch.sqrt( 195 | torch.tensor( 196 | diffusion.alphas_cumprod_prev[i], device=device 197 | ).float() 198 | ) 199 | mask_sigmat_is_larger = 1.0 * ( 200 | sqrt_1_m_cumalpha[None, None, None, None] 201 | > sqrt_cumalpha * noise_map 202 | ) 203 | 204 | scale_x0_at_t = sqrt_cumalpha 205 | scaled_noisy_data = noisy_data * scale_x0_at_t 206 | 207 | noise = torch.randn_like(x_t) 208 | nonzero_mask = ( 209 | (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))) 210 | ) # no noise when t == 0 211 | 212 | # Estimation of x_0 with diffusion model 213 | res_p_mean_variance = diffusion.p_mean_variance( 214 | model, 215 | x_t, 216 | t, 217 | clip_denoised=clip_denoised, 218 | model_kwargs=model_kwargs, 219 | ) 220 | 221 | est_x_0 = res_p_mean_variance["pred_xstart"] * scale_x0_at_t 222 | 223 | # For sigma_t > noise_map 224 | sigma_for_larger_sigmat = torch.sqrt( 225 | mask_sigmat_is_larger 226 | * ( 227 | sqrt_1_m_cumalpha[None, None, None, None] ** 2 228 | - (etaB ** 2) * ((sqrt_cumalpha ** 2) * (noise_map ** 2)) 229 | ) 230 | ) 231 | data_for_larger_sigmat = ( 232 | (1.0 - etaB) * est_x_0 233 | + etaB * scaled_noisy_data 234 | + nonzero_mask * sigma_for_larger_sigmat * torch.randn_like(x_t) 235 | ) 236 | 237 | # For sigma_t < noise_map 238 | eps = diffusion._predict_eps_from_xstart( 239 | x_t, t, res_p_mean_variance["pred_xstart"] 240 | ) # eps is not scaled (original scale) 241 | data_for_smaller_sigmat_A = eps * sqrt_1_m_cumalpha 242 | 243 | coef = (sqrt_1_m_cumalpha / sqrt_cumalpha) / (noise_map + 1e-5) 244 | data_for_smaller_sigmat_C = coef * (scaled_noisy_data - est_x_0) 245 | 246 | masked_v_a = ( 247 | data_for_smaller_sigmat_A * (1.0 - mask_sigmat_is_larger) + 1e-5 248 | ) 249 | masked_v_c = ( 250 | data_for_smaller_sigmat_C * (1.0 - mask_sigmat_is_larger) + 1e-5 251 | ) 252 | cos_sim = torch.nn.CosineSimilarity(dim=0)( 253 | torch.flatten(masked_v_a), torch.flatten(masked_v_c) 254 | ) 255 | 256 | # assert ((etaA**2 + etaC**2 + 2*etaA*etaC*cos_sim) <= 1.0 + 1e-10), "etaA^2 + etaC^2 + 2*etaA*etaC*cos_sim <= 1.0" 257 | if (etaA ** 2 + etaC ** 2 + 2 * etaA * etaC * cos_sim) > 1.0 + 1e-10: 258 | x_t = None 259 | break 260 | sigma_for_smaller_sigmat = torch.sqrt( 261 | (1.0 - mask_sigmat_is_larger) 262 | * (sqrt_1_m_cumalpha ** 2) 263 | * (1.0 - etaA ** 2 - etaC ** 2 - 2 * etaA * etaC * cos_sim) 264 | ) 265 | 266 | data_for_smaller_sigmat = ( 267 | est_x_0 268 | + etaA * data_for_smaller_sigmat_A 269 | + etaC * data_for_smaller_sigmat_C 270 | + nonzero_mask * sigma_for_smaller_sigmat * torch.randn_like(x_t) 271 | ) 272 | 273 | sigma_for_inp_mask = torch.sqrt( 274 | (1.0 - mask_sigmat_is_larger) 275 | * (sqrt_1_m_cumalpha ** 2) 276 | * (1.0 - etaD) 277 | ) 278 | data_for_inp_mask = ( 279 | est_x_0 280 | + etaD * data_for_smaller_sigmat_A 281 | + nonzero_mask * sigma_for_inp_mask * torch.randn_like(x_t) 282 | ) 283 | 284 | data_for_smaller_sigmat = ( 285 | 1.0 - inp_mask 286 | ) * data_for_smaller_sigmat + inp_mask * data_for_inp_mask 287 | 288 | x_t = ( 289 | data_for_smaller_sigmat * (1.0 - mask_sigmat_is_larger) 290 | + data_for_larger_sigmat * mask_sigmat_is_larger 291 | ) 292 | 293 | return x_t 294 | 295 | return informed_denoiser_v2 296 | -------------------------------------------------------------------------------- /scripts/run_refiner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import sys 5 | from pathlib import Path 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 8 | 9 | import yaml 10 | import numpy as np 11 | import torch as th 12 | 13 | from guided_diffusion import dist_util, logger 14 | from guided_diffusion.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | add_dict_to_argparser, 18 | args_to_dict, 19 | ) 20 | 21 | from guided_diffusion.diffiner_util import create_model_and_diffusion 22 | 23 | import torchaudio.transforms 24 | from speech_dataloader.data import load_datasets, AlignedDataset 25 | from informed_denoiser import get_informed_denoiser, get_improved_informed_denoiser 26 | 27 | 28 | def prepare_detaset(args, fft_settings): 29 | dataset_noisy_kwargs = { 30 | "root": Path(args.root_noisy), 31 | "seq_duration": None, 32 | "fft_settings": fft_settings, 33 | } 34 | 35 | dataset_proc_kwargs = { 36 | "root": Path(args.root_proc), 37 | "seq_duration": None, 38 | "fft_settings": fft_settings, 39 | } 40 | 41 | dataset_noisy = AlignedDataset( 42 | random_chunks=False, samples_per_track=1, **dataset_noisy_kwargs 43 | ) 44 | 45 | dataset_proc = AlignedDataset( 46 | random_chunks=False, samples_per_track=1, **dataset_proc_kwargs 47 | ) 48 | 49 | sampler_noisy = th.utils.data.DataLoader( 50 | dataset_noisy, batch_size=args.batch_size, shuffle=False, num_workers=0 51 | ) 52 | 53 | sampler_proc = th.utils.data.DataLoader( 54 | dataset_proc, batch_size=args.batch_size, shuffle=False, num_workers=0 55 | ) 56 | 57 | return sampler_noisy, sampler_proc 58 | 59 | 60 | def genwav_from_compspec(x, fft_settings): 61 | 62 | """ 63 | args: 64 | x : (b, 1, nb, nf) amplitude spectrogram 65 | x_phase : (b, 2, nb, nf) complex spectrogram whose phase is used to recover the signal x 66 | fft_settings: 67 | Returns : 68 | weaveform : (b, timedomain_samples) 69 | """ 70 | 71 | x_comp = (x[:, 0, :, :] + 1j * x[:, 1, :, :]).squeeze() 72 | 73 | waveform = th.istft( 74 | x_comp, 75 | n_fft=fft_settings["fftsize"], 76 | hop_length=fft_settings["shiftsize"], 77 | win_length=fft_settings["wsize"], 78 | window=th.hann_window(fft_settings["wsize"]), 79 | ) 80 | 81 | return waveform 82 | 83 | 84 | def main(): 85 | 86 | parser = argparse.ArgumentParser() 87 | 88 | # Diffiner / Diffiner+ (default: Diffiner+) 89 | parser.add_argument("--simple-diffiner", action="store_true") 90 | 91 | # Data & Model 92 | parser.add_argument( 93 | "--root-noisy", 94 | type=str, 95 | help="Path to a directory storing the target noisy (=unprocessed) speeches.", 96 | ) 97 | parser.add_argument( 98 | "--root-proc", 99 | type=str, 100 | help="Path to a directory storing the target pre-processed speeches (aiming to refine).", 101 | ) 102 | parser.add_argument( 103 | "--max-dur", 104 | type=float, 105 | default=10.0, 106 | help="Expected maximum duration of input speech. A longer speech than this duration will be automatically cut.", 107 | ) 108 | 109 | # Model 110 | parser.add_argument( 111 | "--model-path", 112 | type=str, 113 | help="Path to the pretrained model used to run diffiner+.", 114 | ) 115 | parser.add_argument("--image-size", type=int, default=256) 116 | parser.add_argument("--num-channels", type=int, default=128) 117 | parser.add_argument("--num-res-blocks", type=int, default=3) 118 | 119 | # Parameters 120 | parser.add_argument( 121 | "--etas", nargs="*", type=float, help="a list of variables (eta_a, b, c, and d)" 122 | ) 123 | parser.add_argument("--diffusion-steps", type=int, default=4000) 124 | parser.add_argument("--clip-denoised", action="store_true") 125 | parser.add_argument( 126 | "--noise-scheduler", 127 | type=str, 128 | default="linear", 129 | help="noise scheduler which was used to train the model", 130 | ) 131 | parser.add_argument( 132 | "--timestep-respacing", 133 | type=str, 134 | default="ddim200", 135 | help="Specify which time step to select and execute during sampling from the time steps used during training.", 136 | ) 137 | 138 | # Inference (bigger is faster till being finished) 139 | parser.add_argument("--batch-size", type=int, default=8) 140 | 141 | # Misc 142 | parser.add_argument("--no-gpu", action="store_true") 143 | parser.add_argument("--use-fp16", action="store_true") 144 | 145 | add_dict_to_argparser(parser, model_and_diffusion_defaults()) 146 | args = parser.parse_args() 147 | 148 | if args.no_gpu: 149 | device = "cpu" 150 | else: 151 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 152 | device = "cuda:0" 153 | 154 | # Set Etas; if they were not input by you, then the values used in our paper will be set. 155 | if args.simple_diffiner and args.etas is None: 156 | args.eta_a = 0.9 157 | args.eta_b = 0.9 158 | args.eta_c = None 159 | args.eta_d = None 160 | elif not (args.simple_diffiner) and args.etas is None: 161 | args.eta_a = 0.4 162 | args.eta_b = 0.6 163 | args.eta_c = 0.0 164 | args.eta_d = 0.6 165 | elif args.simple_diffiner and args.etas is not None: 166 | args.eta_a = args.etas[0] 167 | args.eta_b = args.etas[1] 168 | args.eta_c = None 169 | args.eta_d = None 170 | elif not (args.simple_diffiner) and args.etas is not None: 171 | args.eta_a = args.etas[0] 172 | args.eta_b = args.etas[1] 173 | args.eta_c = args.etas[2] 174 | args.eta_d = args.etas[3] 175 | 176 | # FFT settings 177 | fft_settings = {} 178 | fft_settings["shiftsize"] = 256 179 | fft_settings["wsize"] = 512 180 | fft_settings["fftsize"] = 512 181 | sr = 16000.0 182 | crit_frames = args.image_size # depending on network architecture and training 183 | if args.max_dur < (crit_frames * (fft_settings["shiftsize"] / sr)): 184 | fft_settings["nf"] = crit_frames # width 185 | else: 186 | block_num = args.max_dur / (crit_frames * (fft_settings["shiftsize"] / sr)) 187 | block_num = int(np.ceil(block_num)) 188 | fft_settings["nf"] = block_num * crit_frames 189 | fft_settings["spec_type"] = "complex" 190 | 191 | logger.configure() 192 | 193 | # Model preparation 194 | logger.log("creating model and diffusion...") 195 | dict_for_create_model = args_to_dict(args, model_and_diffusion_defaults().keys()) 196 | dict_for_create_model["image_channels"] = ( 197 | 2 if fft_settings["spec_type"] == "complex" else 1 198 | ) 199 | model, diffusion = create_model_and_diffusion(**dict_for_create_model) 200 | model.load_state_dict( 201 | dist_util.load_state_dict(args.model_path, map_location="cpu") 202 | ) 203 | model.to(device) 204 | if args.use_fp16: 205 | model.convert_to_fp16() 206 | model.eval() 207 | model_kwargs = {} 208 | 209 | # How to denoise: Diffiner / Diffiner+ 210 | if args.simple_diffiner: 211 | informed_denoiser = get_informed_denoiser(diffusion) 212 | else: 213 | informed_denoiser = get_improved_informed_denoiser(diffusion) 214 | 215 | # data preparation 216 | sampler_noisy, sampler_proc = prepare_detaset(args, fft_settings) 217 | 218 | # Make output dir under the "root_proc" 219 | if args.simple_diffiner: 220 | dir_name = "diffiner_etaA={}_etaB={}".format(args.eta_a, args.eta_b) 221 | else: 222 | dir_name = "diffiner+_etaA={}_etaB={}_etaC={}_etaD={}".format( 223 | args.eta_a, args.eta_b, args.eta_c, args.eta_d 224 | ) 225 | dir_audio = os.path.join(args.root_proc, dir_name) 226 | dir_audio = os.path.abspath(dir_audio) 227 | os.makedirs(dir_audio, exist_ok=True) 228 | 229 | # Save settings 230 | with open(os.path.join(dir_audio, "config.yml"), "w") as yf: 231 | yaml.dump( 232 | { 233 | "Model": { 234 | "path": os.path.abspath(args.model_path), 235 | "image_size": args.image_size, 236 | "num_channels": args.num_channels, 237 | "num_res_blocks": args.num_res_blocks, 238 | }, 239 | "Used data": { 240 | "noisy": args.root_noisy, 241 | "pre-processed": args.root_proc, 242 | }, 243 | "fft_settings": fft_settings, 244 | "DDRM_settings": { 245 | "type": "diffiner" if args.simple_diffiner else "diffiner_plus", 246 | "eta_A": args.eta_a, 247 | "eta_B": args.eta_b, 248 | "eta_C": args.eta_c, 249 | "eta_D": args.eta_d, 250 | "timestep_respacing": args.timestep_respacing, 251 | "diffusion_steps": args.diffusion_steps, 252 | "noise_schedule": args.noise_schedule, 253 | }, 254 | }, 255 | yf, 256 | default_flow_style=False, 257 | ) 258 | 259 | # Run diffiner / diffiner+ 260 | n_wav_saved = 0 261 | for (x_noisy, info_noisy), (x_proc, info_proc) in zip(sampler_noisy, sampler_proc): 262 | 263 | x_noisy = x_noisy.to(device) 264 | x_proc = x_proc.to(device) 265 | assert x_noisy.shape == x_proc.shape 266 | 267 | n_batch, _, _, nf = x_proc.shape 268 | 269 | noise_stft = x_noisy - x_proc 270 | noise_map = ( 271 | noise_stft.pow(2).sum(1, keepdim=True).pow(1.0 / 2.0).repeat((1, 2, 1, 1)) 272 | ) 273 | 274 | if not (args.simple_diffiner) and args.eta_a ** 2 + args.eta_c ** 2 > 1.0: 275 | print("args.eta_a={} and eta_c={}:".format(args.eta_a, args.eta_c)) 276 | print("args.eta_a^2 + args.eta_c^2 should be less than 1.0, so skip") 277 | continue 278 | 279 | if args.simple_diffiner: 280 | x_deno = informed_denoiser( 281 | model, 282 | x_noisy, 283 | noise_map, 284 | clip_denoised=args.clip_denoised, 285 | model_kwargs=model_kwargs, 286 | etaA_ddrm=args.eta_a, 287 | etaB_ddrm=args.eta_b, 288 | ) 289 | else: 290 | x_deno = informed_denoiser( 291 | model, 292 | x_noisy, 293 | noise_map, 294 | clip_denoised=args.clip_denoised, 295 | model_kwargs=model_kwargs, 296 | etaA=args.eta_a, 297 | etaB=args.eta_a, 298 | etaC=args.eta_c, 299 | inp_mask=None, 300 | etaD=args.eta_d, 301 | ) 302 | cat_x_deno = th.cat((th.zeros(n_batch, 2, 1, nf), x_deno.to("cpu")), dim=2) 303 | waveform_deno = genwav_from_compspec(cat_x_deno, fft_settings)[ 304 | :, None, : 305 | ] # -> n_batch x 1 x n_sample 306 | for i_batch in range(n_batch): 307 | fname = info_proc["file_path"][i_batch].split("/")[-1] 308 | n_sample = min(info_proc["samples"][i_batch].item(), waveform_deno.shape[2]) 309 | torchaudio.save( 310 | os.path.join(dir_audio, fname), 311 | waveform_deno[i_batch, :, :n_sample], 312 | sample_rate=16000, 313 | encoding="PCM_S", 314 | bits_per_sample=16, 315 | ) 316 | n_wav_saved += n_batch 317 | 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train Diffiner 3 | """ 4 | 5 | import argparse 6 | 7 | import os 8 | import sys 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 11 | 12 | from guided_diffusion import dist_util, logger 13 | from guided_diffusion.image_datasets import load_data 14 | from guided_diffusion.resample import create_named_schedule_sampler 15 | from guided_diffusion.script_util import ( 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | args_to_dict, 19 | add_dict_to_argparser, 20 | ) 21 | from guided_diffusion.train_util import TrainLoop 22 | 23 | from guided_diffusion.diffiner_util import create_model_and_diffusion 24 | 25 | from datetime import datetime 26 | from pytz import timezone 27 | 28 | from speech_dataloader import data, utils 29 | from speech_dataloader.data import load_datasets 30 | from torchviz import make_dot 31 | from torchinfo import summary 32 | 33 | import torch 34 | 35 | 36 | def main(): 37 | 38 | dt_now = datetime.now(timezone("Asia/Tokyo")) 39 | dt_str = dt_now.strftime("%Y_%m%d_%H%M") 40 | log_dir = "./exp_logs_diffiner_train_" + dt_str 41 | if not os.path.exists(log_dir): 42 | os.mkdir(log_dir) 43 | 44 | os.environ["OPENAI_LOGDIR"] = log_dir 45 | 46 | parser = create_argparser() 47 | parser.add_argument( 48 | "--complex-conv", 49 | action="store_true", 50 | default=False, 51 | help="Using ComplexConv2D or not. Defauls is False.", 52 | ) 53 | args = parser.parse_args() 54 | 55 | dist_util.setup_dist() 56 | logger.configure() 57 | 58 | logger.log("creating model and diffusion...") 59 | dict_for_create_model = args_to_dict(args, model_and_diffusion_defaults().keys()) 60 | if args.spec_type == "amplitude": 61 | assert "Currently, only complex spectrogram is supported." 62 | dict_for_create_model["image_channels"] = 1 63 | elif args.spec_type == "complex": 64 | dict_for_create_model["image_channels"] = 2 65 | model, diffusion = create_model_and_diffusion(**dict_for_create_model) 66 | 67 | # Show the model summary 68 | summary( 69 | model, 70 | input_size=[(1, 2, 256, 256), (1,)], 71 | device="cpu", 72 | ) 73 | 74 | model.to(dist_util.dev()) 75 | 76 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 77 | 78 | logger.log("creating data loader...") 79 | 80 | # FFT settings 81 | fft_settings = {} 82 | fft_settings["shiftsize"] = 256 83 | fft_settings["wsize"] = 512 84 | fft_settings["fftsize"] = 512 # height = ['fftsize']/2.0 85 | fft_settings["nf"] = 256 # width 86 | fft_settings["spec_type"] = args.spec_type 87 | 88 | train_dataset, args = load_datasets(parser, args, fft_settings) 89 | train_dataset.seq_duration = args.seq_dur 90 | train_dataset.random_chunks = True 91 | 92 | train_sampler = torch.utils.data.DataLoader( 93 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0 94 | ) 95 | 96 | data = create_generator_from_dataloader(train_sampler) 97 | 98 | logger.log("training...") 99 | TrainLoop( 100 | model=model, 101 | diffusion=diffusion, 102 | data=data, 103 | batch_size=args.batch_size, 104 | microbatch=args.microbatch, 105 | lr=args.lr, 106 | ema_rate=args.ema_rate, 107 | log_interval=args.log_interval, 108 | save_interval=args.save_interval, 109 | resume_checkpoint=args.resume_checkpoint, 110 | use_fp16=args.use_fp16, 111 | fp16_scale_growth=args.fp16_scale_growth, 112 | schedule_sampler=schedule_sampler, 113 | weight_decay=args.weight_decay, 114 | lr_anneal_steps=args.lr_anneal_steps, 115 | ).run_loop() 116 | 117 | 118 | def create_generator_from_dataloader(dataloader): 119 | 120 | while True: 121 | yield from dataloader 122 | 123 | 124 | def create_argparser(): 125 | defaults = dict( 126 | root="", 127 | schedule_sampler="uniform", 128 | lr=1e-4, 129 | weight_decay=0.0, 130 | lr_anneal_steps=0, 131 | batch_size=1, 132 | microbatch=-1, # -1 disables microbatches 133 | ema_rate="0.9999", # comma-separated list of EMA values 134 | log_interval=10, 135 | save_interval=10000, 136 | resume_checkpoint="", 137 | use_fp16=False, 138 | fp16_scale_growth=1e-3, 139 | target="vocals", 140 | seq_dur=4.2, # Duration of <=0.0 will result in the full audio 141 | samples_per_track=1, # The number of trimming pathes per a track. 142 | spec_type="complex", 143 | ) 144 | defaults.update(model_and_diffusion_defaults()) 145 | parser = argparse.ArgumentParser() 146 | add_dict_to_argparser(parser, defaults) 147 | return parser 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /speech_dataloader/data.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 4 | 5 | from speech_dataloader.utils import load_audio, load_info 6 | from pathlib import Path 7 | import numpy as np 8 | import torch.utils.data 9 | import argparse 10 | import random 11 | import torch 12 | import torch.nn as nn 13 | import tqdm 14 | import glob 15 | import matplotlib 16 | 17 | matplotlib.use("Agg") 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | class Compose(object): 22 | """Composes several augmentation transforms. 23 | Args: 24 | augmentations: list of augmentations to compose. 25 | """ 26 | 27 | def __init__(self, transforms): 28 | self.transforms = transforms 29 | 30 | def __call__(self, audio): 31 | for t in self.transforms: 32 | audio = t(audio) 33 | return audio 34 | 35 | 36 | def _augment_gain(audio, low=0.25, high=1.25): 37 | """Applies a random gain between `low` and `high`""" 38 | g = low + torch.rand(1) * (high - low) 39 | return audio * g 40 | 41 | 42 | def _augment_channelswap(audio): 43 | """Swap channels of stereo signals with a probability of p=0.5""" 44 | if audio.shape[0] == 2 and torch.FloatTensor(1).uniform_() < 0.5: 45 | return torch.flip(audio, [0]) 46 | else: 47 | return audio 48 | 49 | 50 | def load_datasets(parser, args, fft_settings): 51 | """Loads the specified dataset from commandline arguments 52 | 53 | Returns: 54 | train_dataset, validation_dataset 55 | """ 56 | parser.add_argument("--input-file", type=str) 57 | parser.add_argument("--output-file", type=str) 58 | 59 | args = parser.parse_args() 60 | # set output target to basename of output file 61 | 62 | dataset_kwargs = { 63 | "root": Path(args.root), 64 | "seq_duration": args.seq_dur, 65 | "input_file": args.input_file, 66 | "output_file": args.output_file, 67 | "fft_settings": fft_settings, 68 | } 69 | 70 | train_dataset = AlignedDataset( 71 | random_chunks=True, samples_per_track=args.samples_per_track, **dataset_kwargs 72 | ) 73 | 74 | # valid_dataset = AlignedDataset( 75 | # **dataset_kwargs 76 | # ) 77 | 78 | # return train_dataset, valid_dataset, args 79 | return train_dataset, args 80 | 81 | 82 | class AlignedDataset(torch.utils.data.Dataset): 83 | def __init__( 84 | self, 85 | root, 86 | fft_settings, 87 | samples_per_track, 88 | input_file=None, 89 | output_file=None, 90 | seq_duration=None, 91 | random_chunks=False, 92 | get_spec=True, 93 | sample_rate=16000, # 44100 94 | ): 95 | """A dataset of that assumes multiple track folders 96 | where each track includes and input and an output file 97 | which directly corresponds to the the input and the 98 | output of the model. This dataset is the most basic of 99 | all datasets provided here, due to the least amount of 100 | preprocessing, it is also the fastest option, however, 101 | it lacks any kind of source augmentations or custum mixing. 102 | 103 | Typical use cases: 104 | 105 | * Source Separation (Mixture -> Target) 106 | * Denoising (Noisy -> Clean) 107 | * Bandwidth Extension (Low Bandwidth -> High Bandwidth) 108 | 109 | Example 110 | ======= 111 | data/train/01/mixture.wav --> input 112 | data/train/01/vocals.wav ---> output 113 | 114 | """ 115 | self.root = Path(root).expanduser() 116 | self.samples_per_track = samples_per_track 117 | 118 | self._nf = fft_settings["nf"] 119 | # self._shift_nf = fft_settings['shift_nf'] 120 | self._wsize = fft_settings["wsize"] 121 | self._fftsize = fft_settings["fftsize"] 122 | self._shiftsize = fft_settings["shiftsize"] 123 | 124 | self.sample_rate = sample_rate 125 | self.seq_duration = seq_duration 126 | self.random_chunks = random_chunks 127 | self.get_spec = get_spec 128 | self.spec_type = fft_settings["spec_type"] 129 | if self.get_spec: 130 | self.window = nn.Parameter( 131 | torch.hann_window(self._wsize), requires_grad=False 132 | ) 133 | 134 | # set the input and output files (accept glob) 135 | self.input_file = input_file 136 | self.output_file = output_file 137 | self.tuple_paths = sorted(list(self._get_paths())) 138 | if not self.tuple_paths: 139 | raise RuntimeError("Dataset is empty, please check parameters") 140 | 141 | def __getitem__(self, index): 142 | input_path, _ = self.tuple_paths[index // self.samples_per_track] 143 | 144 | input_info = load_info(input_path) 145 | input_info["file_path"] = input_path 146 | if self.random_chunks: 147 | duration = input_info["duration"] 148 | # output_info = load_info(output_path) 149 | # duration = min(input_info['duration'], output_info['duration']) 150 | start = random.uniform(0, duration - self.seq_duration) 151 | else: 152 | start = 0 153 | duration = self.seq_duration 154 | 155 | X_audio = load_audio(input_path, start=start, dur=self.seq_duration) 156 | 157 | # Convert to mono. 158 | if X_audio.shape[0] > 1: 159 | X_audio = X_audio[0, :].unsqueeze(0) 160 | 161 | # Applying STFT 162 | if self.get_spec: 163 | X_stft = torch.stft( 164 | X_audio, 165 | n_fft=self._fftsize, 166 | hop_length=self._shiftsize, 167 | window=self.window, 168 | center=True, 169 | normalized=False, 170 | onesided=True, 171 | pad_mode="reflect", 172 | return_complex=False, 173 | ) 174 | 175 | if self.spec_type == "amplitude": 176 | X_spec = X_stft.pow(2).sum(-1).pow(1.0 / 2.0) 177 | elif self.spec_type == "complex": 178 | _, _nb, _nf, _ = X_stft.shape 179 | X_spec = torch.permute(torch.squeeze(X_stft), (2, 0, 1)) 180 | elif self.spec_type == "power": 181 | X_spec = X_stft.pow(2).sum(-1) 182 | else: 183 | raise NotImplementedError(self.spec_type) 184 | 185 | # Confirm by visualizing spectrogram 186 | # tmp = X_spec[0, ...].data 187 | # plt.imshow(10.*np.log( tmp + 1.0 ), aspect="auto", cmap="jet") 188 | # plt.savefig('conf.png') 189 | 190 | # Cut DC component 191 | X_spec = X_spec[:, 1::] 192 | # Change the length to self._nf 193 | if X_spec.shape[2] >= self._nf: 194 | X_spec = X_spec[:, :, : self._nf] 195 | else: 196 | _b, _nb, _nf = X_spec.shape 197 | X_spec_ = torch.zeros( 198 | (_b, _nb, self._nf), dtype=X_spec.dtype, device=X_spec.device 199 | ) 200 | X_spec_[:, :, :_nf] = X_spec 201 | X_spec = X_spec_ 202 | 203 | # print(X_audio.shape, X_spec.shape) # torch.Size([1, 80000]) torch.Size([1, 256, 256]) 204 | # return X_spec, {"y": duration} # Y_audio 205 | return X_spec, input_info # Y_audio 206 | else: 207 | return X_audio, input_info # Y_audio 208 | 209 | def __len__(self): 210 | return len(self.tuple_paths) * self.samples_per_track 211 | 212 | def _get_paths(self): 213 | """Loads input tracks""" 214 | p = Path(self.root) # , self.split) 215 | 216 | self.s_list = glob.glob(str(p) + "/**/*.wav", recursive=True) 217 | # print(len(self.s_list)) # 4000 218 | for input_path in tqdm.tqdm(self.s_list): 219 | # output_path = list(track_path.glob(self.output_file)) 220 | if self.seq_duration is not None: 221 | input_info = load_info(input_path) 222 | # output_info = load_info(output_path[0]) 223 | # min_duration = min( 224 | # input_info['duration'], output_info['duration'] 225 | # ) 226 | min_duration = input_info["duration"] 227 | 228 | # check if both targets are available in the subfolder 229 | if min_duration > self.seq_duration: 230 | yield input_path, None # output_path[0] 231 | else: 232 | # raise ValueError("TBD func.") 233 | yield input_path, None # output_path[0] 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser(description="Dataloader for DGM SE model") 238 | parser.add_argument("--root", type=str, help="root path of dataset") 239 | 240 | parser.add_argument("--target", type=str, default="vocals") 241 | 242 | # I/O Parameters 243 | parser.add_argument( 244 | "--seq-dur", 245 | type=float, 246 | default=4.2, 247 | help="Duration of <=0.0 will result in the full audio", 248 | ) 249 | 250 | parser.add_argument( 251 | "--samples-per-track", 252 | type=int, 253 | default=1, 254 | help="The number of trimming patches per a track.", 255 | ) 256 | 257 | parser.add_argument("--batch-size", type=int, default=16) 258 | 259 | args, _ = parser.parse_known_args() 260 | 261 | # FFT settings 262 | fft_settings = {} 263 | fft_settings["shiftsize"] = 256 264 | fft_settings["wsize"] = 512 265 | fft_settings["fftsize"] = 512 # height = ['fftsize']/2.0 266 | fft_settings["nf"] = 256 # width 267 | fft_settings["spec_type"] = "complex" 268 | # fft_settings['shift_nf'] = 16 269 | 270 | train_dataset, args = load_datasets(parser, args, fft_settings) 271 | 272 | # Iterate over training dataset 273 | # total_training_duration = 0 274 | # for k in tqdm.tqdm(range(len(train_dataset))): 275 | # x, dur = train_dataset[k] 276 | # total_training_duration += dur # / train_dataset.sample_rate 277 | # print("Total training duration (h): ", total_training_duration / 3600) 278 | # print("Number of train samples: ", len(train_dataset)) 279 | 280 | # iterate over dataloader 281 | train_dataset.seq_duration = args.seq_dur 282 | train_dataset.random_chunks = True # if true, trimming patch position randomly. 283 | 284 | # Test sampler 285 | train_sampler = torch.utils.data.DataLoader( 286 | train_dataset, 287 | batch_size=args.batch_size, 288 | shuffle=True, 289 | num_workers=0, 290 | ) 291 | 292 | cnt = 0 293 | for x, y in tqdm.tqdm(train_sampler): 294 | cnt += 1 295 | print("Iter#{}, sample shape={}".format(cnt, x.shape)) 296 | pass 297 | -------------------------------------------------------------------------------- /speech_dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | import os 4 | import numpy as np 5 | 6 | 7 | def _sndfile_available(): 8 | try: 9 | import soundfile 10 | except ImportError: 11 | return False 12 | 13 | return True 14 | 15 | 16 | def _torchaudio_available(): 17 | try: 18 | import torchaudio 19 | except ImportError: 20 | return False 21 | 22 | return False 23 | 24 | 25 | def get_loading_backend(): 26 | if _torchaudio_available(): 27 | return torchaudio_loader 28 | 29 | if _sndfile_available(): 30 | return soundfile_loader 31 | 32 | 33 | def get_info_backend(): 34 | if _torchaudio_available(): 35 | return torchaudio_info 36 | 37 | if _sndfile_available(): 38 | return soundfile_info 39 | 40 | 41 | def soundfile_info(path): 42 | import soundfile 43 | 44 | info = {} 45 | sfi = soundfile.info(path) 46 | info["samplerate"] = sfi.samplerate 47 | info["samples"] = round(sfi.duration * sfi.samplerate) 48 | info["duration"] = sfi.duration 49 | return info 50 | 51 | 52 | def soundfile_loader(path, start=0, dur=None): 53 | import soundfile 54 | 55 | # get metadata 56 | info = soundfile_info(path) 57 | start = int(start * info["samplerate"]) 58 | # check if dur is none 59 | if dur: 60 | # stop in soundfile is calc in samples, not seconds 61 | stop = start + int(dur * info["samplerate"]) 62 | else: 63 | # set to None for reading complete file 64 | stop = None 65 | audio, _ = soundfile.read(path, always_2d=True, start=start, stop=stop) 66 | return torch.FloatTensor(audio.T) 67 | 68 | 69 | def torchaudio_info(path): 70 | import torchaudio 71 | 72 | # get length of file in samples 73 | info = {} 74 | si, _ = torchaudio.info(str(path)) 75 | info["samplerate"] = si.rate 76 | info["samples"] = round(si.length / si.channels) 77 | info["duration"] = info["samples"] / si.rate 78 | return info 79 | 80 | 81 | def torchaudio_loader(path, start=0, dur=None): 82 | import torchaudio 83 | 84 | info = torchaudio_info(path) 85 | # loads the full track duration 86 | if dur is None: 87 | sig, rate = torchaudio.load(path) 88 | return sig 89 | # otherwise loads a random excerpt 90 | else: 91 | num_frames = int(dur * info["samplerate"]) 92 | offset = int(start * info["samplerate"]) 93 | sig, rate = torchaudio.load(path, num_frames=num_frames, offset=offset) 94 | return sig 95 | 96 | 97 | def load_info(path): 98 | loader = get_info_backend() 99 | return loader(path) 100 | 101 | 102 | def load_audio(path, start=0, dur=None): 103 | loader = get_loading_backend() 104 | return loader(path, start=start, dur=dur) 105 | 106 | 107 | def bandwidth_to_max_bin(rate, n_fft, bandwidth): 108 | freqs = np.linspace(0, float(rate) / 2, n_fft // 2 + 1, endpoint=True) 109 | 110 | return np.max(np.where(freqs <= bandwidth)[0]) + 1 111 | 112 | 113 | def save_checkpoint(state, is_best, path, target): 114 | # save full checkpoint including optimizer 115 | torch.save(state, os.path.join(path, target + ".chkpnt")) 116 | if is_best: 117 | # save just the weights 118 | torch.save(state["state_dict"], os.path.join(path, target + ".pth")) 119 | 120 | 121 | class AverageMeter(object): 122 | """Computes and stores the average and current value""" 123 | 124 | def __init__(self): 125 | self.reset() 126 | 127 | def reset(self): 128 | self.val = 0 129 | self.avg = 0 130 | self.sum = 0 131 | self.count = 0 132 | 133 | def update(self, val, n=1): 134 | self.val = val 135 | self.sum += val * n 136 | self.count += n 137 | self.avg = self.sum / self.count 138 | 139 | 140 | class EarlyStopping(object): 141 | def __init__(self, mode="min", min_delta=0, patience=10): 142 | self.mode = mode 143 | self.min_delta = min_delta 144 | self.patience = patience 145 | self.best = None 146 | self.num_bad_epochs = 0 147 | self.is_better = None 148 | self._init_is_better(mode, min_delta) 149 | 150 | if patience == 0: 151 | self.is_better = lambda a, b: True 152 | 153 | def step(self, metrics): 154 | if self.best is None: 155 | self.best = metrics 156 | return False 157 | 158 | if np.isnan(metrics): 159 | return True 160 | 161 | if self.is_better(metrics, self.best): 162 | self.num_bad_epochs = 0 163 | self.best = metrics 164 | else: 165 | self.num_bad_epochs += 1 166 | 167 | if self.num_bad_epochs >= self.patience: 168 | return True 169 | 170 | return False 171 | 172 | def _init_is_better(self, mode, min_delta): 173 | if mode not in {"min", "max"}: 174 | raise ValueError("mode " + mode + " is unknown!") 175 | if mode == "min": 176 | self.is_better = lambda a, best: a < best - min_delta 177 | if mode == "max": 178 | self.is_better = lambda a, best: a > best + min_delta 179 | -------------------------------------------------------------------------------- /start-container-torch1p11_diffiner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker run --ipc=host --rm --gpus device=0 --shm-size=64g -itd \ 3 | -v /home/$USER:/home/$USER \ 4 | -v /hdd:/hdd \ 5 | -w /home/$USER/project/DDGM/se_ddgm \ 6 | diffiner --------------------------------------------------------------------------------