├── src ├── __init__.py ├── models │ ├── __init__.py │ └── components │ │ ├── __init__.py │ │ ├── learnable_image.py │ │ └── auffusion_converter.py ├── utils │ ├── __init__.py │ ├── rich_utils.py │ ├── re_ranking.py │ ├── pylogger.py │ ├── consistency_check.py │ └── animation_with_text.py ├── colorization │ ├── __init__.py │ ├── views.py │ ├── create_color_video.py │ ├── colorizer.py │ └── samplers.py ├── evaluator │ ├── __init__.py │ ├── clap.py │ ├── clip.py │ └── eval.py ├── guidance │ ├── __init__.py │ ├── deepfloyd.py │ ├── stable_diffusion.py │ └── auffusion.py ├── transformation │ ├── __init__.py │ ├── identity.py │ ├── block_rearrange.py │ ├── random_crop.py │ └── img_to_spec.py ├── main_imprint.py ├── main_sds.py └── main_denoise.py ├── huggingface_login.py ├── assets └── teaser.jpg ├── .project-root ├── configs ├── main_denoise │ ├── debug │ │ └── default.yaml │ ├── hydra │ │ └── default.yaml │ ├── experiment │ │ └── examples │ │ │ ├── dog.yaml │ │ │ ├── tiger.yaml │ │ │ ├── train.yaml │ │ │ ├── bell.yaml │ │ │ ├── corgi.yaml │ │ │ ├── pond.yaml │ │ │ ├── garden-v2.yaml │ │ │ ├── horse.yaml │ │ │ ├── kitten.yaml │ │ │ ├── race.yaml │ │ │ └── garden.yaml │ └── main.yaml ├── main_imprint │ ├── debug │ │ └── default.yaml │ ├── hydra │ │ └── default.yaml │ ├── experiment │ │ └── examples │ │ │ ├── dog.yaml │ │ │ ├── bell.yaml │ │ │ ├── tiger.yaml │ │ │ ├── train.yaml │ │ │ ├── kitten.yaml │ │ │ ├── garden.yaml │ │ │ └── race.yaml │ └── main.yaml └── main_sds │ ├── debug │ └── default.yaml │ ├── hydra │ └── default.yaml │ ├── experiment │ └── examples │ │ ├── dog.yaml │ │ ├── bell.yaml │ │ ├── kitten.yaml │ │ └── garden.yaml │ └── main.yaml ├── LICENSE ├── environment.yml ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/colorization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/guidance/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/transformation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /huggingface_login.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import login 2 | login() -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IFICL/images-that-sound/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /configs/main_denoise/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | hydra: 10 | job_logging: 11 | root: 12 | level: DEBUG 13 | 14 | extras: 15 | ignore_warnings: false 16 | 17 | trainer: 18 | num_inference_steps: 100 -------------------------------------------------------------------------------- /configs/main_imprint/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | hydra: 10 | job_logging: 11 | root: 12 | level: DEBUG 13 | 14 | extras: 15 | ignore_warnings: false 16 | 17 | trainer: 18 | num_inference_steps: 100 -------------------------------------------------------------------------------- /configs/main_sds/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | hydra: 10 | job_logging: 11 | root: 12 | level: DEBUG 13 | 14 | extras: 15 | ignore_warnings: false 16 | 17 | trainer: 18 | num_iteration: 10 19 | save_step: 1 20 | visualize_step: 1 -------------------------------------------------------------------------------- /configs/main_sds/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${output_dir} 11 | 12 | job_logging: 13 | handlers: 14 | file: 15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 16 | filename: ${output_dir}/main.log -------------------------------------------------------------------------------- /src/transformation/identity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class NaiveIdentity(nn.Module): 7 | def __init__( 8 | self, 9 | **kwargs 10 | ): 11 | '''We implement an identity transformation to support our code 12 | ''' 13 | super().__init__() 14 | 15 | def forward(self, x, **kwargs): 16 | ''' 17 | Input: x 18 | Output: x 19 | ''' 20 | return x 21 | -------------------------------------------------------------------------------- /configs/main_denoise/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${output_dir} 11 | 12 | job_logging: 13 | handlers: 14 | file: 15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 16 | filename: ${output_dir}/main.log -------------------------------------------------------------------------------- /configs/main_imprint/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${output_dir} 11 | 12 | job_logging: 13 | handlers: 14 | file: 15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 16 | filename: ${output_dir}/main.log -------------------------------------------------------------------------------- /src/colorization/views.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ColorLView(): 5 | def __init__(self): 6 | pass 7 | 8 | def view(self, im): 9 | return im 10 | 11 | def inverse_view(self, noise): 12 | # Get L color by averaging color channels 13 | noise[:3] = 2 * torch.stack([noise[:3].mean(0)] * 3) 14 | 15 | return noise 16 | 17 | 18 | class ColorABView(): 19 | def __init__(self): 20 | pass 21 | 22 | def view(self, im): 23 | return im 24 | 25 | def inverse_view(self, noise): 26 | # Get AB color by taking residual 27 | noise[:3] = 2 * (noise[:3] - torch.stack([noise[:3].mean(0)] * 3)) 28 | 29 | return noise 30 | 31 | -------------------------------------------------------------------------------- /src/transformation/block_rearrange.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from kornia.augmentation import RandomCrop 6 | 7 | 8 | class BlockRearranger(nn.Module): 9 | def __init__( 10 | self, 11 | **kwargs 12 | ): 13 | '''We implement an easy block rearrange operation 14 | ''' 15 | super().__init__() 16 | 17 | def forward(self, latents, inverse=False): 18 | ''' 19 | Input: (1, C, H, W) 20 | Output: (n_view, C, h, w) 21 | ''' 22 | B, C, H, W = latents.shape 23 | if not inverse: # convert square to rectangle 24 | transformed_latents = torch.cat([latents[:, :, :H//2, :], latents[:, :, H//2:, :]], dim=-1) 25 | else: # convert rectangle to square 26 | transformed_latents = torch.cat([latents[:, :, :, :W//2 ], latents[:, :, :, W//2:]], dim=-2) 27 | return transformed_latents 28 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/dog.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/selected/dog' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of cute dogs, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'dog barking' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/dog.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/dog' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.1 25 | 26 | # image guidance 27 | image_prompt: 'a painting of cute dogs, grayscale' 28 | image_guidance_scale: 7.5 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'dog barking' 33 | audio_guidance_scale: 7.5 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/tiger.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/tiger' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of tigers, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'tiger growling' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/train' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of trains, grayscale' 28 | image_guidance_scale: 10 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'train whistling' 33 | audio_guidance_scale: 10 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/bell.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/bell' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of castle towers, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'bell ringing' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/bell.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/bell' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | 23 | enable_clip_rank: false 24 | # enable_rank: True 25 | top_ranks: 0.1 26 | 27 | # image guidance 28 | image_prompt: 'a painting of castle towers, grayscale' 29 | image_guidance_scale: 7.5 30 | image_start_step: 10 31 | 32 | # audio guidance 33 | audio_prompt: 'bell ringing' 34 | audio_guidance_scale: 7.5 35 | audio_start_step: 0 36 | audio_weight: 0.5 37 | 38 | latent_transformation: 39 | _target_: src.transformation.identity.NaiveIdentity 40 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/tiger.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/tiger-long' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.1 25 | 26 | # image guidance 27 | image_prompt: 'a painting of tigers, grayscale' 28 | image_guidance_scale: 7.5 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'tiger growling' 33 | audio_guidance_scale: 7.5 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/train-long' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | 19 | cutoff_latent: false 20 | crop_image: true 21 | use_colormap: true 22 | 23 | enable_clip_rank: false 24 | # enable_rank: True 25 | top_ranks: 0.1 26 | 27 | # image guidance 28 | image_prompt: 'a painting of trains, grayscale' 29 | image_guidance_scale: 7.5 30 | image_start_step: 10 31 | 32 | # audio guidance 33 | audio_prompt: 'train whistling' 34 | audio_guidance_scale: 7.5 35 | audio_start_step: 0 36 | audio_weight: 0.5 37 | 38 | latent_transformation: 39 | _target_: src.transformation.identity.NaiveIdentity 40 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/corgi.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/corgi' 9 | 10 | seed: 2034 11 | 12 | trainer: 13 | num_samples: 50 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of corgis, grayscale, black background' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'dog barking' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/pond.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/pond' 9 | 10 | seed: 2024 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a pond full of water lilies, grayscale, lithograph style' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'frog croaking' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/garden-v2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/selected/garden' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of a blooming garden, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'birds singing sweetly' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/horse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/horse' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 1000 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of horse heads, grayscale, black background' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'horse neighing' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/kitten.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/kitten' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of furry kittens, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'a kitten meowing for attention' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/race.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/examples/race' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of auto racing game, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: "a race car passing by and disappearing" 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/kitten.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/kitten-long' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | 19 | cutoff_latent: false 20 | crop_image: true 21 | use_colormap: true 22 | 23 | enable_clip_rank: false 24 | # enable_rank: True 25 | top_ranks: 0.1 26 | 27 | # image guidance 28 | image_prompt: 'a painting of furry kittens, grayscale' 29 | image_guidance_scale: 7.5 30 | image_start_step: 10 31 | 32 | # audio guidance 33 | audio_prompt: 'a kitten meowing for attention' 34 | audio_guidance_scale: 7.5 35 | audio_start_step: 0 36 | audio_weight: 0.5 37 | 38 | latent_transformation: 39 | _target_: src.transformation.identity.NaiveIdentity 40 | -------------------------------------------------------------------------------- /configs/main_denoise/experiment/examples/garden.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-denoise/selected/garden' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 100 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | cutoff_latent: false 19 | crop_image: true 20 | use_colormap: true 21 | 22 | enable_clip_rank: false 23 | # enable_rank: True 24 | top_ranks: 0.2 25 | 26 | # image guidance 27 | image_prompt: 'a painting of a blooming garden with many birds, grayscale' 28 | image_guidance_scale: 10.0 29 | image_start_step: 10 30 | 31 | # audio guidance 32 | audio_prompt: 'birds singing sweetly' 33 | audio_guidance_scale: 10.0 34 | audio_start_step: 0 35 | audio_weight: 0.5 36 | 37 | latent_transformation: 38 | _target_: src.transformation.identity.NaiveIdentity 39 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/garden.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/garden' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | 19 | cutoff_latent: false 20 | crop_image: true 21 | use_colormap: true 22 | 23 | enable_clip_rank: false 24 | # enable_rank: True 25 | top_ranks: 0.1 26 | 27 | # image guidance 28 | image_prompt: 'a painting of a blooming garden with many birds, grayscale' 29 | image_guidance_scale: 7.5 30 | image_start_step: 10 31 | 32 | # audio guidance 33 | audio_prompt: 'birds singing sweetly' 34 | audio_guidance_scale: 7.5 35 | audio_start_step: 0 36 | audio_weight: 0.5 37 | 38 | latent_transformation: 39 | _target_: src.transformation.identity.NaiveIdentity 40 | -------------------------------------------------------------------------------- /configs/main_imprint/experiment/examples/race.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-imprint/examples/racing-long' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_samples: 40 14 | num_inference_steps: 100 15 | img_height: 256 16 | img_width: 1024 17 | 18 | 19 | cutoff_latent: false 20 | crop_image: true 21 | use_colormap: true 22 | 23 | enable_clip_rank: false 24 | # enable_rank: True 25 | top_ranks: 0.1 26 | 27 | # image guidance 28 | image_prompt: 'a painting of auto racing game, grayscale' 29 | image_guidance_scale: 7.5 30 | image_start_step: 10 31 | 32 | # audio guidance 33 | audio_prompt: "a race car passing by and disappearing" 34 | audio_guidance_scale: 7.5 35 | audio_start_step: 0 36 | audio_weight: 0.5 37 | 38 | latent_transformation: 39 | _target_: src.transformation.identity.NaiveIdentity 40 | -------------------------------------------------------------------------------- /src/transformation/random_crop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from kornia.augmentation import RandomCrop 6 | 7 | 8 | class ImageRandomCropper(nn.Module): 9 | def __init__( 10 | self, 11 | size, 12 | n_view=1, 13 | padding=None, 14 | cropping_mode="slice", 15 | p=1.0, 16 | **kwargs 17 | ): 18 | '''We implement an easy random cropping operation 19 | ''' 20 | super().__init__() 21 | 22 | self.transformation = RandomCrop( 23 | size=size, 24 | padding=padding, 25 | p=p, 26 | cropping_mode=cropping_mode 27 | ) 28 | self.n_view = n_view 29 | 30 | def forward(self, x): 31 | ''' 32 | Input: (1, C, H, W) 33 | Output: (n_view, C, h, w) 34 | ''' 35 | if x.shape[0] == 1: 36 | x = x.repeat(self.n_view, 1, 1, 1) 37 | 38 | if x.shape[1] == 1: 39 | x = x.repeat(1, 3, 1, 1) 40 | 41 | x = self.transformation(x) 42 | return x 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziyang Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/main_sds/experiment/examples/dog.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-sds/examples/dog' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_iteration: 40000 14 | batch_size: 8 15 | save_step: 50000 16 | visualize_step: 10000 17 | accumulate_grad_batches: 1 18 | 19 | use_colormap: true 20 | crop_image: true 21 | 22 | # image guidance 23 | image_prompt: 'a painting of cute dogs, grayscale' 24 | image_start_step: 5000 25 | image_guidance_scale: 80 26 | image_weight: 0.4 27 | 28 | # audio guidance 29 | audio_prompt: 'dog barking' 30 | audio_guidance_scale: 10 31 | audio_weight: 1 32 | 33 | image_learner: 34 | _target_: src.models.components.learnable_image.LearnableImageFourier 35 | height: 256 36 | width: 1024 37 | num_channels: 1 38 | 39 | audio_transformation: 40 | _target_: src.transformation.img_to_spec.ImageToSpec 41 | inverse: false 42 | flip: false 43 | rgb2gray: mean 44 | 45 | image_diffusion_guidance: 46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance 47 | repo_id: DeepFloyd/IF-I-M-v1.0 48 | fp16: true 49 | t_consistent: true 50 | t_range: [0.02, 0.98] -------------------------------------------------------------------------------- /configs/main_sds/experiment/examples/bell.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-sds/examples/bell' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_iteration: 40000 14 | batch_size: 8 15 | save_step: 50000 16 | visualize_step: 10000 17 | accumulate_grad_batches: 1 18 | 19 | use_colormap: true 20 | crop_image: true 21 | 22 | # image guidance 23 | image_prompt: 'a painting of a castle towers, grayscale' 24 | image_start_step: 5000 25 | image_guidance_scale: 80 26 | image_weight: 0.4 27 | 28 | # audio guidance 29 | audio_prompt: 'Bell ringing' 30 | audio_guidance_scale: 10 31 | audio_weight: 1 32 | 33 | image_learner: 34 | _target_: src.models.components.learnable_image.LearnableImageFourier 35 | height: 256 36 | width: 1024 37 | num_channels: 1 38 | 39 | audio_transformation: 40 | _target_: src.transformation.img_to_spec.ImageToSpec 41 | inverse: false 42 | flip: false 43 | rgb2gray: mean 44 | 45 | image_diffusion_guidance: 46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance 47 | repo_id: DeepFloyd/IF-I-M-v1.0 48 | fp16: true 49 | t_consistent: true 50 | t_range: [0.02, 0.98] -------------------------------------------------------------------------------- /configs/main_sds/experiment/examples/kitten.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-sds/examples/kitten' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_iteration: 40000 14 | batch_size: 8 15 | save_step: 50000 16 | visualize_step: 10000 17 | accumulate_grad_batches: 1 18 | 19 | use_colormap: true 20 | crop_image: true 21 | 22 | # image guidance 23 | image_prompt: 'a painting of furry kittens, grayscale' 24 | image_start_step: 5000 25 | image_guidance_scale: 80 26 | image_weight: 0.4 27 | 28 | # audio guidance 29 | audio_prompt: 'a kitten meowing for attention' 30 | audio_guidance_scale: 10 31 | audio_weight: 1 32 | 33 | image_learner: 34 | _target_: src.models.components.learnable_image.LearnableImageFourier 35 | height: 256 36 | width: 1024 37 | num_channels: 1 38 | 39 | audio_transformation: 40 | _target_: src.transformation.img_to_spec.ImageToSpec 41 | inverse: false 42 | flip: false 43 | rgb2gray: mean 44 | 45 | image_diffusion_guidance: 46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance 47 | repo_id: DeepFloyd/IF-I-M-v1.0 48 | fp16: true 49 | t_consistent: true 50 | t_range: [0.02, 0.98] -------------------------------------------------------------------------------- /configs/main_sds/experiment/examples/garden.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | # this allows you to overwrite only specified parameters 7 | 8 | task_name: 'soundify-sds/examples/garden' 9 | 10 | seed: 1234 11 | 12 | trainer: 13 | num_iteration: 40000 14 | batch_size: 8 15 | save_step: 50000 16 | visualize_step: 10000 17 | accumulate_grad_batches: 1 18 | 19 | use_colormap: true 20 | crop_image: true 21 | 22 | # image guidance 23 | image_prompt: 'a painting of a blooming garden with many birds, grayscale' 24 | image_start_step: 5000 25 | image_guidance_scale: 80 26 | image_weight: 0.4 27 | 28 | # audio guidance 29 | audio_prompt: 'birds singing sweetly' 30 | audio_guidance_scale: 10 31 | audio_weight: 1 32 | 33 | image_learner: 34 | _target_: src.models.components.learnable_image.LearnableImageFourier 35 | height: 256 36 | width: 1024 37 | num_channels: 1 38 | 39 | audio_transformation: 40 | _target_: src.transformation.img_to_spec.ImageToSpec 41 | inverse: false 42 | flip: false 43 | rgb2gray: mean 44 | 45 | image_diffusion_guidance: 46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance 47 | repo_id: DeepFloyd/IF-I-M-v1.0 48 | fp16: true 49 | t_consistent: true 50 | t_range: [0.02, 0.98] -------------------------------------------------------------------------------- /configs/main_imprint/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - hydra: default 4 | # experiment configs allow for version control of specific hyperparameters 5 | # e.g. best hyperparameters for given model and datamodule 6 | - experiment: null 7 | - debug: null 8 | 9 | seed: 2024 10 | log_dir: 'logs' 11 | task_name: 'soundify-imprint' 12 | output_dir: ${log_dir}/${task_name} 13 | 14 | trainer: 15 | num_samples: 5 16 | num_inference_steps: 100 17 | img_height: 256 18 | img_width: 1024 19 | 20 | mag_ratio: 0.5 21 | inverse_image: true 22 | crop_image: false 23 | use_colormap: true 24 | 25 | enable_rank: False 26 | enable_clip_rank: False 27 | top_ranks: 0.2 28 | 29 | # image guidance 30 | image_prompt: 'a castle with bell towers, grayscale, lithograph style' 31 | image_neg_prompt: '' 32 | image_guidance_scale: 7.5 33 | image_start_step: 0 34 | 35 | # audio guidance 36 | audio_prompt: 'bell ringing' 37 | audio_neg_prompt: '' 38 | audio_guidance_scale: 7.5 39 | audio_start_step: 0 40 | audio_weight: 0.5 41 | 42 | audio_diffusion_guidance: 43 | _target_: src.guidance.auffusion.AuffusionGuidance 44 | repo_id: auffusion/auffusion-full-no-adapter 45 | fp16: True 46 | t_range: [0.02, 0.98] 47 | 48 | image_diffusion_guidance: 49 | _target_: src.guidance.stable_diffusion.StableDiffusionGuidance 50 | repo_id: runwayml/stable-diffusion-v1-5 51 | fp16: True 52 | t_consistent: True 53 | t_range: [0.02, 0.98] 54 | 55 | 56 | latent_transformation: 57 | _target_: src.transformation.identity.NaiveIdentity 58 | 59 | audio_evaluator: 60 | _target_: src.evaluator.clap.CLAPEvaluator 61 | 62 | visual_evaluator: 63 | _target_: src.evaluator.clip.CLIPEvaluator 64 | 65 | extras: 66 | ignore_warnings: true 67 | print_config: true 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /src/transformation/img_to_spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from kornia.augmentation import RandomCrop 6 | 7 | 8 | class ImageToSpec(nn.Module): 9 | def __init__( 10 | self, 11 | inverse=False, 12 | flip=False, 13 | rgb2gray='mean', 14 | **kwargs 15 | ): 16 | '''We implement an simple image-to-spectrogram transformation 17 | ''' 18 | super().__init__() 19 | 20 | self.inverse = inverse 21 | self.flip = flip 22 | self.rgb2gray = rgb2gray 23 | 24 | if self.rgb2gray == 'mean': 25 | self.coefficients = torch.ones(3).float() / 3.0 26 | elif self.rgb2gray in ['NTSC', 'ntsc']: 27 | self.coefficients = torch.tensor([0.299, 0.587, 0.114]) 28 | elif self.rgb2gray == 'luminance': 29 | self.coefficients = torch.tensor([0.2126, 0.7152, 0.0722]) 30 | elif self.rgb2gray == 'r_channel': 31 | self.coefficients = torch.tensor([1.0, 0.0, 0.0]) 32 | elif self.rgb2gray == 'g_channel': 33 | self.coefficients = torch.tensor([0.0, 1.0, 0.0]) 34 | elif self.rgb2gray == 'b_channel': 35 | self.coefficients = torch.tensor([0.0, 0.0, 1.0]) 36 | 37 | def forward(self, x): 38 | ''' 39 | Input: (1, C, H, W) 40 | Output: (1, 1, H, W) 41 | ''' 42 | 43 | if x.shape[1] == 1: 44 | x = x.repeat(1, 3, 1, 1) 45 | 46 | coefficients = self.coefficients.view(1, -1, 1, 1).to(dtype=x.dtype, device=x.device) 47 | x = torch.sum(x * coefficients, dim=1, keepdim=True) 48 | 49 | if self.inverse: 50 | x = 1.0 - x 51 | 52 | if self.flip: 53 | x = torch.flip(x, [2]) 54 | 55 | return x 56 | -------------------------------------------------------------------------------- /configs/main_sds/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - hydra: default 4 | # experiment configs allow for version control of specific hyperparameters 5 | # e.g. best hyperparameters for given model and datamodule 6 | - experiment: null 7 | - debug: null 8 | 9 | seed: 2024 10 | log_dir: 'logs' 11 | output_dir: ${log_dir}/${task_name} 12 | task_name: 'soundify-sds' 13 | 14 | trainer: 15 | num_iteration: 40000 16 | batch_size: 8 17 | save_step: 50000 18 | visualize_step: 10000 19 | accumulate_grad_batches: 1 20 | 21 | use_colormap: false 22 | crop_image: false 23 | 24 | # image guidance 25 | image_prompt: 'a painting of castle towers, grayscale' 26 | image_start_step: 5000 27 | image_guidance_scale: 80 28 | image_weight: 0.4 29 | 30 | # audio guidance 31 | audio_prompt: 'bell ringing' 32 | audio_guidance_scale: 10 33 | audio_weight: 1 34 | 35 | 36 | image_learner: 37 | _target_: src.models.components.learnable_image.LearnableImageFourier 38 | height: 256 39 | width: 1024 40 | num_channels: 3 41 | 42 | 43 | audio_diffusion_guidance: 44 | _target_: src.guidance.auffusion.AuffusionGuidance 45 | repo_id: auffusion/auffusion-full-no-adapter 46 | fp16: True 47 | t_range: [0.02, 0.98] 48 | 49 | audio_transformation: 50 | _target_: src.transformation.img_to_spec.ImageToSpec 51 | inverse: false 52 | flip: false 53 | rgb2gray: mean 54 | 55 | 56 | image_diffusion_guidance: 57 | _target_: src.guidance.deepfloyd.DeepfloydGuidance 58 | repo_id: DeepFloyd/IF-I-M-v1.0 59 | fp16: true 60 | t_consistent: true 61 | t_range: [0.02, 0.98] 62 | 63 | image_transformation: 64 | _target_: src.transformation.random_crop.ImageRandomCropper 65 | size: [256, 256] 66 | n_view: ${trainer.batch_size} 67 | 68 | 69 | optimizer: 70 | _target_: torch.optim.AdamW 71 | _partial_: true 72 | lr: 0.0001 73 | weight_decay: 0.001 74 | 75 | extras: 76 | ignore_warnings: true 77 | print_config: true 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /configs/main_denoise/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - hydra: default 4 | # experiment configs allow for version control of specific hyperparameters 5 | # e.g. best hyperparameters for given model and datamodule 6 | - experiment: null 7 | - debug: null 8 | 9 | seed: 2024 10 | log_dir: 'logs' 11 | task_name: 'soundify-denoise' 12 | output_dir: ${log_dir}/${task_name} 13 | 14 | trainer: 15 | num_samples: 50 16 | num_inference_steps: 100 17 | img_height: 256 18 | img_width: 1024 19 | cutoff_latent: false 20 | crop_image: false 21 | use_colormap: true 22 | 23 | enable_rank: False 24 | enable_clip_rank: False 25 | top_ranks: 0.2 26 | 27 | # image guidance 28 | image_prompt: 'a castle with bell towers, grayscale, lithograph style' 29 | image_neg_prompt: '' 30 | image_guidance_scale: 10.0 31 | image_start_step: 10 32 | 33 | # audio guidance 34 | audio_prompt: 'bell ringing' 35 | audio_neg_prompt: '' 36 | audio_guidance_scale: 10.0 37 | audio_start_step: 0 38 | audio_weight: 0.5 39 | 40 | audio_diffusion_guidance: 41 | _target_: src.guidance.auffusion.AuffusionGuidance 42 | repo_id: auffusion/auffusion-full-no-adapter 43 | fp16: True 44 | t_range: [0.02, 0.98] 45 | 46 | image_diffusion_guidance: 47 | _target_: src.guidance.stable_diffusion.StableDiffusionGuidance 48 | repo_id: runwayml/stable-diffusion-v1-5 49 | fp16: True 50 | t_consistent: True 51 | t_range: [0.02, 0.98] 52 | 53 | diffusion_scheduler: 54 | _target_: diffusers.DDIMScheduler.from_pretrained 55 | pretrained_model_name_or_path: runwayml/stable-diffusion-v1-5 56 | subfolder: "scheduler" 57 | torch_dtype: torch.float16 58 | 59 | latent_transformation: 60 | _target_: src.transformation.identity.NaiveIdentity 61 | 62 | audio_evaluator: 63 | _target_: src.evaluator.clap.CLAPEvaluator 64 | 65 | visual_evaluator: 66 | _target_: src.evaluator.clip.CLIPEvaluator 67 | 68 | extras: 69 | ignore_warnings: true 70 | print_config: true 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /src/evaluator/clap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import soundfile as sf 4 | import librosa 5 | from lightning import seed_everything 6 | from transformers import AutoProcessor, ClapModel 7 | 8 | class CLAPEvaluator(nn.Module): 9 | def __init__( 10 | self, 11 | repo_id="laion/clap-htsat-unfused", 12 | **kwargs 13 | ): 14 | super(CLAPEvaluator, self).__init__() 15 | self.repo_id = repo_id 16 | 17 | # create model and load pretrained weights from huggingface 18 | self.model = ClapModel.from_pretrained(self.repo_id) 19 | self.processor = AutoProcessor.from_pretrained(self.repo_id) 20 | 21 | def forward(self, text, audio, sampling_rate=16000): 22 | return self.calc_score(text, audio, sampling_rate=sampling_rate) 23 | 24 | def calc_score(self, text, audio, sampling_rate=16000): 25 | if sampling_rate != 48000: 26 | audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=48000) 27 | audio = librosa.util.normalize(audio) 28 | inputs = self.processor(text=text, audios=audio, return_tensors="pt", sampling_rate=48000, padding=True).to(self.model.device) 29 | outputs = self.model(**inputs) 30 | logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score 31 | score = logits_per_audio / self.model.logit_scale_a.exp() 32 | score = score.squeeze().item() 33 | score = max(score, 0.0) 34 | return score 35 | 36 | 37 | if __name__ == "__main__": 38 | # import pdb; pdb.set_trace() 39 | seed_everything(2024, workers=True) 40 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 41 | 42 | text = 'Birds singing sweetly' 43 | audio_path = '/home/czyang/Workspace/images-that-sound/data/audios/audio_02.wav' 44 | # import pdb; pdb.set_trace() 45 | audio, sr = sf.read(audio_path) 46 | 47 | clap = CLAPEvaluator().to(device) 48 | score = clap.calc_score(text, audio, sampling_rate=sr) 49 | print(score) 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # reasons you might want to use `environment.yml` instead of `requirements.txt`: 2 | # - pip installs packages in a loop, without ensuring dependencies across all packages 3 | # are fulfilled simultaneously, but conda achieves proper dependency control across 4 | # all packages 5 | # - conda allows for installing packages without requiring certain compilers or 6 | # libraries to be available in the system, since it installs precompiled binaries 7 | 8 | name: soundify 9 | 10 | channels: 11 | - pytorch 12 | - conda-forge 13 | - defaults 14 | - huggingface 15 | - nvidia 16 | 17 | # it is strongly recommended to specify versions of packages installed through conda 18 | # to avoid situation when version-unspecified packages install their latest major 19 | # versions which can sometimes break things 20 | 21 | # current approach below keeps the dependencies in the same major versions across all 22 | # users, but allows for different minor and patch versions of packages where backwards 23 | # compatibility is usually guaranteed 24 | 25 | dependencies: 26 | - python=3.10 27 | - pytorch=2.1.2 28 | - torchvision=0.16 29 | - torchaudio=2.1.2 30 | - pytorch-cuda=12.1 31 | - lightning=2.2.0 32 | - torchmetrics=1.* 33 | - hydra-core=1.3 34 | - rich=13.* 35 | # - pre-commit=3.* 36 | # - pytest=7.* 37 | - diffusers=0.25.* 38 | - transformers=4.36.* 39 | - pysoundfile=0.12.1 40 | - kornia=0.7.0 41 | - tensorboard=2.15.1 42 | - accelerate 43 | - librosa 44 | 45 | - matplotlib # need to include this package in the yml file else the conda is not able to solve the conflict. this will significantly slow down the speed, remove it if you don't need 46 | - moviepy 47 | - imagemagick 48 | 49 | 50 | # --------- loggers --------- # 51 | # - wandb 52 | # - neptune-client 53 | # - mlflow 54 | # - comet-ml 55 | # - aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 56 | 57 | - pip>=23 58 | - pip: 59 | - hydra-optuna-sweeper 60 | - hydra-colorlog 61 | - rootutils 62 | - gpustat 63 | - nvitop 64 | - sentencepiece==0.2.0 65 | - bitsandbytes==0.43.1 66 | -------------------------------------------------------------------------------- /src/evaluator/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms.functional as TF 4 | 5 | from PIL import Image 6 | 7 | from torchmetrics.multimodal.clip_score import CLIPScore 8 | from torchmetrics.functional.multimodal import clip_score 9 | from functools import partial 10 | 11 | 12 | class CLIPEvaluator(nn.Module): 13 | def __init__( 14 | self, 15 | repo_id='openai/clip-vit-base-patch16', 16 | **kwargs 17 | ): 18 | super(CLIPEvaluator, self).__init__() 19 | self.repo_id = repo_id 20 | 21 | # create model and load pretrained weights from huggingface 22 | # self.clip = partial(clip_score, model_name_or_path=repo_id) 23 | self.clip = CLIPScore(model_name_or_path=repo_id) 24 | self.clip.reset() 25 | 26 | def forward(self, image, text): 27 | image, text = self.processing(image, text) 28 | image_int = (image * 255).to(torch.uint8) 29 | score = self.clip(image_int, text) 30 | score = score.item() / 100 31 | return score 32 | 33 | def processing(self, image, text): 34 | bsz, C, H, W = image.shape 35 | # import pdb; pdb.set_trace() 36 | if H != W: 37 | image = image.unfold(dimension=3, size=H, step=H//8) # cut image into several HxH images 38 | image = image.permute(0, 3, 1, 2, 4).squeeze(0) 39 | text = [text] * image.shape[0] 40 | return image, text 41 | 42 | if __name__ == "__main__": 43 | # import pdb; pdb.set_trace() 44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 45 | # Load image 46 | # im_path = "logs/soundify-denoise/colorization/bell_example_29/img.png" 47 | # text = "a castle with bell towers, grayscale, lithograph style" 48 | im_path = "/home/czyang/Workspace/images-that-sound/logs/soundify-denoise/colorization/bell_example_29/img.png" 49 | text = "a castle with bell towers, grayscale, lithograph style" 50 | 51 | im = Image.open(im_path) 52 | im = TF.to_tensor(im).to(device) 53 | im = im.unsqueeze(0) 54 | # import pdb; pdb.set_trace() 55 | clip = CLIPEvaluator().to(device) 56 | score = clip(im, text) 57 | print(score) 58 | 59 | score = clip(im, text) 60 | print(score) 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | import os, sys 4 | 5 | import rich 6 | import rich.syntax 7 | import rich.tree 8 | from hydra.core.hydra_config import HydraConfig 9 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 10 | from omegaconf import DictConfig, OmegaConf, open_dict 11 | from rich.prompt import Prompt 12 | 13 | 14 | from src.utils import pylogger 15 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "extras", 22 | ), 23 | resolve: bool = False, 24 | save_to_file: bool = False, 25 | ) -> None: 26 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 27 | 28 | :param cfg: A DictConfig composed by Hydra. 29 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 30 | "callbacks", "logger", "trainer", "paths", "extras")``. 31 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 32 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 33 | """ 34 | style = "dim" 35 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 36 | 37 | queue = [] 38 | 39 | # add fields from `print_order` to queue 40 | for field in print_order: 41 | queue.append(field) if field in cfg else log.warning( 42 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 43 | ) 44 | 45 | # add all the other fields to queue (not specified in `print_order`) 46 | for field in cfg: 47 | if field not in queue: 48 | queue.append(field) 49 | 50 | # generate config tree from queue 51 | for field in queue: 52 | branch = tree.add(field, style=style, guide_style=style) 53 | 54 | config_group = cfg[field] 55 | if isinstance(config_group, DictConfig): 56 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 57 | else: 58 | branch_content = str(config_group) 59 | 60 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 61 | 62 | # print config tree 63 | rich.print(tree) 64 | -------------------------------------------------------------------------------- /src/utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import shutil 5 | 6 | import rootutils 7 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 8 | from src.utils.pylogger import RankedLogger 9 | log = RankedLogger(__name__, rank_zero_only=True) 10 | 11 | 12 | def select_top_k_ranking(cfg, clip_scores, clap_scores): 13 | top_ranks = cfg.trainer.get("top_ranks", 0.2) 14 | clap_scores = np.array(clap_scores) 15 | clip_scores = np.array(clip_scores) 16 | top_ranks = int(clip_scores.shape[0] * top_ranks) 17 | 18 | # Get top-k indices for each array 19 | top_k_clap_indices = np.argsort(clap_scores)[::-1][:top_ranks] 20 | top_k_clip_indices = np.argsort(clip_scores)[::-1][:top_ranks] 21 | 22 | # Find the joint indices of top-k indices 23 | joint_indices = set(top_k_clap_indices).intersection(set(top_k_clip_indices)) 24 | joint_indices = list(joint_indices) 25 | 26 | ori_dir = os.path.join(cfg.output_dir, 'results') 27 | examples = glob.glob(f'{ori_dir}/*') 28 | examples.sort() 29 | selected_dir = os.path.join(cfg.output_dir, 'results_selected') 30 | os.makedirs(selected_dir, exist_ok=True) 31 | 32 | log.info(f"Selected {len(joint_indices)} examples.") 33 | for ind in joint_indices: 34 | example_dir_path = examples[ind] 35 | dir_name = example_dir_path.split('/')[-1] 36 | save_dir_path = os.path.join(selected_dir, dir_name) 37 | shutil.copytree(example_dir_path, save_dir_path) 38 | return 39 | 40 | def select_top_k_clip_ranking(cfg, clip_scores): 41 | # import pdb; pdb.set_trace() 42 | top_ranks = cfg.trainer.get("top_ranks", 0.1) 43 | clip_scores = np.array(clip_scores) 44 | top_ranks = int(clip_scores.shape[0] * top_ranks) 45 | 46 | # Get top-k indices 47 | top_k_clip_indices = np.argsort(clip_scores)[::-1][:top_ranks] 48 | 49 | ori_dir = os.path.join(cfg.output_dir, 'results') 50 | examples = glob.glob(f'{ori_dir}/*') 51 | examples.sort() 52 | selected_dir = os.path.join(cfg.output_dir, 'results_selected') 53 | os.makedirs(selected_dir, exist_ok=True) 54 | 55 | log.info(f"Selected {len(top_k_clip_indices)} examples.") 56 | for ind in top_k_clip_indices: 57 | example_dir_path = examples[ind] 58 | dir_name = example_dir_path.split('/')[-1] 59 | save_dir_path = os.path.join(selected_dir, dir_name) 60 | shutil.copytree(example_dir_path, save_dir_path) 61 | return 62 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | logging.getLogger('PIL').setLevel(logging.WARNING) 25 | super().__init__(logger=logger, extra=extra) 26 | self.rank_zero_only = rank_zero_only 27 | 28 | def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: 29 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 30 | of the process it's being logged from. If `'rank'` is provided, then the log will only 31 | occur on that rank/process. 32 | 33 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 34 | :param msg: The message to log. 35 | :param rank: The rank to log at. 36 | :param args: Additional args to pass to the underlying logging function. 37 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 38 | """ 39 | if self.isEnabledFor(level): 40 | msg, kwargs = self.process(msg, kwargs) 41 | current_rank = getattr(rank_zero_only, "rank", None) 42 | if current_rank is None: 43 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") 44 | msg = rank_prefixed_message(msg, current_rank) 45 | if self.rank_zero_only: 46 | if current_rank == 0: 47 | self.logger.log(level, msg, *args, **kwargs) 48 | else: 49 | if rank is None: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | elif current_rank == rank: 52 | self.logger.log(level, msg, *args, **kwargs) 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | # /data/ 150 | /logs/ 151 | .env 152 | 153 | # Aim logging 154 | .aim 155 | 156 | /data 157 | /data/* 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Images that Sound 2 | 3 | [Ziyang Chen](https://ificl.github.io/), [Daniel Geng](https://dangeng.github.io/), [Andrew Owens](https://andrewowens.com/) 4 | 5 | University of Michigan, Ann Arbor 6 | 7 | arXiv 2024 8 | 9 | [[Paper](https://arxiv.org/abs/2405.12221)] [[Project Page](https://ificl.github.io/images-that-sound/)] 10 |
11 | 12 | This repository contains the code to generate *images that sound*, a special spectrogram that can be seen as images and played as sound. 13 |
14 | teaser 15 |
16 | 17 | 18 | ## Environment 19 | To setup the environment, please simply run: 20 | ```bash 21 | conda env create -f environment.yml 22 | conda activate soundify 23 | ``` 24 | ***Pro tip***: *we highly recommend using [mamba](https://github.com/conda-forge/miniforge) instead of conda for much faster environment solving and installation.* 25 | 26 | **DeepFlyod**: our repo also uses [DeepFloyd IF](https://huggingface.co/docs/diffusers/api/pipelines/deepfloyd_if). To use DeepFloyd IF, you must accept its usage conditions. To do so: 27 | 28 | 1. Sign up or log in to [Hugging Face account](https://huggingface.co/join). 29 | 2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). 30 | 3. Log in locally by running `python huggingface_login.py` and entering your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens) when prompted. It does not matter how you answer the `Add token as git credential? (Y/n)` question. 31 | 32 | 33 | ## Usage 34 | 35 | We use pretrained image latent diffusion [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and pretrained audio latent diffusion [Auffusion](https://huggingface.co/auffusion/auffusion-full-no-adapter), which finetuned from Stable Diffusion. We provide the codes (including visualization) and instructions for our approach (multimodal denoising) and two proposed baselines: Imprint and SDS. We note that our code is based on the [hydra](https://github.com/facebookresearch/hydra), you can overwrite the parameters based on hydra. 36 | 37 | ### Multimodal denoising 38 | To create *images that sound* using our multimodal denoising method, run the code with config files under `configs/main_denoise/experiment`: 39 | ```bash 40 | python src/main_denoise.py experiment=examples/bell 41 | ``` 42 | **Note:** our method does not have a high success rate since it's zero-shot and it highly depends on initial random noises. We recommend generating more samples such as N=100 to selectively hand-pick high-quality results. 43 | 44 | 45 | ### Imprint baseline 46 | To create *images that sound* using our proposed imprint baseline method, run the code with config files under `configs/main_imprint/experiment`: 47 | ```bash 48 | python src/main_imprint.py experiment=examples/bell 49 | ``` 50 | 51 | ### SDS baseline 52 | To create *images that sound* using our proposed multimodal SDS baseline method, run the code with config file under `configs/main_sds/experiment`: 53 | ```bash 54 | python src/main_sds.py experiment=examples/bell 55 | ``` 56 | **Note:** we find that Audio SDS doesn't work for a lot of audio prompts. We hypothesize the reason is that latent diffusions don't work quite well as pixel-based diffusion for SDS. 57 | 58 | ### Colorization 59 | We also provide the colorization code under `src/colorization` which is adopted from [Factorized Diffusion](https://github.com/dangeng/visual_anagrams). To directly generate colorized videos with audio, run the code: 60 | ```bash 61 | python src/colorization/create_color_video.py \ 62 | --sample_dir /path/to/generated/sample/dir \ 63 | --prompt "a colorful photo of [object]" \ 64 | --num_samples 16 --guidance_scale 10 \ 65 | --num_inference_steps 30 --start_diffusion_step 7 66 | ``` 67 | **Note:** since our generated images fall outside the distribution, we recommend running more trials (num_samples=16) to select best colorized results. 68 | 69 | 70 | 71 | ## Acknowledgement 72 | Our code is based on [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), [diffusers](https://github.com/huggingface/diffusers), [stable-dreamfusion](https://github.com/ashawkey/stable-dreamfusion), [Diffusion-Illusions](https://github.com/RyannDaGreat/Diffusion-Illusions), [Auffusion](https://github.com/happylittlecat2333/Auffusion), and [visual-anagrams](https://github.com/dangeng/visual_anagrams). We appreciate their open-source codes. -------------------------------------------------------------------------------- /src/evaluator/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | from tqdm import tqdm 4 | import os 5 | import glob 6 | import soundfile as sf 7 | from omegaconf import OmegaConf, DictConfig 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torchvision.transforms.functional as TF 13 | 14 | from lightning import seed_everything 15 | 16 | import rootutils 17 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 18 | 19 | from src.evaluator.clap import CLAPEvaluator 20 | from src.evaluator.clip import CLIPEvaluator 21 | from src.utils.consistency_check import griffin_lim 22 | 23 | 24 | def bootstrap_confidence_intervals(data, num_bootstraps=10000): 25 | # Bootstrap resampling 26 | bootstrap_samples = np.random.choice(data, size=(num_bootstraps, data.shape[0]), replace=True) 27 | bootstrap_means = np.mean(bootstrap_samples, axis=1) 28 | 29 | # Compute confidence interval 30 | confidence_interval = np.percentile(bootstrap_means, [2.5, 97.5]) 31 | 32 | # Calculate point estimate (mean) 33 | sample_mean = np.mean(data) 34 | 35 | # Calculate margin of error 36 | margin_of_error = (confidence_interval[1] - confidence_interval[0]) / 2 37 | 38 | return sample_mean, margin_of_error 39 | 40 | 41 | def eval(args): 42 | seed_everything(2024, workers=True) 43 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 44 | 45 | tqdm.write("Preparing samples to be evaluated...") 46 | eval_samples = glob.glob(f'{args.eval_path}/*/results/*') 47 | eval_samples.sort() 48 | 49 | if args.max_sample != -1: 50 | eval_samples = eval_samples[:args.max_sample] 51 | 52 | tqdm.write("Instantiating audio evaluator...") 53 | audio_evaluator = CLAPEvaluator().to(device) 54 | 55 | tqdm.write("Instantiating visual evaluator...") 56 | visual_evaluator = CLIPEvaluator().to(device) 57 | 58 | clip_scores = [] 59 | clap_scores = [] 60 | 61 | for sample in tqdm(eval_samples, desc="Evaluating"): 62 | clip_score, clap_score = evaluate_single_sample(sample, audio_evaluator, visual_evaluator, device, use_griffin_lim=args.use_griffin_lim) 63 | 64 | clip_scores.append(clip_score) 65 | clap_scores.append(clap_score) 66 | 67 | clip_scores = np.array(clip_scores) 68 | clap_scores = np.array(clap_scores) 69 | 70 | # Choose a multiplier (e.g., for 95% confidence interval, multiplier is approximately 1.96) 71 | confidence_multiplier = 1.96 72 | n = clip_scores.shape[0] 73 | 74 | avg_clip_score = clip_scores.mean() 75 | # _, clip_error = bootstrap_confidence_intervals(clip_scores) 76 | clip_error = confidence_multiplier * np.std(clip_scores) / np.sqrt(n) 77 | tqdm.write(f"Averaged CLIP score: {avg_clip_score * 100} | margin of error: {clip_error * 100}") 78 | 79 | avg_clap_score = clap_scores.mean() 80 | # _, clap_error = bootstrap_confidence_intervals(clap_scores) 81 | clap_error = confidence_multiplier * np.std(clap_scores) / np.sqrt(n) 82 | tqdm.write(f"Averaged CLAP score: {avg_clap_score * 100} | margin of error: {clap_error * 100}") 83 | 84 | 85 | def evaluate_single_sample(sample_dir, audio_evaluator, visual_evaluator, device, use_griffin_lim=False): 86 | # import pdb; pdb.set_trace() 87 | # read sample dir 88 | gray_im_path = f'{sample_dir}/spec.png' 89 | audio_path = f'{sample_dir}/audio.wav' 90 | config_path = f'{sample_dir}/config.yaml' 91 | cfg = OmegaConf.load(config_path) 92 | image_prompt = cfg.trainer.image_prompt 93 | audio_prompt = cfg.trainer.audio_prompt 94 | 95 | # Load gray image and evaluate 96 | gray_im = Image.open(gray_im_path) 97 | gray_im = TF.to_tensor(gray_im).to(device) 98 | spec = gray_im.detach().cpu() 99 | gray_im = gray_im.mean(dim=0, keepdim=True).repeat(3, 1, 1) 100 | gray_im = gray_im.unsqueeze(0) 101 | 102 | clip_score = visual_evaluator(gray_im, image_prompt) 103 | 104 | # load audio waveform and evaluate 105 | audio, sr = sf.read(audio_path) 106 | 107 | if use_griffin_lim: 108 | # import pdb; pdb.set_trace() 109 | audio = griffin_lim(spec, audio) 110 | 111 | clap_score = audio_evaluator(audio_prompt, audio, sampling_rate=sr) 112 | 113 | return clip_score, clap_score 114 | 115 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/auffusion" 116 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/AV-IF-SDS-V2" 117 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/AV-Denoise-cfg7.5" 118 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/AV-Denoise-notime" 119 | 120 | 121 | if __name__ == '__main__': 122 | # Parse args 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--eval_path", required=True, type=str) 125 | parser.add_argument('--use_griffin_lim', default=False, action='store_true') 126 | parser.add_argument("--max_sample", type=int, default=-1) 127 | 128 | args = parser.parse_args() 129 | 130 | eval(args) -------------------------------------------------------------------------------- /src/utils/consistency_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from PIL import Image 5 | import shutil 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.utils import save_image 10 | import torchvision.transforms.functional as TF 11 | import torchaudio 12 | 13 | import soundfile as sf 14 | import librosa 15 | 16 | import rootutils 17 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 18 | 19 | from src.models.components.auffusion_converter import mel_spectrogram, normalize_spectrogram, denormalize_spectrogram 20 | 21 | 22 | 23 | def wav2spec(audio, sr): 24 | audio = torch.FloatTensor(audio) 25 | audio = audio.unsqueeze(0) 26 | spec = mel_spectrogram(audio, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False) 27 | spec = normalize_spectrogram(spec) 28 | return spec 29 | 30 | 31 | def griffin_lim(mel_spec, ori_audio): 32 | mel_spec = denormalize_spectrogram(mel_spec) 33 | mel_spec = torch.exp(mel_spec) 34 | 35 | audio = librosa.feature.inverse.mel_to_audio( 36 | mel_spec.numpy(), 37 | sr=16000, 38 | n_fft=2048, 39 | hop_length=160, 40 | win_length=1024, 41 | power=1, 42 | center=True, 43 | # length=ori_audio.shape[0] 44 | ) 45 | 46 | length = ori_audio.shape[0] 47 | 48 | if audio.shape[0] > length: 49 | audio = audio[:length] 50 | elif audio.shape[0] < length: 51 | audio = np.pad(audio, (0, length - audio.shape[0]), mode='constant') 52 | 53 | audio = np.clip(audio, a_min=-1, a_max=1) 54 | return audio 55 | 56 | def inverse_stft(mel_spec, ori_audio): 57 | mel_spec = denormalize_spectrogram(mel_spec) 58 | mel_spec = torch.exp(mel_spec) 59 | 60 | n_fft = 2048 61 | hop_length = 160 62 | win_length = 1024 63 | power = 1 64 | center = False 65 | 66 | spec_mag = librosa.feature.inverse.mel_to_stft(mel_spec.numpy(), sr=16000, n_fft=n_fft, power=power) 67 | spec_mag = torch.tensor(spec_mag).float() 68 | 69 | audio_length = ori_audio.shape[0] 70 | ori_audio = torch.tensor(ori_audio) 71 | ori_audio = ori_audio.unsqueeze(0) 72 | ori_audio = torch.nn.functional.pad(ori_audio.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') 73 | ori_audio = ori_audio.squeeze(1) 74 | vocoder_spec = torch.stft(ori_audio, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 75 | # Get the phase from the complex spectrogram 76 | vocoder_phase = torch.angle(vocoder_spec).float() 77 | # import pdb; pdb.set_trace() 78 | 79 | # Combine the new magnitude with the original phase 80 | # We use polar coordinates to transform magnitude and phase into a complex number 81 | reconstructed_complex_spec = torch.polar(spec_mag.unsqueeze(0), vocoder_phase) 82 | 83 | # Perform the ISTFT to convert the spectrogram back to time domain audio signal 84 | reconstructed_audio = torch.istft( 85 | reconstructed_complex_spec, 86 | n_fft, 87 | hop_length=hop_length, 88 | win_length=win_length, 89 | window=torch.hann_window(win_length), 90 | center=True, 91 | normalized=False, 92 | onesided=True, 93 | length=audio_length # Ensure output audio length matches original audio 94 | ) 95 | 96 | reconstructed_audio = reconstructed_audio.squeeze(0).numpy() 97 | reconstructed_audio = np.clip(reconstructed_audio, a_min=-1, a_max=1) 98 | 99 | return reconstructed_audio 100 | 101 | 102 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/colorization/bell_example_005" 103 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/colorization/tiger_example_002" 104 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/colorization/dog_example_06" 105 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/debug/results/example_015" 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument("--dir", required=False, type=str, default="logs/soundify-denoise/colorization/bell_example_29") 112 | 113 | args = parser.parse_args() 114 | save_dir = f"logs/consistency-check/{args.dir.split('/')[-1]}" 115 | os.makedirs(save_dir, exist_ok=True) 116 | 117 | # import pdb; pdb.set_trace() 118 | 119 | # audio from vocoder to spectrogram 120 | audio_path = os.path.join(args.dir, "audio.wav") 121 | audio_data, sampling_rate = sf.read(audio_path) 122 | spec = wav2spec(audio_data, sampling_rate) 123 | 124 | save_path = os.path.join(save_dir, f"respec-hifi.png") 125 | save_image(spec, save_path, padding=0) 126 | 127 | # spectrogram to audio using griffin-lim 128 | spec_path = os.path.join(args.dir, "spec.png") 129 | spec = Image.open(spec_path) 130 | spec = TF.to_tensor(spec) 131 | 132 | save_path = os.path.join(save_dir, f"spec.png") 133 | shutil.copyfile(spec_path, save_path) 134 | 135 | audio_istft = inverse_stft(spec, audio_data) 136 | save_audio_path = os.path.join(save_dir, f"audio-istft.wav") 137 | sf.write(save_audio_path, audio_istft, samplerate=16000) 138 | spec_istft = wav2spec(audio_istft, sampling_rate) 139 | 140 | save_path = os.path.join(save_dir, f"respec-istft.png") 141 | save_image(spec_istft, save_path, padding=0) 142 | 143 | 144 | audio_gl = griffin_lim(spec, audio_data) 145 | save_audio_path = os.path.join(save_dir, f"audio-griffin-lim.wav") 146 | sf.write(save_audio_path, audio_gl, samplerate=16000) 147 | 148 | # audio_data, sampling_rate = sf.read(save_audio_path) 149 | 150 | spec_gl = wav2spec(audio_gl, sampling_rate) 151 | 152 | save_path = os.path.join(save_dir, f"respec-gl.png") 153 | save_image(spec_gl, save_path, padding=0) 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /src/colorization/create_color_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import glob 6 | import os 7 | from omegaconf import OmegaConf, DictConfig, open_dict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torchvision.transforms.functional as TF 12 | from torchvision.utils import save_image 13 | 14 | import rootutils 15 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 16 | 17 | from src.colorization.colorizer import FactorizedColorization 18 | from src.utils.animation_with_text import create_animation_with_text 19 | 20 | 21 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/bell_example_29 --prompt "a colorful photo of a castle with bell towers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 22 | 23 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/bell_example_005 --prompt "a colorful photo of a castle with bell towers" --num_samples 12 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 24 | 25 | 26 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/tiger_example_02 --prompt "a colorful photo of a tigers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 27 | 28 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/tiger_example_06 --prompt "a colorful photo of a tigers" --num_samples 8 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 29 | 30 | 31 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/race_example_002 --prompt "a colorful photo of a auto racing game" --num_samples 12 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 32 | 33 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/bird_example_40 --prompt "a blooming garden with many birds" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 34 | 35 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/kitten_example_08 --prompt "a colorful photo of kittens" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 36 | 37 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/kitten_example_08_v2 --prompt "a colorful photo of kittens with blue eyes and pink noses" --num_samples 16 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 38 | 39 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/dog_example_06 --prompt "a colorful photo of dogs" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 40 | 41 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/train_example_02 --prompt "a colorful photo of a long train" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 42 | 43 | if __name__ == '__main__': 44 | # Parse args 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--sample_dir", required=True, type=str) 47 | parser.add_argument("--prompt", required=True, type=str, help='Prompts to use for colorization') 48 | parser.add_argument("--num_samples", type=int, default=4) 49 | parser.add_argument("--depth", type=int, default=0) 50 | # parser.add_argument('--no_colormap', default=False, action='store_true') 51 | parser.add_argument("--guidance_scale", type=float, default=10.0) 52 | parser.add_argument("--num_inference_steps", type=int, default=30) 53 | parser.add_argument("--seed", type=int, default=0) 54 | parser.add_argument("--device", type=str, default='cuda') 55 | parser.add_argument("--noise_level", type=int, default=50, help='Noise level for stage 2') 56 | parser.add_argument("--start_diffusion_step", type=int, default=7, help='What step to start the diffusion process') 57 | 58 | 59 | args = parser.parse_args() 60 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 61 | 62 | # create diffusion colorization instance 63 | colorizer = FactorizedColorization( 64 | inverse_color=False, 65 | num_inference_steps=args.num_inference_steps, 66 | guidance_scale=args.guidance_scale, 67 | start_diffusion_step=args.start_diffusion_step, 68 | noise_level=args.noise_level, 69 | ).to(device) 70 | 71 | 72 | sample_dirs = glob.glob(f"{args.sample_dir}" + "/*" * args.depth) 73 | sample_dirs.sort() 74 | 75 | # read sample dir 76 | for sample_dir in sample_dirs: 77 | gray_im_path = f'{sample_dir}/img.png' 78 | spec = f'{sample_dir}/spec_colormap.png' 79 | if not os.path.exists(spec): 80 | spec = f'{sample_dir}/spec.png' 81 | 82 | audio = f'{sample_dir}/audio.wav' 83 | config_path = f'{sample_dir}/config.yaml' 84 | cfg = OmegaConf.load(config_path) 85 | image_prompt = args.prompt 86 | audio_prompt = cfg.trainer.audio_prompt 87 | 88 | with open_dict(cfg): 89 | cfg.trainer.colorization_prompt = args.prompt 90 | OmegaConf.save(cfg, config_path) 91 | 92 | # Load gray image 93 | gray_im = Image.open(gray_im_path) 94 | gray_im = TF.to_tensor(gray_im).to(device) 95 | 96 | img_save_dir = os.path.join(sample_dir, 'colorized_imgs') 97 | os.makedirs(img_save_dir, exist_ok=True) 98 | 99 | video_save_dir = os.path.join(sample_dir, 'colorized_videos') 100 | os.makedirs(video_save_dir, exist_ok=True) 101 | 102 | # Sample illusions 103 | for i in tqdm(range(args.num_samples), desc="Sampling images"): 104 | generator = torch.manual_seed(args.seed + i) 105 | image = colorizer(gray_im, args.prompt, generator=generator) 106 | img_save_path = f'{img_save_dir}/{i:04}.png' 107 | save_image(image, img_save_path, padding=0) 108 | 109 | video_save_path = f'{video_save_dir}/{i:04}.mp4' 110 | create_animation_with_text(img_save_path, spec, audio, video_save_path, image_prompt, audio_prompt) 111 | 112 | -------------------------------------------------------------------------------- /src/models/components/learnable_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class LearnableImage(nn.Module): 7 | def __init__( 8 | self, 9 | height :int, 10 | width :int, 11 | num_channels:int 12 | ): 13 | '''This is an abstract class, and is meant to be subclassed before use upon calling forward(), retuns a tensor of shape (num_channels, height, width) 14 | ''' 15 | super().__init__() 16 | 17 | self.height = height 18 | self.width = width 19 | self.num_channels = num_channels 20 | 21 | def as_numpy_image(self): 22 | image = self.forward() 23 | image = image.cpu().numpy() 24 | image = image.transpose(1, 2, 0) 25 | return image 26 | 27 | 28 | class LearnableImageFourier(LearnableImage): 29 | def __init__( 30 | self, 31 | height :int=256, # Height of the learnable images 32 | width :int=256, # Width of the learnable images 33 | num_channels:int=3 , # Number of channels in the images 34 | hidden_dim :int=256, # Number of dimensions per hidden layer of the MLP 35 | num_features:int=128, # Number of fourier features per coordinate 36 | scale :int=10 , # Magnitude of the initial feature noise 37 | renormalize :bool=False 38 | ): 39 | super().__init__(height, width, num_channels) 40 | 41 | self.hidden_dim = hidden_dim 42 | self.num_features = num_features 43 | self.scale = scale 44 | self.renormalize = renormalize 45 | 46 | # The following objects do NOT have parameters, and are not changed while optimizing this class 47 | self.uv_grid = nn.Parameter(get_uv_grid(height, width, batch_size=1), requires_grad=False) 48 | self.feature_extractor = GaussianFourierFeatureTransform(2, num_features, scale) 49 | self.features = nn.Parameter(self.feature_extractor(self.uv_grid), requires_grad=False) # pre-compute this if we're regressing on images 50 | 51 | H = hidden_dim # Number of hidden features. These 1x1 convolutions act as a per-pixel MLP 52 | C = num_channels # Shorter variable names let us align the code better 53 | M = 2 * num_features 54 | self.model = nn.Sequential( 55 | nn.Conv2d(M, H, kernel_size=1), nn.ReLU(), nn.BatchNorm2d(H), 56 | nn.Conv2d(H, H, kernel_size=1), nn.ReLU(), nn.BatchNorm2d(H), 57 | nn.Conv2d(H, H, kernel_size=1), nn.ReLU(), nn.BatchNorm2d(H), 58 | nn.Conv2d(H, C, kernel_size=1), 59 | nn.Sigmoid(), 60 | ) 61 | 62 | def forward(self): 63 | features = self.features 64 | 65 | output = self.model(features).squeeze(0) 66 | 67 | assert output.shape==(self.num_channels, self.height, self.width) 68 | 69 | if self.renormalize: 70 | output = output * 2 - 1.0 # renormalize to [-1, 1] 71 | 72 | return output 73 | 74 | 75 | class LearnableImageParam(LearnableImage): 76 | def __init__( 77 | self, 78 | height :int=256, # Height of the learnable images 79 | width :int=256, # Width of the learnable images 80 | num_channels:int=3 , # Number of channels in the images\ 81 | **kwargs 82 | ): 83 | super().__init__(height, width, num_channels) 84 | 85 | self.model = nn.Parameter(torch.randn(num_channels, height, width)) 86 | 87 | 88 | def forward(self): 89 | x = self.model 90 | return x 91 | 92 | 93 | ################################## 94 | ######## HELPER FUNCTIONS ######## 95 | ################################## 96 | 97 | class GaussianFourierFeatureTransform(nn.Module): 98 | """ 99 | Original authors: https://github.com/ndahlquist/pytorch-fourier-feature-networks 100 | 101 | An implementation of Gaussian Fourier feature mapping. 102 | 103 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": 104 | https://arxiv.org/abs/2006.10739 105 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 106 | 107 | Given an input of size [batches, num_input_channels, width, height], 108 | returns a tensor of size [batches, num_features*2, width, height]. 109 | """ 110 | def __init__(self, num_channels, num_features=256, scale=10): 111 | #It generates fourier components of Arandom frequencies, not all of them. 112 | #The frequencies are determined by a random normal distribution, multiplied by "scale" 113 | #So, when "scale" is higher, the fourier features will have higher frequencies 114 | #In learnable_image_tutorial.ipynb, this translates to higher fidelity images. 115 | #In other words, 'scale' loosely refers to the X,Y scale of the images 116 | #With a high scale, you can learn detailed images with simple MLP's 117 | #If it's too high though, it won't really learn anything but high frequency noise 118 | 119 | super().__init__() 120 | 121 | self.num_channels = num_channels 122 | self.num_features = num_features 123 | 124 | #freqs are n-dimensional spatial frequencies, where n=num_channels 125 | self.freqs = nn.Parameter(torch.randn(num_channels, num_features) * scale, requires_grad=False) 126 | 127 | def forward(self, x): 128 | assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim()) 129 | 130 | batch_size, num_channels, height, width = x.shape 131 | 132 | assert num_channels == self.num_channels,\ 133 | "Expected input to have {} channels (got {} channels)".format(self.num_channels, num_channels) 134 | 135 | # Make shape compatible for matmul with freqs. 136 | # From [B, C, H, W] to [(B*H*W), C]. 137 | x = x.permute(0, 2, 3, 1).reshape(batch_size * height * width, num_channels) 138 | 139 | # [(B*H*W), C] x [C, F] = [(B*H*W), F] 140 | x = x @ self.freqs 141 | 142 | # From [(B*H*W), F] to [B, H, W, F] 143 | x = x.view(batch_size, height, width, self.num_features) 144 | # From [B, H, W, F] to [B, F, H, W 145 | x = x.permute(0, 3, 1, 2) 146 | 147 | x = 2 * torch.pi * x 148 | 149 | output = torch.cat([torch.sin(x), torch.cos(x)], dim=1) 150 | 151 | assert output.shape==(batch_size, 2*self.num_features, height, width) 152 | 153 | return output 154 | 155 | 156 | def get_uv_grid(height:int, width:int, batch_size:int=1)->torch.Tensor: 157 | #Returns a torch cpu tensor of shape (batch_size,2,height,width) 158 | #Note: batch_size can probably be removed from this function after refactoring this file. It's always 1 in all usages. 159 | #The second dimension is (x,y) coordinates, which go from [0 to 1) from edge to edge 160 | #(In other words, it will include x=y=0, but instead of x=y=1 the other corner will be x=y=.999) 161 | #(this is so it doesn't wrap around the texture 360 degrees) 162 | 163 | # import pdb; pdb.set_trace() 164 | assert height>0 and width>0 and batch_size>0,'All dimensions must be positive integers' 165 | 166 | y_coords = np.linspace(0, 1, height, endpoint=False) 167 | x_coords = np.linspace(0, 1, width , endpoint=False) 168 | 169 | uv_grid = np.stack(np.meshgrid(y_coords, x_coords), -1) 170 | uv_grid = torch.tensor(uv_grid).unsqueeze(0).permute(0, 3, 2, 1).float().contiguous() 171 | uv_grid = uv_grid.repeat(batch_size,1,1,1) 172 | 173 | assert tuple(uv_grid.shape)==(batch_size,2,height,width) 174 | 175 | return uv_grid 176 | 177 | -------------------------------------------------------------------------------- /src/colorization/colorizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import os 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms.functional as TF 10 | from torchvision.utils import save_image 11 | 12 | from diffusers import DiffusionPipeline 13 | 14 | 15 | import rootutils 16 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 17 | 18 | from src.colorization.views import ColorLView, ColorABView 19 | from src.colorization.samplers import sample_stage_1, sample_stage_2 20 | 21 | 22 | 23 | class FactorizedColorization(nn.Module): 24 | '''Colorization diffusion model by Factorized Diffusion 25 | ''' 26 | def __init__( 27 | self, 28 | inverse_color=False, 29 | **kwargs 30 | ): 31 | super().__init__() 32 | 33 | # Make DeepFloyd IF stage I 34 | self.stage_1 = DiffusionPipeline.from_pretrained( 35 | "DeepFloyd/IF-I-M-v1.0", 36 | variant="fp16", 37 | torch_dtype=torch.float16 38 | ) 39 | self.stage_1.enable_model_cpu_offload() 40 | 41 | # Make DeepFloyd IF stage II 42 | self.stage_2 = DiffusionPipeline.from_pretrained( 43 | "DeepFloyd/IF-II-M-v1.0", 44 | text_encoder=None, 45 | variant="fp16", 46 | torch_dtype=torch.float16, 47 | ) 48 | self.stage_2.enable_model_cpu_offload() 49 | 50 | # if inverse the gray scale 51 | self.inverse_color = inverse_color 52 | 53 | # get views 54 | self.views = [ColorLView(), ColorABView()] 55 | 56 | self.num_inference_steps = kwargs.get("num_inference_steps", 30) 57 | self.guidance_scale = kwargs.get("guidance_scale", 10.0) 58 | self.start_diffusion_step = kwargs.get("start_diffusion_step", 0) 59 | self.noise_level = kwargs.get("noise_level", 50) 60 | 61 | 62 | @torch.no_grad() 63 | def get_text_embeds(self, prompt): 64 | # Get prompt embeddings (need two, because code is designed for 65 | # two components: L and ab) 66 | prompts = [prompt] * 2 67 | prompt_embeds = [self.stage_1.encode_prompt(p) for p in prompts] 68 | prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds) 69 | prompt_embeds = torch.cat(prompt_embeds) 70 | negative_prompt_embeds = torch.cat(negative_prompt_embeds) # These are just null embeds 71 | return prompt_embeds, negative_prompt_embeds 72 | 73 | def forward( 74 | self, 75 | gray_im, 76 | prompt, 77 | num_inference_steps=None, 78 | guidance_scale=None, 79 | start_diffusion_step=None, 80 | noise_level=None, 81 | generator=None 82 | ): 83 | # 1. overwrite the hyparams if provided 84 | num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps 85 | guidance_scale = self.guidance_scale if guidance_scale is None else guidance_scale 86 | start_diffusion_step = self.start_diffusion_step if start_diffusion_step is None else start_diffusion_step 87 | noise_level = self.noise_level if noise_level is None else noise_level 88 | 89 | # 2. prepare the text embeddings 90 | prompt_embeds, negative_prompt_embeds = self.get_text_embeds(prompt) 91 | 92 | # import pdb; pdb.set_trace() 93 | 94 | # 3. prepare grayscale image 95 | _, height, width = gray_im.shape 96 | if self.inverse_color: 97 | gray_im = 1.0 - gray_im 98 | 99 | gray_im = gray_im * 2.0 - 1 # normalize the pixel value 100 | 101 | # 4. Sample 64x64 image 102 | image = sample_stage_1( 103 | self.stage_1, 104 | prompt_embeds, 105 | negative_prompt_embeds, 106 | self.views, 107 | height=height // 4, 108 | width=width // 4, 109 | fixed_im=gray_im, 110 | num_inference_steps=num_inference_steps, 111 | guidance_scale=guidance_scale, 112 | generator=generator, 113 | start_diffusion_step=start_diffusion_step 114 | ) 115 | 116 | # 5. Sample 256x256 image, by upsampling 64x64 image 117 | image = sample_stage_2( 118 | self.stage_2, 119 | image, 120 | prompt_embeds, 121 | negative_prompt_embeds, 122 | self.views, 123 | height=height, 124 | width=width, 125 | fixed_im=gray_im, 126 | num_inference_steps=num_inference_steps, 127 | guidance_scale=guidance_scale, 128 | noise_level=noise_level, 129 | generator=generator 130 | ) 131 | 132 | # 6. return the final image 133 | image = image / 2 + 0.5 134 | return image 135 | 136 | 137 | # python colorizer.py --name colorize.castle.full --gray_im_path ./imgs/castle.full.png --prompt "a colorful photo of a white castle with bell towers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 138 | 139 | # python colorizer.py --name colorize.racing.full --gray_im_path ./imgs/racing.full.png --prompt "a colorful photo of a auto racing game" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 140 | 141 | # python colorizer.py --name colorize.tiger.full --gray_im_path ./imgs/tiger.full.png --prompt "a colorful photo of a tigers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7 142 | 143 | # python colorizer.py --name colorize.dog.full --gray_im_path ./imgs/dog.full.png --prompt "a colorful photo of puppies on green grass" --num_samples 4 --guidance_scale 10.0 --num_inference_steps 30 --start_diffusion_step 7 144 | 145 | # python colorizer.py --name colorize.spec.full --gray_im_path ./imgs/spec.full.png --prompt "a colorful photo of kittens" --num_samples 8 --guidance_scale 10.0 --num_inference_steps 30 --start_diffusion_step 0 146 | 147 | if __name__ == '__main__': 148 | # Parse args 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--name", required=True, type=str) 151 | parser.add_argument("--gray_im_path", required=True, type=str) 152 | parser.add_argument("--prompt", required=True, type=str, help='Prompts to use for colorization') 153 | parser.add_argument("--num_samples", type=int, default=4) 154 | parser.add_argument("--guidance_scale", type=float, default=10.0) 155 | parser.add_argument("--num_inference_steps", type=int, default=30) 156 | parser.add_argument("--seed", type=int, default=0) 157 | parser.add_argument("--device", type=str, default='cuda') 158 | parser.add_argument("--noise_level", type=int, default=50, help='Noise level for stage 2') 159 | parser.add_argument("--start_diffusion_step", type=int, default=7, help='What step to start the diffusion process') 160 | 161 | 162 | args = parser.parse_args() 163 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 164 | 165 | # create diffusion colorization instance 166 | colorizer = FactorizedColorization( 167 | inverse_color=False, 168 | num_inference_steps=args.num_inference_steps, 169 | guidance_scale=args.guidance_scale, 170 | start_diffusion_step=args.start_diffusion_step, 171 | noise_level=args.noise_level, 172 | ).to(device) 173 | 174 | # Load gray image 175 | gray_im = Image.open(args.gray_im_path) 176 | gray_im = TF.to_tensor(gray_im).to(device) 177 | 178 | save_dir = os.path.join('results', args.name) 179 | os.makedirs(save_dir, exist_ok=True) 180 | 181 | # Sample illusions 182 | for i in tqdm(range(args.num_samples), desc="Sampling images"): 183 | generator = torch.manual_seed(args.seed + i) 184 | image = colorizer(gray_im, args.prompt, generator=generator) 185 | save_image(image, f'{save_dir}/{i:04}.png', padding=0) 186 | -------------------------------------------------------------------------------- /src/guidance/deepfloyd.py: -------------------------------------------------------------------------------- 1 | from diffusers import IFPipeline, DDPMScheduler 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from lightning import seed_everything 8 | 9 | 10 | class DeepfloydGuidance(nn.Module): 11 | def __init__( 12 | self, 13 | repo_id='DeepFloyd/IF-I-XL-v1.0', 14 | fp16=True, 15 | t_range=[0.02, 0.98], 16 | t_consistent=False, 17 | **kwargs 18 | ): 19 | super().__init__() 20 | 21 | self.repo_id = repo_id 22 | self.precision_t = torch.float16 if fp16 else torch.float32 23 | 24 | # Create model 25 | pipe = IFPipeline.from_pretrained(repo_id, torch_dtype=self.precision_t) 26 | 27 | self.unet = pipe.unet 28 | self.tokenizer = pipe.tokenizer 29 | self.text_encoder = pipe.text_encoder 30 | self.unet = pipe.unet 31 | self.scheduler = pipe.scheduler 32 | 33 | self.pipe = pipe 34 | 35 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps 36 | self.min_step = int(self.num_train_timesteps * t_range[0]) 37 | self.max_step = int(self.num_train_timesteps * t_range[1]) 38 | self.t_consistent = t_consistent 39 | 40 | self.register_buffer('alphas', self.scheduler.alphas_cumprod) # for convenience 41 | 42 | @torch.no_grad() 43 | def get_text_embeds(self, prompt, device): 44 | # prompt: [str] 45 | prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) 46 | inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt') 47 | embeddings = self.text_encoder(inputs.input_ids.to(device))[0] 48 | embeddings = embeddings.to(dtype=self.text_encoder.dtype, device=device) 49 | return embeddings 50 | 51 | 52 | def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, t=None, grad_scale=1): 53 | pred_rgb = pred_rgb.to(self.unet.dtype) 54 | # [0, 1] to [-1, 1] and make sure shape is [64, 64] 55 | images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 56 | 57 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level 58 | if t is None: 59 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level 60 | if self.t_consistent: 61 | t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=images.device) 62 | t = t.repeat(images.shape[0]) 63 | else: 64 | t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=images.device) 65 | else: 66 | t = t.to(dtype=torch.long, device=images.device) 67 | 68 | # predict the noise residual with unet, NO grad! 69 | with torch.no_grad(): 70 | # add noise 71 | noise = torch.randn_like(images) 72 | images_noisy = self.scheduler.add_noise(images, noise, t) 73 | 74 | # pred noise 75 | model_input = torch.cat([images_noisy] * 2) 76 | model_input = self.scheduler.scale_model_input(model_input, t) 77 | tt = torch.cat([t] * 2) 78 | noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample 79 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 80 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) 81 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) 82 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 83 | 84 | # w(t), sigma_t^2 85 | w = (1 - self.alphas[t]) 86 | grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) 87 | grad = torch.nan_to_num(grad) 88 | 89 | targets = (images - grad).detach() 90 | loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0] 91 | 92 | return loss 93 | 94 | 95 | @torch.no_grad() 96 | def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5): 97 | 98 | images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype) 99 | images = images * self.scheduler.init_noise_sigma 100 | 101 | self.scheduler.set_timesteps(num_inference_steps) 102 | 103 | for i, t in enumerate(self.scheduler.timesteps): 104 | # expand the image if we are doing classifier-free guidance to avoid doing two forward passes. 105 | model_input = torch.cat([images] * 2) 106 | model_input = self.scheduler.scale_model_input(model_input, t) 107 | 108 | # predict the noise residual 109 | noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample 110 | 111 | # perform guidance 112 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 113 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) 114 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) 115 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 116 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) 117 | 118 | # compute the previous noisy sample x_t -> x_t-1 119 | images = self.scheduler.step(noise_pred, t, images).prev_sample 120 | 121 | images = (images + 1) / 2 122 | 123 | return images 124 | 125 | def prompt_to_img(self, prompts, negative_prompts='', height=64, width=64, num_inference_steps=50, guidance_scale=7.5, device=None): 126 | 127 | if isinstance(prompts, str): 128 | prompts = [prompts] 129 | 130 | if isinstance(negative_prompts, str): 131 | negative_prompts = [negative_prompts] 132 | 133 | # Prompts -> text embeds 134 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768] 135 | neg_embeds = self.get_text_embeds(negative_prompts, device) 136 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] 137 | 138 | # Text embeds -> img 139 | imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] 140 | 141 | # Img to Numpy 142 | imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() 143 | imgs = (imgs * 255).round().astype('uint8') 144 | 145 | return imgs 146 | 147 | 148 | if __name__ == '__main__': 149 | import argparse 150 | from PIL import Image 151 | import os 152 | 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('--prompt', type=str, default='an oil paint of modern city, street view') 155 | parser.add_argument('--negative', default='', type=str) 156 | parser.add_argument('--repo_id', type=str, default='DeepFloyd/IF-I-XL-v1.0', help="stable diffusion version") 157 | parser.add_argument('--fp16', action='store_true', help="use float16 for training") 158 | parser.add_argument('--H', type=int, default=64) 159 | parser.add_argument('--W', type=int, default=64) 160 | parser.add_argument('--seed', type=int, default=0) 161 | parser.add_argument('--steps', type=int, default=50) 162 | opt = parser.parse_args() 163 | 164 | seed_everything(opt.seed) 165 | 166 | device = torch.device('cuda') 167 | 168 | sd = DeepfloydGuidance(repo_id=opt.repo_id, fp16=opt.fp16).to(device) 169 | 170 | imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, device=device) 171 | 172 | # visualize image 173 | save_path = 'logs/test' 174 | os.makedirs(save_path, exist_ok=True) 175 | image = Image.fromarray(imgs[0], mode='RGB') 176 | image.save(os.path.join(save_path, f'{opt.prompt}.png')) -------------------------------------------------------------------------------- /src/main_imprint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from tqdm import tqdm 4 | import warnings 5 | import soundfile as sf 6 | import numpy as np 7 | import shutil 8 | import glob 9 | import copy 10 | import matplotlib.pyplot as plt 11 | 12 | import hydra 13 | from omegaconf import OmegaConf, DictConfig, open_dict 14 | import torch 15 | import torch.nn as nn 16 | from transformers import logging 17 | from lightning import seed_everything 18 | 19 | from torchvision.utils import save_image 20 | 21 | 22 | import rootutils 23 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 24 | 25 | from src.utils.rich_utils import print_config_tree 26 | from src.utils.animation_with_text import create_animation_with_text, create_single_image_animation_with_text 27 | from src.utils.re_ranking import select_top_k_ranking, select_top_k_clip_ranking 28 | from src.utils.pylogger import RankedLogger 29 | log = RankedLogger(__name__, rank_zero_only=True) 30 | 31 | 32 | 33 | def save_audio(audio, save_path): 34 | sf.write(save_path, audio, samplerate=16000) 35 | 36 | 37 | def encode_prompt(prompt, diffusion_guidance, device, negative_prompt='', time_repeat=1): 38 | '''Encode text prompts into embeddings 39 | ''' 40 | prompts = [prompt] * time_repeat 41 | negative_prompts = [negative_prompt] * time_repeat 42 | 43 | # Prompts -> text embeds 44 | cond_embeds = diffusion_guidance.get_text_embeds(prompts, device) # [B, 77, 768] 45 | uncond_embeds = diffusion_guidance.get_text_embeds(negative_prompts, device) # [B, 77, 768] 46 | text_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) # [2 * B, 77, 768] 47 | return text_embeds 48 | 49 | def estimate_noise(diffusion, latents, t, text_embeddings, guidance_scale): 50 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 51 | latent_model_input = torch.cat([latents] * 2) 52 | # predict the noise residual 53 | noise_pred = diffusion.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] 54 | 55 | # perform guidance 56 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 57 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 58 | return noise_pred 59 | 60 | 61 | 62 | @hydra.main(version_base="1.3", config_path="../configs/main_imprint", config_name="main.yaml") 63 | def main(cfg: DictConfig) -> Optional[float]: 64 | """Main function for training 65 | """ 66 | 67 | if cfg.extras.get("ignore_warnings"): 68 | log.info("Disabling python warnings! ") 69 | warnings.filterwarnings("ignore") 70 | logging.set_verbosity_error() 71 | 72 | if cfg.extras.get("print_config"): 73 | print_config_tree(cfg, resolve=True) 74 | 75 | # set seed for random number generators in pytorch, numpy and python.random 76 | if cfg.get("seed"): 77 | seed_everything(cfg.seed, workers=True) 78 | 79 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 80 | 81 | log.info(f"Instantiating Image Diffusion model <{cfg.image_diffusion_guidance._target_}>") 82 | image_diffusion_guidance = hydra.utils.instantiate(cfg.image_diffusion_guidance).to(device) 83 | 84 | log.info(f"Instantiating Audio Diffusion guidance model <{cfg.audio_diffusion_guidance._target_}>") 85 | audio_diffusion_guidance = hydra.utils.instantiate(cfg.audio_diffusion_guidance).to(device) 86 | 87 | # create transformation 88 | log.info(f"Instantiating latent transformation <{cfg.latent_transformation._target_}>") 89 | latent_transformation = hydra.utils.instantiate(cfg.latent_transformation).to(device) 90 | 91 | # create audio evaluator 92 | if cfg.audio_evaluator: 93 | log.info(f"Instantiating audio evaluator <{cfg.audio_evaluator._target_}>") 94 | audio_evaluator = hydra.utils.instantiate(cfg.audio_evaluator).to(device) 95 | else: 96 | audio_evaluator = None 97 | 98 | if cfg.visual_evaluator: 99 | log.info(f"Instantiating visual evaluator <{cfg.visual_evaluator._target_}>") 100 | visual_evaluator = hydra.utils.instantiate(cfg.visual_evaluator).to(device) 101 | else: 102 | visual_evaluator = None 103 | 104 | clip_scores = [] 105 | clap_scores = [] 106 | log.info(f"Starting sampling!") 107 | for idx in tqdm(range(cfg.trainer.num_samples), desc='Sampling'): 108 | clip_score, clap_score = create_sample(cfg, image_diffusion_guidance, audio_diffusion_guidance, latent_transformation, visual_evaluator, audio_evaluator, idx, device) 109 | clip_scores.append(clip_score) 110 | clap_scores.append(clap_score) 111 | 112 | # re-ranking by metrics 113 | enable_rank = cfg.trainer.get("enable_rank", False) 114 | if enable_rank: 115 | log.info(f"Starting re-ranking and selection!") 116 | select_top_k_ranking(cfg, clip_scores, clap_scores) 117 | 118 | enable_clip_rank = cfg.trainer.get("enable_clip_rank", False) 119 | if enable_clip_rank: 120 | log.info(f"Starting re-ranking and selection by CLIP score!") 121 | select_top_k_clip_ranking(cfg, clip_scores) 122 | 123 | log.info(f"Finished!") 124 | 125 | 126 | @torch.no_grad() 127 | def create_sample(cfg, image_diffusion, audio_diffusion, latent_transformation, visual_evaluator, audio_evaluator, idx, device): 128 | image_guidance_scale, audio_guidance_scale = cfg.trainer.image_guidance_scale, cfg.trainer.audio_guidance_scale 129 | height, width = cfg.trainer.img_height, cfg.trainer.img_width 130 | inverse_image = cfg.trainer.get("inverse_image", False) 131 | use_colormap = cfg.trainer.get("use_colormap", False) 132 | crop_image = cfg.trainer.get("crop_image", False) 133 | 134 | generator = torch.manual_seed(cfg.seed + idx) 135 | 136 | # obtain the image and spec for each modality's diffusion process 137 | image = image_diffusion.prompt_to_img(cfg.trainer.image_prompt, negative_prompts=cfg.trainer.image_neg_prompt, height=height, width=width, num_inference_steps=50, guidance_scale=image_guidance_scale, device=device, generator=generator) 138 | image = image.mean(dim=1) # make grayscale image 139 | 140 | spec = audio_diffusion.prompt_to_spec(cfg.trainer.audio_prompt, negative_prompts=cfg.trainer.audio_neg_prompt, height=height, width=width, num_inference_steps=100, guidance_scale=audio_guidance_scale, device=device, generator=generator) 141 | spec = spec.mean(dim=1) # make a single channel 142 | 143 | # perform the naive baseline 144 | mag_ratio = cfg.trainer.get("mag_ratio", 0.5) 145 | if inverse_image: 146 | image = 1.0 - image 147 | image_mask = 1 - mag_ratio * image 148 | spec_new = spec * image_mask 149 | img = image 150 | # import pdb; pdb.set_trace() 151 | audio = audio_diffusion.spec_to_audio(spec_new) 152 | audio = np.ravel(audio) 153 | 154 | if crop_image: 155 | pixel = 32 156 | audio_length = int(pixel / width * audio.shape[0]) 157 | img = img[..., :-pixel] 158 | spec = spec[..., :-pixel] 159 | spec_new = spec_new[..., :-pixel] 160 | audio = audio[:-audio_length] 161 | 162 | # evaluate with CLIP 163 | if visual_evaluator is not None: 164 | clip_score = visual_evaluator(img.repeat(3, 1, 1).unsqueeze(0), cfg.trainer.image_prompt) 165 | else: 166 | clip_score = None 167 | 168 | # evaluate with CLAP 169 | if audio_evaluator is not None: 170 | clap_score = audio_evaluator(cfg.trainer.audio_prompt, audio) 171 | else: 172 | clap_score = None 173 | 174 | sample_dir = os.path.join(cfg.output_dir, 'results', f'example_{str(idx+1).zfill(3)}') 175 | os.makedirs(sample_dir, exist_ok=True) 176 | 177 | # import pdb; pdb.set_trace() 178 | # save config with example-specific information 179 | cfg_save_path = os.path.join(sample_dir, 'config.yaml') 180 | current_cfg = copy.deepcopy(cfg) 181 | current_cfg.seed = cfg.seed + idx 182 | with open_dict(current_cfg): 183 | current_cfg.clip_score = clip_score 184 | current_cfg.clap_score = clap_score 185 | OmegaConf.save(current_cfg, cfg_save_path) 186 | 187 | # save image 188 | img_save_path = os.path.join(sample_dir, f'img.png') 189 | save_image(img, img_save_path) 190 | 191 | # save audio 192 | audio_save_path = os.path.join(sample_dir, f'audio.wav') 193 | save_audio(audio, audio_save_path) 194 | 195 | # save spec 196 | spec_save_path = os.path.join(sample_dir, f'spec_ori.png') 197 | save_image(spec, spec_save_path) 198 | 199 | spec_save_path = os.path.join(sample_dir, f'spec.png') 200 | save_image(spec_new, spec_save_path) 201 | 202 | # save spec with colormap (renormalize the spectrogram range) 203 | if use_colormap: 204 | spec_save_path = os.path.join(sample_dir, f'spec_colormap.png') 205 | spec_colormap = spec_new.mean(dim=0).cpu().numpy() 206 | plt.imsave(spec_save_path, spec_colormap, cmap='gray') 207 | 208 | 209 | # save video 210 | video_output_path = os.path.join(sample_dir, f'video.mp4') 211 | if img.shape[-2:] == spec.shape[-2:]: 212 | create_single_image_animation_with_text(spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt) 213 | else: 214 | create_animation_with_text(img_save_path, spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt) 215 | return clip_score, clap_score 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /src/main_sds.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from tqdm import tqdm 4 | import warnings 5 | import soundfile as sf 6 | import numpy as np 7 | import shutil 8 | import matplotlib.pyplot as plt 9 | 10 | import hydra 11 | from omegaconf import OmegaConf, DictConfig 12 | import torch 13 | import torch.nn as nn 14 | from transformers import logging 15 | from lightning import seed_everything 16 | 17 | from torchvision.utils import save_image 18 | 19 | 20 | import rootutils 21 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 22 | 23 | from src.utils.rich_utils import print_config_tree 24 | from src.utils.animation_with_text import create_animation_with_text, create_single_image_animation_with_text 25 | from src.utils.pylogger import RankedLogger 26 | log = RankedLogger(__name__, rank_zero_only=True) 27 | 28 | 29 | def save_model(output_dir, step, net): 30 | save_dir = os.path.join(output_dir, 'checkpoints') 31 | os.makedirs(save_dir, exist_ok=True) 32 | path = os.path.join(save_dir, 'checkpoint_latest.pth.tar') 33 | torch.save( 34 | { 35 | 'step': step, 36 | 'state_dict': net.state_dict(), 37 | }, 38 | path 39 | ) 40 | 41 | def save_audio(audio, save_path): 42 | sf.write(save_path, audio, samplerate=16000) 43 | 44 | 45 | def encode_prompt(prompt, diffusion_guidance, device, time_repeat=1): 46 | '''Encode text prompts into embeddings 47 | ''' 48 | prompts = [prompt] * time_repeat 49 | null_prompts = [''] * time_repeat 50 | 51 | # Prompts -> text embeds 52 | cond_embeds = diffusion_guidance.get_text_embeds(prompts, device) # [B, 77, 768] 53 | uncond_embeds = diffusion_guidance.get_text_embeds(null_prompts, device) # [B, 77, 768] 54 | text_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) # [2 * B, 77, 768] 55 | return text_embeds 56 | 57 | 58 | 59 | @hydra.main(version_base="1.3", config_path="../configs/main_sds", config_name="main.yaml") 60 | def main(cfg: DictConfig) -> Optional[float]: 61 | """Main function for training 62 | """ 63 | 64 | if cfg.extras.get("ignore_warnings"): 65 | log.info("Disabling python warnings! ") 66 | warnings.filterwarnings("ignore") 67 | logging.set_verbosity_error() 68 | 69 | if cfg.extras.get("print_config"): 70 | print_config_tree(cfg, resolve=True) 71 | 72 | # set seed for random number generators in pytorch, numpy and python.random 73 | if cfg.get("seed"): 74 | seed_everything(cfg.seed, workers=True) 75 | 76 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 77 | 78 | log.info(f"Instantiating Image Learner model <{cfg.image_learner._target_}>") 79 | image_learner = hydra.utils.instantiate(cfg.image_learner).to(device) 80 | 81 | log.info(f"Instantiating Image Diffusion model <{cfg.image_diffusion_guidance._target_}>") 82 | image_diffusion_guidance = hydra.utils.instantiate(cfg.image_diffusion_guidance).to(device) 83 | 84 | log.info(f"Instantiating Audio Diffusion guidance model <{cfg.audio_diffusion_guidance._target_}>") 85 | audio_diffusion_guidance = hydra.utils.instantiate(cfg.audio_diffusion_guidance).to(device) 86 | 87 | # create optimizer 88 | log.info(f"Instantiating optimizer <{cfg.optimizer._target_}>") 89 | optimizer = hydra.utils.instantiate(cfg.optimizer) 90 | optimizer = optimizer(params=image_learner.parameters()) 91 | 92 | # create transformation 93 | log.info(f"Instantiating image transformation <{cfg.image_transformation._target_}>") 94 | image_transformation = hydra.utils.instantiate(cfg.image_transformation).to(device) 95 | 96 | log.info(f"Instantiating audio transformation <{cfg.audio_transformation._target_}>") 97 | audio_transformation = hydra.utils.instantiate(cfg.audio_transformation).to(device) 98 | 99 | log.info(f"Starting training!") 100 | trainer(cfg, image_learner, image_diffusion_guidance, audio_diffusion_guidance, optimizer, image_transformation, audio_transformation, device) 101 | 102 | 103 | def trainer(cfg, image_learner, image_diffusion, audio_diffusion, optimizer, image_transformation, audio_transformation, device): 104 | image_guidance_scale, audio_guidance_scale = cfg.trainer.image_guidance_scale, cfg.trainer.audio_guidance_scale 105 | # image_weight, audio_weight = cfg.trainer.image_weight, cfg.trainer.audio_weight 106 | image_start_step = cfg.trainer.get("image_start_step", 0) 107 | audio_start_step = cfg.trainer.get("audio_start_step", 0) 108 | 109 | accumulate_grad_batches = cfg.trainer.get("accumulate_grad_batches", 1) 110 | use_colormap = cfg.trainer.get("use_colormap", False) 111 | crop_image = cfg.trainer.get("crop_image", False) 112 | 113 | image_text_embeds = encode_prompt(cfg.trainer.image_prompt, image_diffusion, device, time_repeat=cfg.trainer.batch_size) 114 | audio_text_embeds = encode_prompt(cfg.trainer.audio_prompt, audio_diffusion, device, time_repeat=1) 115 | 116 | image_learner.train() 117 | for step in tqdm(range(cfg.trainer.num_iteration), desc="Training"): 118 | # import pdb; pdb.set_trace() 119 | image = image_learner() # [C, H, W] 120 | images = image.unsqueeze(0) # (1, C, H, W) 121 | 122 | # perform image guidance 123 | rgb_images = image_transformation(images) # (B, C, h, w) 124 | 125 | # perform audio guidance 126 | spec_images = audio_transformation(images) # (1, 1, H, W) 127 | 128 | if step >= image_start_step: 129 | image_weight = cfg.trainer.image_weight 130 | image_loss = image_diffusion.train_step(image_text_embeds, rgb_images, guidance_scale=image_guidance_scale, grad_scale=1) 131 | else: 132 | image_weight = 0.0 133 | image_loss = torch.tensor(0.0).to(device) 134 | 135 | if step >= audio_start_step: 136 | audio_weight = cfg.trainer.audio_weight 137 | audio_loss = audio_diffusion.train_step(audio_text_embeds, spec_images, guidance_scale=audio_guidance_scale, grad_scale=1) 138 | else: 139 | audio_weight = 0.0 140 | audio_loss = torch.tensor(0.0).to(device) 141 | 142 | loss = image_weight * image_loss + audio_weight * audio_loss 143 | loss = loss / accumulate_grad_batches 144 | loss.backward() 145 | 146 | # apply gradient accumulation 147 | if (step + 1) % accumulate_grad_batches == 0 or (step + 1) == cfg.trainer.num_iteration: 148 | optimizer.step() 149 | optimizer.zero_grad() 150 | 151 | tqdm.write(f"Iteration: {step+1}/{cfg.trainer.num_iteration}, loss: {loss.item():.4f} | visual loss: {image_loss.item():.4f} | audio loss: {audio_loss.item():.4f}") 152 | 153 | if (step + 1) % cfg.trainer.save_step == 0: 154 | save_model(cfg.output_dir, step, image_learner) 155 | 156 | if (step + 1) % cfg.trainer.visualize_step == 0: 157 | img_save_dir = os.path.join(cfg.output_dir, 'image_results') 158 | os.makedirs(img_save_dir, exist_ok=True) 159 | img_save_path = os.path.join(img_save_dir, f'img_{str(step+1).zfill(6)}.png') 160 | save_image(image, img_save_path) 161 | 162 | spec_save_dir = os.path.join(cfg.output_dir, 'spec_results') 163 | os.makedirs(spec_save_dir, exist_ok=True) 164 | spec_save_path = os.path.join(spec_save_dir, f'spec_{str(step+1).zfill(6)}.png') 165 | save_image(spec_images.squeeze(0), spec_save_path) 166 | 167 | audio_save_dir = os.path.join(cfg.output_dir, 'audio_results') 168 | os.makedirs(audio_save_dir, exist_ok=True) 169 | audio_save_path = os.path.join(audio_save_dir, f'audio_{str(step+1).zfill(6)}.wav') 170 | 171 | audio = audio_diffusion.spec_to_audio(spec_images.squeeze(0)) 172 | audio = np.ravel(audio) 173 | save_audio(audio, audio_save_path) 174 | 175 | # obtain final results 176 | img = image_learner() # [C, H, W] 177 | spec = audio_transformation(img.unsqueeze(0)).squeeze(0) # (1, H, W) 178 | audio = audio_diffusion.spec_to_audio(spec) 179 | audio = np.ravel(audio) 180 | 181 | if crop_image: 182 | pixel = 32 183 | audio_length = int(pixel / image_learner.width * audio.shape[0]) 184 | img = img[..., :-pixel] 185 | spec = spec[..., :-pixel] 186 | audio = audio[:-audio_length] 187 | 188 | # save the final results 189 | sample_dir = os.path.join(cfg.output_dir, 'results', f'final') 190 | os.makedirs(sample_dir, exist_ok=True) 191 | 192 | # save config 193 | cfg_path = os.path.join(cfg.output_dir, '.hydra', 'config.yaml') 194 | cfg_save_path = os.path.join(sample_dir, 'config.yaml') 195 | shutil.copyfile(cfg_path, cfg_save_path) 196 | 197 | # save image 198 | img_save_path = os.path.join(sample_dir, f'img.png') 199 | save_image(img, img_save_path) 200 | 201 | # save audio 202 | audio_save_path = os.path.join(sample_dir, f'audio.wav') 203 | save_audio(audio, audio_save_path) 204 | 205 | # save spec 206 | spec_save_path = os.path.join(sample_dir, f'spec.png') 207 | save_image(spec.mean(dim=0, keepdim=True), spec_save_path) 208 | 209 | if use_colormap: 210 | spec_save_path = os.path.join(sample_dir, f'spec_colormap.png') 211 | spec_colormap = spec.mean(dim=0).detach().cpu().numpy() 212 | plt.imsave(spec_save_path, spec_colormap, cmap='gray') 213 | 214 | # log.info("Generating video ...") 215 | # save video 216 | video_output_path = os.path.join(cfg.output_dir, 'video.mp4') 217 | 218 | if image_learner.num_channels == 1: 219 | create_single_image_animation_with_text(spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt) 220 | else: 221 | create_animation_with_text(img_save_path, spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt) 222 | # log.info("Generated video.") 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /src/guidance/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.utils import save_image 7 | 8 | 9 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline 10 | from diffusers.utils.import_utils import is_xformers_available 11 | 12 | from lightning import seed_everything 13 | 14 | 15 | 16 | class StableDiffusionGuidance(nn.Module): 17 | def __init__( 18 | self, 19 | repo_id='runwayml/stable-diffusion-v1-5', 20 | fp16=True, 21 | t_range=[0.02, 0.98], 22 | t_consistent=False, 23 | **kwargs 24 | ): 25 | super().__init__() 26 | 27 | self.repo_id = repo_id 28 | 29 | self.precision_t = torch.float16 if fp16 else torch.float32 30 | 31 | # Create model 32 | self.vae, self.tokenizer, self.text_encoder, self.unet = self.create_model_from_pipe(repo_id, self.precision_t) 33 | 34 | self.scheduler = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler", torch_dtype=self.precision_t) 35 | 36 | self.register_buffer('alphas_cumprod', self.scheduler.alphas_cumprod) 37 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps 38 | self.min_step = int(self.num_train_timesteps * t_range[0]) 39 | self.max_step = int(self.num_train_timesteps * t_range[1]) 40 | self.t_consistent = t_consistent 41 | 42 | def create_model_from_pipe(self, repo_id, dtype): 43 | pipe = StableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype) 44 | vae = pipe.vae 45 | tokenizer = pipe.tokenizer 46 | text_encoder = pipe.text_encoder 47 | unet = pipe.unet 48 | return vae, tokenizer, text_encoder, unet 49 | 50 | @torch.no_grad() 51 | def get_text_embeds(self, prompt, device): 52 | # prompt: [str] 53 | # import pdb; pdb.set_trace() 54 | inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') 55 | prompt_embeds = self.text_encoder(inputs.input_ids.to(device))[0] 56 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 57 | return prompt_embeds 58 | 59 | def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, t=None, grad_scale=1, save_guidance_path:Path=None): 60 | # import pdb; pdb.set_trace() 61 | pred_rgb = pred_rgb.to(self.vae.dtype) 62 | if as_latent: 63 | # latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 64 | latents = pred_rgb 65 | else: 66 | # interp to 512x512 to be fed into vae. 67 | # pred_rgb_512 = pred_rgb 68 | pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) 69 | # encode image into latents with vae, requires grad! 70 | latents = self.encode_imgs(pred_rgb_512) 71 | 72 | if t is None: 73 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level 74 | if self.t_consistent: 75 | t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=latents.device) 76 | t = t.repeat(latents.shape[0]) 77 | else: 78 | t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=latents.device) 79 | else: 80 | t = t.to(dtype=torch.long, device=latents.device) 81 | 82 | # predict the noise residual with unet, NO grad! 83 | with torch.no_grad(): 84 | # add noise 85 | noise = torch.randn_like(latents) 86 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 87 | # pred noise 88 | latent_model_input = torch.cat([latents_noisy] * 2) 89 | tt = torch.cat([t] * 2) 90 | noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample 91 | 92 | # perform guidance (high scale from paper!) 93 | noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) 94 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) 95 | 96 | # w(t), sigma_t^2 97 | w = (1 - self.alphas_cumprod[t]) 98 | grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) 99 | grad = torch.nan_to_num(grad) 100 | 101 | if save_guidance_path: 102 | with torch.no_grad(): 103 | if as_latent: 104 | pred_rgb_512 = self.decode_latents(latents) 105 | 106 | # visualize predicted denoised image 107 | # The following block of code is equivalent to `predict_start_from_noise`... 108 | # see zero123_utils.py's version for a simpler implementation. 109 | alphas = self.scheduler.alphas.to(latents.device) 110 | total_timesteps = self.max_step - self.min_step + 1 111 | index = total_timesteps - t.to(latents.device) - 1 112 | b = len(noise_pred) 113 | a_t = alphas[index].reshape(b,1,1,1).to(latents.device) 114 | sqrt_one_minus_alphas = torch.sqrt(1 - alphas) 115 | sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(latents.device) 116 | pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 117 | result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) 118 | 119 | # visualize noisier image 120 | result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) 121 | 122 | # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] 123 | viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) 124 | save_image(viz_images, save_guidance_path) 125 | 126 | targets = (latents - grad).detach() 127 | loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] 128 | 129 | return loss 130 | 131 | 132 | @torch.no_grad() 133 | def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, generator=None): 134 | 135 | if latents is None: 136 | latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=self.unet.dtype).to(text_embeddings.device) 137 | 138 | self.scheduler.set_timesteps(num_inference_steps) 139 | 140 | for i, t in enumerate(self.scheduler.timesteps): 141 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 142 | latent_model_input = torch.cat([latents] * 2) 143 | # predict the noise residual 144 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] 145 | 146 | # perform guidance 147 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 148 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 149 | 150 | # compute the previous noisy sample x_t -> x_t-1 151 | latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] 152 | 153 | return latents 154 | 155 | def decode_latents(self, latents): 156 | latents = latents.to(self.vae.dtype) 157 | latents = 1 / self.vae.config.scaling_factor * latents 158 | 159 | imgs = self.vae.decode(latents).sample 160 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 161 | 162 | return imgs 163 | 164 | def encode_imgs(self, imgs, generator=None): 165 | # imgs: [B, 3, H, W] 166 | imgs = 2 * imgs - 1 167 | posterior = self.vae.encode(imgs).latent_dist 168 | latents = posterior.sample(generator) * self.vae.config.scaling_factor 169 | return latents 170 | 171 | def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, device=None, generator=None): 172 | if isinstance(prompts, str): 173 | prompts = [prompts] 174 | 175 | if isinstance(negative_prompts, str): 176 | negative_prompts = [negative_prompts] 177 | 178 | # Prompts -> text embeds 179 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768] 180 | neg_embeds = self.get_text_embeds(negative_prompts, device) 181 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] 182 | 183 | # Text embeds -> img latents 184 | latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) # [1, 4, 64, 64] 185 | 186 | # Img latents -> imgs 187 | imgs = self.decode_latents(latents) # [1, 3, 512, 512] 188 | 189 | # # Img to Numpy 190 | # imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() 191 | # imgs = (imgs * 255).round().astype('uint8') 192 | 193 | return imgs 194 | 195 | 196 | if __name__ == '__main__': 197 | 198 | import argparse 199 | from PIL import Image 200 | import os 201 | 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument('--prompt', type=str, default='') 204 | parser.add_argument('--negative', default='', type=str) 205 | parser.add_argument('--repo_id', type=str, default='runwayml/stable-diffusion-v1-5', help="stable diffusion version") 206 | parser.add_argument('--fp16', action='store_true', help="use float16 for training") 207 | parser.add_argument('--H', type=int, default=512) 208 | parser.add_argument('--W', type=int, default=512) 209 | parser.add_argument('--seed', type=int, default=0) 210 | parser.add_argument('--steps', type=int, default=50) 211 | opt = parser.parse_args() 212 | 213 | seed_everything(opt.seed) 214 | 215 | device = torch.device('cuda') 216 | 217 | sd = StableDiffusionGuidance(repo_id=opt.repo_id, fp16=opt.fp16) 218 | sd = sd.to(device) 219 | 220 | imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, device=device) 221 | 222 | # visualize image 223 | save_path = 'logs/test' 224 | os.makedirs(save_path, exist_ok=True) 225 | image = Image.fromarray(imgs[0], mode='RGB') 226 | image.save(os.path.join(save_path, f'{opt.prompt}.png')) -------------------------------------------------------------------------------- /src/main_denoise.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from tqdm import tqdm 4 | import warnings 5 | import soundfile as sf 6 | import numpy as np 7 | import shutil 8 | import glob 9 | import copy 10 | import matplotlib.pyplot as plt 11 | 12 | import hydra 13 | from omegaconf import OmegaConf, DictConfig, open_dict 14 | import torch 15 | import torch.nn as nn 16 | from transformers import logging 17 | from lightning import seed_everything 18 | 19 | from torchvision.utils import save_image 20 | 21 | 22 | import rootutils 23 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 24 | 25 | from src.utils.rich_utils import print_config_tree 26 | from src.utils.animation_with_text import create_animation_with_text, create_single_image_animation_with_text 27 | from src.utils.re_ranking import select_top_k_ranking, select_top_k_clip_ranking 28 | from src.utils.pylogger import RankedLogger 29 | log = RankedLogger(__name__, rank_zero_only=True) 30 | 31 | 32 | def save_audio(audio, save_path): 33 | sf.write(save_path, audio, samplerate=16000) 34 | 35 | 36 | def encode_prompt(prompt, diffusion_guidance, device, negative_prompt='', time_repeat=1): 37 | '''Encode text prompts into embeddings 38 | ''' 39 | prompts = [prompt] * time_repeat 40 | negative_prompts = [negative_prompt] * time_repeat 41 | 42 | # Prompts -> text embeds 43 | cond_embeds = diffusion_guidance.get_text_embeds(prompts, device) # [B, 77, 768] 44 | uncond_embeds = diffusion_guidance.get_text_embeds(negative_prompts, device) # [B, 77, 768] 45 | text_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) # [2 * B, 77, 768] 46 | return text_embeds 47 | 48 | def estimate_noise(diffusion, latents, t, text_embeddings, guidance_scale): 49 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 50 | latent_model_input = torch.cat([latents] * 2) 51 | # predict the noise residual 52 | noise_pred = diffusion.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] 53 | 54 | # perform guidance 55 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 56 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 57 | return noise_pred 58 | 59 | 60 | @hydra.main(version_base="1.3", config_path="../configs/main_denoise", config_name="main.yaml") 61 | def main(cfg: DictConfig) -> Optional[float]: 62 | """Main function for training 63 | """ 64 | 65 | if cfg.extras.get("ignore_warnings"): 66 | log.info("Disabling python warnings! ") 67 | warnings.filterwarnings("ignore") 68 | logging.set_verbosity_error() 69 | 70 | if cfg.extras.get("print_config"): 71 | print_config_tree(cfg, resolve=True) 72 | 73 | # set seed for random number generators in pytorch, numpy and python.random 74 | if cfg.get("seed"): 75 | seed_everything(cfg.seed, workers=True) 76 | 77 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 78 | 79 | log.info(f"Instantiating Image Diffusion model <{cfg.image_diffusion_guidance._target_}>") 80 | image_diffusion_guidance = hydra.utils.instantiate(cfg.image_diffusion_guidance).to(device) 81 | 82 | log.info(f"Instantiating Audio Diffusion guidance model <{cfg.audio_diffusion_guidance._target_}>") 83 | audio_diffusion_guidance = hydra.utils.instantiate(cfg.audio_diffusion_guidance).to(device) 84 | 85 | # create the shared noise scheduler 86 | log.info(f"Instantiating joint diffusion scheduler <{cfg.diffusion_scheduler._target_}>") 87 | scheduler = hydra.utils.instantiate(cfg.diffusion_scheduler) 88 | 89 | # create transformation 90 | log.info(f"Instantiating latent transformation <{cfg.latent_transformation._target_}>") 91 | latent_transformation = hydra.utils.instantiate(cfg.latent_transformation).to(device) 92 | 93 | # create audio evaluator 94 | if cfg.audio_evaluator: 95 | log.info(f"Instantiating audio evaluator <{cfg.audio_evaluator._target_}>") 96 | audio_evaluator = hydra.utils.instantiate(cfg.audio_evaluator).to(device) 97 | else: 98 | audio_evaluator = None 99 | 100 | if cfg.visual_evaluator: 101 | log.info(f"Instantiating visual evaluator <{cfg.visual_evaluator._target_}>") 102 | visual_evaluator = hydra.utils.instantiate(cfg.visual_evaluator).to(device) 103 | else: 104 | visual_evaluator = None 105 | 106 | log.info(f"Starting sampling!") 107 | clip_scores = [] 108 | clap_scores = [] 109 | 110 | for idx in tqdm(range(cfg.trainer.num_samples), desc='Sampling'): 111 | clip_score, clap_score = denoise(cfg, image_diffusion_guidance, audio_diffusion_guidance, scheduler, latent_transformation, visual_evaluator, audio_evaluator, idx, device) 112 | clip_scores.append(clip_score) 113 | clap_scores.append(clap_score) 114 | 115 | # re-ranking by metrics 116 | enable_rank = cfg.trainer.get("enable_rank", False) 117 | if enable_rank: 118 | log.info(f"Starting re-ranking and selection!") 119 | select_top_k_ranking(cfg, clip_scores, clap_scores) 120 | 121 | enable_clip_rank = cfg.trainer.get("enable_clip_rank", False) 122 | if enable_clip_rank: 123 | log.info(f"Starting re-ranking and selection by CLIP score!") 124 | select_top_k_clip_ranking(cfg, clip_scores) 125 | 126 | log.info(f"Finished!") 127 | 128 | 129 | @torch.no_grad() 130 | def denoise(cfg, image_diffusion, audio_diffusion, scheduler, latent_transformation, visual_evaluator, audio_evaluator, idx, device): 131 | image_guidance_scale, audio_guidance_scale = cfg.trainer.image_guidance_scale, cfg.trainer.audio_guidance_scale 132 | height, width = cfg.trainer.img_height, cfg.trainer.img_width 133 | image_start_step = cfg.trainer.get("image_start_step", 0) 134 | audio_start_step = cfg.trainer.get("audio_start_step", 0) 135 | audio_weight = cfg.trainer.get("audio_weight", 0.5) 136 | use_colormap = cfg.trainer.get("use_colormap", False) 137 | 138 | cutoff_latent = cfg.trainer.get("cutoff_latent", False) 139 | crop_image = cfg.trainer.get("crop_image", False) 140 | 141 | generator = torch.manual_seed(cfg.seed + idx) 142 | 143 | # obtain the text embeddings for each modality's diffusion process 144 | image_text_embeds = encode_prompt(cfg.trainer.image_prompt, image_diffusion, device, negative_prompt=cfg.trainer.image_neg_prompt, time_repeat=1) 145 | audio_text_embeds = encode_prompt(cfg.trainer.audio_prompt, audio_diffusion, device, negative_prompt=cfg.trainer.audio_neg_prompt, time_repeat=1) 146 | 147 | scheduler.set_timesteps(cfg.trainer.num_inference_steps) 148 | 149 | # init random latents 150 | latents = torch.randn((image_text_embeds.shape[0] // 2, image_diffusion.unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=image_diffusion.precision_t).to(device) 151 | 152 | for i, t in enumerate(scheduler.timesteps): 153 | if i >= image_start_step: 154 | image_noise = estimate_noise(image_diffusion, latents, t, image_text_embeds, image_guidance_scale) 155 | else: 156 | image_noise = None 157 | 158 | if i >= audio_start_step: 159 | transform_latents = latent_transformation(latents, inverse=False) 160 | audio_noise = estimate_noise(audio_diffusion, transform_latents, t, audio_text_embeds, audio_guidance_scale) 161 | audio_noise = latent_transformation(audio_noise, inverse=True) 162 | else: 163 | audio_noise = None 164 | 165 | if image_noise is not None and audio_noise is not None: 166 | noise_pred = (1.0 - audio_weight) * image_noise + audio_weight * audio_noise 167 | elif image_noise is not None and audio_noise is None: 168 | noise_pred = image_noise 169 | elif image_noise is None and audio_noise is not None: 170 | noise_pred = audio_noise 171 | else: 172 | log.info("No estimated noise! Exit.") 173 | raise NotImplementedError 174 | 175 | # compute the previous noisy sample x_t -> x_t-1 176 | latents = scheduler.step(noise_pred, t, latents)['prev_sample'] 177 | 178 | if cutoff_latent and not crop_image: 179 | latents = latents[..., :-4] # we cut off 4 latents so that we can directly remove the black region 180 | 181 | # Img latents -> imgs 182 | img = image_diffusion.decode_latents(latents) # [1, 3, H, W] 183 | 184 | # Img latents -> audio 185 | audio_latents = latent_transformation(latents, inverse=False) 186 | spec = audio_diffusion.decode_latents(audio_latents).squeeze(0) # [3, 256, 1024] 187 | audio = audio_diffusion.spec_to_audio(spec) 188 | audio = np.ravel(audio) 189 | 190 | if crop_image and not cutoff_latent: 191 | pixel = 32 192 | audio_length = int(pixel / width * audio.shape[0]) 193 | img = img[..., :-pixel] 194 | spec = spec[..., :-pixel] 195 | audio = audio[:-audio_length] 196 | 197 | # evaluate with CLIP 198 | if visual_evaluator is not None: 199 | clip_score = visual_evaluator(img, cfg.trainer.image_prompt) 200 | else: 201 | clip_score = None 202 | 203 | # evaluate with CLAP 204 | if audio_evaluator is not None: 205 | clap_score = audio_evaluator(cfg.trainer.audio_prompt, audio) 206 | else: 207 | clap_score = None 208 | 209 | sample_dir = os.path.join(cfg.output_dir, 'results', f'example_{str(idx+1).zfill(3)}') 210 | os.makedirs(sample_dir, exist_ok=True) 211 | 212 | # save config with example-specific information 213 | cfg_save_path = os.path.join(sample_dir, 'config.yaml') 214 | current_cfg = copy.deepcopy(cfg) 215 | current_cfg.seed = cfg.seed + idx 216 | with open_dict(current_cfg): 217 | current_cfg.clip_score = clip_score 218 | current_cfg.clap_score = clap_score 219 | OmegaConf.save(current_cfg, cfg_save_path) 220 | 221 | # save image 222 | img_save_path = os.path.join(sample_dir, f'img.png') 223 | save_image(img, img_save_path) 224 | 225 | # save audio 226 | audio_save_path = os.path.join(sample_dir, f'audio.wav') 227 | save_audio(audio, audio_save_path) 228 | 229 | # save spec 230 | spec_save_path = os.path.join(sample_dir, f'spec.png') 231 | save_image(spec.mean(dim=0, keepdim=True), spec_save_path) 232 | 233 | # save spec with colormap (renormalize the spectrogram range) 234 | if use_colormap: 235 | spec_save_path = os.path.join(sample_dir, f'spec_colormap.png') 236 | spec_colormap = spec.mean(dim=0).cpu().numpy() 237 | plt.imsave(spec_save_path, spec_colormap, cmap='gray') 238 | 239 | # save video 240 | video_output_path = os.path.join(sample_dir, f'video.mp4') 241 | if img.shape[-2:] == spec.shape[-2:]: 242 | create_single_image_animation_with_text(spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt) 243 | else: 244 | create_animation_with_text(img_save_path, spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt) 245 | 246 | return clip_score, clap_score 247 | 248 | 249 | if __name__ == "__main__": 250 | main() 251 | -------------------------------------------------------------------------------- /src/utils/animation_with_text.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import soundfile as sf 5 | import subprocess 6 | from omegaconf import OmegaConf, DictConfig 7 | 8 | 9 | from moviepy.editor import VideoClip, ImageClip, TextClip 10 | 11 | # import pdb; pdb.set_trace() 12 | 13 | 14 | # Run the ffmpeg command to get the path 15 | ffmpeg_path = subprocess.check_output(['which', 'ffmpeg']).decode().strip() 16 | # Set the environment variable 17 | os.environ["IMAGEIO_FFMPEG_EXE"] = ffmpeg_path 18 | # Optionally, print the path for verification 19 | print("IMAGEIO_FFMPEG_EXE set to:", os.environ["IMAGEIO_FFMPEG_EXE"]) 20 | 21 | 22 | # import moviepy.config_defaults 23 | # magick_path = subprocess.check_output(['which', 'magick']).decode().strip() 24 | # moviepy.config_defaults.IMAGEMAGICK_BINARY = magick_path 25 | # print("IMAGEMAGICK_BINARY Path set to:", moviepy.config_defaults.IMAGEMAGICK_BINARY) 26 | 27 | 28 | 29 | 30 | def create_animation_with_text(image_path1, image_path2, audio_path, output_path, image_prompt, audio_prompt): 31 | # set up video hyperparameter 32 | space_between_images = 30 33 | padding = 40 34 | space_between_text_image = 8 35 | fontsize = 26 36 | font = 'Ubuntu-Mono' 37 | text_method = 'caption' 38 | 39 | # Load images 40 | image1 = np.array(Image.open(image_path1).convert('RGB')) 41 | image2 = np.array(Image.open(image_path2).convert('RGB')) 42 | image1_height, image1_width, _ = image1.shape 43 | image2_height, image2_width, _ = image2.shape 44 | 45 | max_image_width = max(image1_width, image2_width) 46 | # Add text prompts 47 | text_size = [max_image_width, None] 48 | image_prompt = TextClip('Image prompt: ' + image_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size) 49 | audio_prompt = TextClip('Audio prompt: ' + audio_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size) 50 | image_text_height = np.array(image_prompt.get_frame(0)).shape[0] 51 | audio_text_height = np.array(audio_prompt.get_frame(0)).shape[0] 52 | # Calculate total height for the video 53 | total_height = image1_height + image2_height + space_between_images + 2 * padding + image_text_height + audio_text_height + 2 * space_between_text_image 54 | 55 | # Create a white slider image without shadow 56 | slider_width = 5 # Increased width 57 | slider_height = image1_height + image2_height + space_between_images # Adjusted height 58 | slider_color = (240, 240, 240) # White 59 | 60 | # Create slider without shadow 61 | slider = np.zeros((slider_height, slider_width, 3), dtype=np.uint8) 62 | slider[:, :] = slider_color # Main slider area 63 | 64 | # import pdb; pdb.set_trace() 65 | # Load audio using soundfile 66 | audio_data, sample_rate = sf.read(audio_path) 67 | 68 | # Calculate video duration based on audio length 69 | video_duration = len(audio_data) / sample_rate 70 | 71 | # Define video dimensions 72 | video_width = max_image_width + 2 * padding 73 | 74 | # Function to generate frame at time t 75 | def make_frame(t): 76 | # Calculate slider position 77 | # import pdb; pdb.set_trace() 78 | slider_position = int(t * (max_image_width - slider_width//2) / video_duration) 79 | 80 | # Create a white blank frame 81 | frame = np.ones((total_height, video_width, 3), dtype=np.uint8) * 255 82 | 83 | # Calculate positions for image prompt and add it 84 | image_text = np.array(image_prompt.get_frame(t)) 85 | image_prompt_start_pos = (padding, (video_width - image_text.shape[1]) // 2 ) 86 | image_prompt_end_pos = (padding + image_text.shape[0], image_prompt_start_pos[1] + image_text.shape[1]) 87 | frame[image_prompt_start_pos[0]: image_prompt_end_pos[0], image_prompt_start_pos[1]: image_prompt_end_pos[1]] = image_text 88 | 89 | # Add images to the frame 90 | image1_start_pos = (image_prompt_end_pos[0] + space_between_text_image, padding) 91 | frame[image1_start_pos[0]: (image1_start_pos[0] + image1_height), image1_start_pos[1]:(image1_start_pos[1] + image1_width)] = image1 92 | 93 | image2_start_pos = (image1_start_pos[0] + image1_height + space_between_images, padding) 94 | frame[image2_start_pos[0]: (image2_start_pos[0] + image2_height), image2_start_pos[1]:(image2_start_pos[1] + image2_width)] = image2 95 | 96 | # Calculate positions for image prompt and add it 97 | audio_text = np.array(audio_prompt.get_frame(t)) 98 | audio_prompt_start_pos = (image2_start_pos[0] + image2_height + space_between_text_image, (video_width - audio_text.shape[1]) // 2) 99 | audio_prompt_end_pos = (audio_prompt_start_pos[0] + audio_text.shape[0], audio_prompt_start_pos[1] + audio_text.shape[1]) 100 | frame[audio_prompt_start_pos[0]: audio_prompt_end_pos[0], audio_prompt_start_pos[1]: audio_prompt_end_pos[1]] = audio_text 101 | 102 | # Add slider to the frame 103 | frame[image1_start_pos[0]:(slider_height+image1_start_pos[0]), (padding+slider_position):(padding+slider_position+slider_width)] = slider 104 | 105 | return frame 106 | 107 | # Create a VideoClip 108 | video_clip = VideoClip(make_frame, duration=video_duration) 109 | 110 | # Write the final video 111 | temp_path = output_path[:-4] + '-temp.mp4' 112 | video_clip.write_videofile(temp_path, codec='libx264', fps=60, logger=None) 113 | 114 | # the reason we do this is because when change audio codec, the quality of audio is changed a lot. 115 | # So we copy the original audio to ensure we have best audio quality 116 | os.system(f"ffmpeg -v quiet -y -i \"{temp_path}\" -i {audio_path} -c:v copy -c:a aac {output_path}") 117 | os.system(f"rm {temp_path}") 118 | 119 | 120 | def create_single_image_animation_with_text(image_path, audio_path, output_path, image_prompt, audio_prompt): 121 | # set up video hyperparameter 122 | padding = 40 123 | space_between_text_image = 8 124 | fontsize = 26 125 | font = 'Ubuntu-Mono' 126 | text_method = 'caption' 127 | 128 | # Load images 129 | image = np.array(Image.open(image_path).convert('RGB')) 130 | image_height, image_width, _ = image.shape 131 | 132 | max_image_width = image_width 133 | 134 | # Add text prompts 135 | text_size = [max_image_width, None] 136 | image_prompt = TextClip('Image prompt: ' + image_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size) 137 | audio_prompt = TextClip('Audio prompt: ' + audio_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size) 138 | image_text_height = np.array(image_prompt.get_frame(0)).shape[0] 139 | audio_text_height = np.array(audio_prompt.get_frame(0)).shape[0] 140 | 141 | # Calculate total height for the video 142 | total_height = image_height + 2 * padding + image_text_height + audio_text_height + 2 * space_between_text_image 143 | 144 | # Create a white slider image without shadow 145 | slider_width = 5 # Increased width 146 | slider_height = image_height # Adjusted height 147 | slider_color = (240, 240, 240) # White 148 | border_color = (200, 200, 200) # gray 149 | 150 | 151 | # Create slider without shadow 152 | slider = np.zeros((slider_height, slider_width, 3), dtype=np.uint8) 153 | slider[:, :] = slider_color # Main slider area 154 | 155 | # import pdb; pdb.set_trace() 156 | # Load audio using soundfile 157 | audio_data, sample_rate = sf.read(audio_path) 158 | 159 | # Calculate video duration based on audio length 160 | video_duration = len(audio_data) / sample_rate 161 | 162 | # Define video dimensions 163 | video_width = max_image_width + 2 * padding 164 | 165 | # Function to generate frame at time t 166 | def make_frame(t): 167 | # Calculate slider position 168 | # import pdb; pdb.set_trace() 169 | slider_position = int(t * (max_image_width - slider_width // 2) / video_duration) 170 | 171 | # Create a white blank frame 172 | frame = np.ones((total_height, video_width, 3), dtype=np.uint8) * 255 173 | 174 | # Calculate positions for image prompt and add it 175 | image_text = np.array(image_prompt.get_frame(t)) 176 | image_prompt_start_pos = (padding, (video_width - image_text.shape[1]) // 2 ) 177 | image_prompt_end_pos = (padding + image_text.shape[0], image_prompt_start_pos[1] + image_text.shape[1]) 178 | frame[image_prompt_start_pos[0]: image_prompt_end_pos[0], image_prompt_start_pos[1]: image_prompt_end_pos[1]] = image_text 179 | 180 | # Add image to the frame 181 | image_start_pos = (image_prompt_end_pos[0] + space_between_text_image, padding) 182 | frame[image_start_pos[0]: (image_start_pos[0] + image_height), image_start_pos[1]:(image_start_pos[1] + image_width)] = image 183 | 184 | # Calculate positions for image prompt and add it 185 | audio_text = np.array(audio_prompt.get_frame(t)) 186 | audio_prompt_start_pos = (image_start_pos[0] + image_height + space_between_text_image, (video_width - audio_text.shape[1]) // 2) 187 | audio_prompt_end_pos = (audio_prompt_start_pos[0] + audio_text.shape[0], audio_prompt_start_pos[1] + audio_text.shape[1]) 188 | frame[audio_prompt_start_pos[0]: audio_prompt_end_pos[0], audio_prompt_start_pos[1]: audio_prompt_end_pos[1]] = audio_text 189 | 190 | # Add slider to the frame 191 | frame[image_start_pos[0]:(slider_height+image_start_pos[0]), (padding+slider_position):(padding+slider_position+slider_width)] = slider 192 | 193 | return frame 194 | 195 | # Create a VideoClip 196 | video_clip = VideoClip(make_frame, duration=video_duration) 197 | 198 | # Write the final video 199 | temp_path = output_path[:-4] + '-temp.mp4' 200 | video_clip.write_videofile(temp_path, codec='libx264', fps=60, logger=None) 201 | 202 | # the reason we do this is because when change audio codec, the quality of audio is changed a lot. 203 | # So we copy the original audio to ensure we have best audio quality 204 | # os.system(f"ffmpeg -v quiet -y -i \"{temp_path}\" -i {audio_path} -c:v copy -c:a copy {output_path}") 205 | os.system(f"ffmpeg -v quiet -y -i \"{temp_path}\" -i {audio_path} -c:v copy -c:a aac {output_path}") 206 | 207 | os.system(f"rm {temp_path}") 208 | 209 | 210 | # Example usage: 211 | if __name__ == '__main__': 212 | # rgb = '/home/czyang/Workspace/images-that-sound/logs/soundify/kitten/image_results/img_030000.png' 213 | # spec = '/home/czyang/Workspace/images-that-sound/logs/soundify/kitten/spec_results/spec_030000.png' 214 | # audio = '/home/czyang/Workspace/images-that-sound/logs/soundify/kitten/audio_results/audio_030000.wav' 215 | # image_prompt = 'an oil paint of playground with cats chasing and playing' 216 | # audio_prompt = 'A kitten mewing for attention' 217 | # audio = 'audio_050000.wav' 218 | 219 | # example_path = '/home/czyang/Workspace/images-that-sound/logs/soundify-denoise/good-examples/example_02' 220 | example_path = '/home/czyang/Workspace/images-that-sound/logs/soundify-denoise/colorization/tiger_example_06' 221 | 222 | rgb = f'{example_path}/img_rgb.png' 223 | spec = f'{example_path}/spec.png' 224 | audio = f'{example_path}/audio.wav' 225 | config_path = f'{example_path}/config.yaml' 226 | cfg = OmegaConf.load(config_path) 227 | # image_prompt = cfg.trainer.colorization_prompt 228 | image_prompt = cfg.trainer.image_prompt 229 | audio_prompt = cfg.trainer.audio_prompt 230 | # output_path = f'{example_path}/video_rgb.mp4' 231 | output_path = f'test.mp4' 232 | 233 | # create_animation_with_text(rgb, spec, audio, output_path, image_prompt, audio_prompt) 234 | create_single_image_animation_with_text(spec, audio, output_path, image_prompt, audio_prompt) 235 | -------------------------------------------------------------------------------- /src/guidance/auffusion.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.utils import save_image 7 | 8 | 9 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline 10 | from diffusers.utils.import_utils import is_xformers_available 11 | 12 | # from .perpneg_utils import weighted_perpendicular_aggregator 13 | 14 | from lightning import seed_everything 15 | 16 | import rootutils 17 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 18 | 19 | from src.models.components.auffusion_converter import Generator, denormalize_spectrogram 20 | 21 | 22 | 23 | class AuffusionGuidance(nn.Module): 24 | def __init__( 25 | self, 26 | repo_id='auffusion/auffusion-full-no-adapter', 27 | fp16=True, 28 | t_range=[0.02, 0.98], 29 | **kwargs 30 | ): 31 | super().__init__() 32 | 33 | self.repo_id = repo_id 34 | 35 | self.precision_t = torch.float16 if fp16 else torch.float32 36 | 37 | # Create model 38 | self.vae, self.tokenizer, self.text_encoder, self.unet = self.create_model_from_pipe(repo_id, self.precision_t) 39 | self.scheduler = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler", torch_dtype=self.precision_t) 40 | self.vocoder = Generator.from_pretrained(repo_id, subfolder="vocoder").to(dtype=self.precision_t) 41 | 42 | self.register_buffer('alphas_cumprod', self.scheduler.alphas_cumprod) 43 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps 44 | self.min_step = int(self.num_train_timesteps * t_range[0]) 45 | self.max_step = int(self.num_train_timesteps * t_range[1]) 46 | 47 | def create_model_from_pipe(self, repo_id, dtype): 48 | pipe = StableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype) 49 | vae = pipe.vae 50 | tokenizer = pipe.tokenizer 51 | text_encoder = pipe.text_encoder 52 | unet = pipe.unet 53 | return vae, tokenizer, text_encoder, unet 54 | 55 | @torch.no_grad() 56 | def get_text_embeds(self, prompt, device): 57 | # prompt: [str] 58 | # import pdb; pdb.set_trace() 59 | inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') 60 | prompt_embeds = self.text_encoder(inputs.input_ids.to(device))[0] 61 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 62 | return prompt_embeds 63 | 64 | def train_step(self, text_embeddings, pred_spec, guidance_scale=100, as_latent=False, t=None, grad_scale=1, save_guidance_path:Path=None): 65 | # import pdb; pdb.set_trace() 66 | pred_spec = pred_spec.to(self.vae.dtype) 67 | 68 | if as_latent: 69 | latents = pred_spec 70 | else: 71 | if pred_spec.shape[1] != 3: 72 | pred_spec = pred_spec.repeat(1, 3, 1, 1) 73 | 74 | # encode image into latents with vae, requires grad! 75 | latents = self.encode_imgs(pred_spec) 76 | 77 | if t is None: 78 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level 79 | t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=latents.device) 80 | else: 81 | t = t.to(dtype=torch.long, device=latents.device) 82 | 83 | # predict the noise residual with unet, NO grad! 84 | with torch.no_grad(): 85 | # add noise 86 | noise = torch.randn_like(latents) 87 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 88 | # pred noise 89 | latent_model_input = torch.cat([latents_noisy] * 2) 90 | tt = torch.cat([t] * 2) 91 | noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample 92 | 93 | # perform guidance (high scale from paper!) 94 | noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) 95 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) 96 | 97 | # w(t), sigma_t^2 98 | w = (1 - self.alphas_cumprod[t]) 99 | grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) 100 | grad = torch.nan_to_num(grad) 101 | 102 | if save_guidance_path: 103 | with torch.no_grad(): 104 | if as_latent: 105 | pred_rgb_512 = self.decode_latents(latents) 106 | 107 | # visualize predicted denoised image 108 | # The following block of code is equivalent to `predict_start_from_noise`... 109 | # see zero123_utils.py's version for a simpler implementation. 110 | alphas = self.scheduler.alphas.to(latents.device) 111 | total_timesteps = self.max_step - self.min_step + 1 112 | index = total_timesteps - t.to(latents.device) - 1 113 | b = len(noise_pred) 114 | a_t = alphas[index].reshape(b,1,1,1).to(latents.device) 115 | sqrt_one_minus_alphas = torch.sqrt(1 - alphas) 116 | sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(latents.device) 117 | pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0 118 | result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t))) 119 | 120 | # visualize noisier image 121 | result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t)) 122 | 123 | # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] 124 | viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) 125 | save_image(viz_images, save_guidance_path) 126 | 127 | targets = (latents - grad).detach() 128 | loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] 129 | 130 | return loss 131 | 132 | 133 | @torch.no_grad() 134 | def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, generator=None): 135 | 136 | if latents is None: 137 | latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=self.unet.dtype).to(text_embeddings.device) 138 | 139 | self.scheduler.set_timesteps(num_inference_steps) 140 | 141 | for i, t in enumerate(self.scheduler.timesteps): 142 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 143 | latent_model_input = torch.cat([latents] * 2) 144 | # predict the noise residual 145 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] 146 | 147 | # perform guidance 148 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 149 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 150 | 151 | # compute the previous noisy sample x_t -> x_t-1 152 | latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] 153 | 154 | return latents 155 | 156 | def decode_latents(self, latents): 157 | latents = latents.to(self.vae.dtype) 158 | latents = 1 / self.vae.config.scaling_factor * latents 159 | 160 | imgs = self.vae.decode(latents).sample 161 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 162 | 163 | return imgs 164 | 165 | def encode_imgs(self, imgs): 166 | # imgs: [B, 3, H, W] 167 | imgs = 2 * imgs - 1 168 | posterior = self.vae.encode(imgs).latent_dist 169 | latents = posterior.sample() * self.vae.config.scaling_factor 170 | 171 | return latents 172 | 173 | def spec_to_audio(self, spec): 174 | spec = spec.to(dtype=self.precision_t) 175 | denorm_spec = denormalize_spectrogram(spec) 176 | audio = self.vocoder.inference(denorm_spec) 177 | return audio 178 | 179 | def prompt_to_audio(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, device=None, generator=None): 180 | if isinstance(prompts, str): 181 | prompts = [prompts] 182 | 183 | if isinstance(negative_prompts, str): 184 | negative_prompts = [negative_prompts] 185 | 186 | # Prompts -> text embeds 187 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768] 188 | neg_embeds = self.get_text_embeds(negative_prompts, device) 189 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] 190 | 191 | # Text embeds -> img latents 192 | latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) # [1, 4, 64, 64] 193 | 194 | # Img latents -> imgs 195 | imgs = self.decode_latents(latents) # [1, 3, 512, 512] 196 | spec = imgs[0] 197 | denorm_spec = denormalize_spectrogram(spec) 198 | audio = self.vocoder.inference(denorm_spec) 199 | # # Img to Numpy 200 | # imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() 201 | # imgs = (imgs * 255).round().astype('uint8') 202 | 203 | return audio 204 | 205 | def prompt_to_spec(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, device=None, generator=None): 206 | if isinstance(prompts, str): 207 | prompts = [prompts] 208 | 209 | if isinstance(negative_prompts, str): 210 | negative_prompts = [negative_prompts] 211 | 212 | # Prompts -> text embeds 213 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768] 214 | neg_embeds = self.get_text_embeds(negative_prompts, device) 215 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] 216 | 217 | # Text embeds -> img latents 218 | latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) # [1, 4, 64, 64] 219 | 220 | # Img latents -> imgs 221 | imgs = self.decode_latents(latents) # [1, 3, 512, 512] 222 | return imgs 223 | 224 | if __name__ == '__main__': 225 | import numpy as np 226 | import argparse 227 | from PIL import Image 228 | import os 229 | import soundfile as sf 230 | 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument('--prompt', type=str, default='A kitten mewing for attention') 233 | parser.add_argument('--negative', default='', type=str) 234 | parser.add_argument('--repo_id', type=str, default='auffusion/auffusion-full-no-adapter', help="stable diffusion version") 235 | parser.add_argument('--fp16', action='store_true', help="use float16 for training") 236 | parser.add_argument('--H', type=int, default=256) 237 | parser.add_argument('--W', type=int, default=1024) 238 | parser.add_argument('--seed', type=int, default=0) 239 | parser.add_argument('--steps', type=int, default=100) 240 | parser.add_argument('--out_dir', type=str, default='logs/test') 241 | 242 | opt = parser.parse_args() 243 | 244 | seed_everything(opt.seed) 245 | 246 | device = torch.device('cuda') 247 | 248 | sd = AuffusionGuidance(repo_id=opt.repo_id, fp16=opt.fp16) 249 | sd = sd.to(device) 250 | 251 | audio = sd.prompt_to_audio(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, device=device) 252 | # import pdb; pdb.set_trace() 253 | # visualize audio 254 | save_folder = opt.out_dir 255 | os.makedirs(save_folder, exist_ok=True) 256 | save_path = os.path.join(save_folder, f'{opt.prompt}.wav') 257 | sf.write(save_path, np.ravel(audio), samplerate=16000) -------------------------------------------------------------------------------- /src/colorization/samplers.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision.transforms.functional as TF 8 | 9 | from diffusers.utils.torch_utils import randn_tensor 10 | 11 | @torch.no_grad() 12 | def sample_stage_1( 13 | model, 14 | prompt_embeds, 15 | negative_prompt_embeds, 16 | views, 17 | height=None, 18 | width=None, 19 | fixed_im=None, 20 | num_inference_steps=100, 21 | guidance_scale=7.0, 22 | reduction='mean', 23 | generator=None, 24 | num_recurse=1, 25 | start_diffusion_step=0 26 | ): 27 | 28 | # Params 29 | num_images_per_prompt = 1 30 | device = torch.device('cuda') # Sometimes model device is cpu??? 31 | height = model.unet.config.sample_size if height is None else height 32 | width = model.unet.config.sample_size if width is None else width 33 | batch_size = 1 # TODO: Support larger batch sizes, maybe 34 | num_prompts = prompt_embeds.shape[0] 35 | assert num_prompts == len(views), \ 36 | "Number of prompts must match number of views!" 37 | 38 | # Resize image to correct size 39 | if fixed_im is not None: 40 | fixed_im = TF.resize(fixed_im, height, antialias=False) 41 | 42 | # For CFG 43 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 44 | 45 | # Setup timesteps 46 | model.scheduler.set_timesteps(num_inference_steps, device=device) 47 | timesteps = model.scheduler.timesteps 48 | 49 | # Make intermediate_images 50 | noisy_images = model.prepare_intermediate_images( 51 | batch_size * num_images_per_prompt, 52 | model.unet.config.in_channels, 53 | height, 54 | width, 55 | prompt_embeds.dtype, 56 | device, 57 | generator, 58 | ) 59 | 60 | for i, t in enumerate(timesteps): 61 | for j in range(num_recurse): 62 | # Logic to keep a component fixed to reference image 63 | if fixed_im is not None: 64 | # Inject noise 65 | alpha_cumprod = model.scheduler.alphas_cumprod[t] 66 | im_noisy = torch.sqrt(alpha_cumprod) * fixed_im + torch.sqrt(1 - alpha_cumprod) * torch.randn_like(fixed_im) 67 | 68 | # Replace component in noisy images with component from fixed image 69 | im_noisy_component = views[0].inverse_view(im_noisy).to(noisy_images.device).to(noisy_images.dtype) 70 | noisy_images_component = views[1].inverse_view(noisy_images[0]) 71 | noisy_images = im_noisy_component + noisy_images_component 72 | 73 | # Correct for factor of 2 from view TODO: Fix this.... 74 | noisy_images = noisy_images[None] / 2. 75 | 76 | # "Reset" diffusion by replacing noisy images with noisy version 77 | # of grayscale image. All diffusion steps before this one are "thrown away" 78 | if i == start_diffusion_step: 79 | noisy_images = im_noisy.to(noisy_images.device).to(noisy_images.dtype)[None] 80 | 81 | # Apply views to noisy_image 82 | viewed_noisy_images = [] 83 | for view_fn in views: 84 | viewed_noisy_images.append(view_fn.view(noisy_images[0])) 85 | viewed_noisy_images = torch.stack(viewed_noisy_images) 86 | 87 | # Duplicate inputs for CFG 88 | # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ] 89 | model_input = torch.cat([viewed_noisy_images] * 2) 90 | model_input = model.scheduler.scale_model_input(model_input, t) 91 | 92 | # Predict noise estimate 93 | noise_pred = model.unet( 94 | model_input, 95 | t, 96 | encoder_hidden_states=prompt_embeds, 97 | cross_attention_kwargs=None, 98 | return_dict=False, 99 | )[0] 100 | 101 | # Extract uncond (neg) and cond noise estimates 102 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 103 | 104 | # Invert the unconditional (negative) estimates 105 | inverted_preds = [] 106 | for pred, view in zip(noise_pred_uncond, views): 107 | inverted_pred = view.inverse_view(pred) 108 | inverted_preds.append(inverted_pred) 109 | noise_pred_uncond = torch.stack(inverted_preds) 110 | 111 | # Invert the conditional estimates 112 | inverted_preds = [] 113 | for pred, view in zip(noise_pred_text, views): 114 | inverted_pred = view.inverse_view(pred) 115 | inverted_preds.append(inverted_pred) 116 | noise_pred_text = torch.stack(inverted_preds) 117 | 118 | # Split into noise estimate and variance estimates 119 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) 120 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) 121 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 122 | 123 | # Reduce predicted noise and variances 124 | noise_pred = noise_pred.view(-1,num_prompts,3,height,width) 125 | predicted_variance = predicted_variance.view(-1,num_prompts,3,height,width) 126 | if reduction == 'mean': 127 | noise_pred = noise_pred.mean(1) 128 | predicted_variance = predicted_variance.mean(1) 129 | elif reduction == 'alternate': 130 | noise_pred = noise_pred[:,i%num_prompts] 131 | predicted_variance = predicted_variance[:,i%num_prompts] 132 | else: 133 | raise ValueError('Reduction must be either `mean` or `alternate`') 134 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) 135 | 136 | # compute the previous noisy sample x_t -> x_t-1 137 | noisy_images = model.scheduler.step( 138 | noise_pred, t, noisy_images, generator=generator, return_dict=False 139 | )[0] 140 | 141 | if j != (num_recurse - 1): 142 | beta = model.scheduler.betas[t] 143 | noisy_images = (1 - beta).sqrt() * noisy_images + beta.sqrt() * torch.randn_like(noisy_images) 144 | 145 | # Return denoised images 146 | return noisy_images 147 | 148 | 149 | 150 | 151 | 152 | @torch.no_grad() 153 | def sample_stage_2( 154 | model, 155 | image, 156 | prompt_embeds, 157 | negative_prompt_embeds, 158 | views, 159 | height=None, 160 | width=None, 161 | fixed_im=None, 162 | num_inference_steps=100, 163 | guidance_scale=7.0, 164 | reduction='mean', 165 | noise_level=50, 166 | generator=None 167 | ): 168 | 169 | # Params 170 | batch_size = 1 # TODO: Support larger batch sizes, maybe 171 | num_prompts = prompt_embeds.shape[0] 172 | height = model.unet.config.sample_size if height is None else height 173 | width = model.unet.config.sample_size if width is None else width 174 | device = image.device 175 | num_images_per_prompt = 1 176 | 177 | # Resize fixed image to correct size 178 | if fixed_im is not None: 179 | fixed_im = TF.resize(fixed_im, height, antialias=False) 180 | 181 | # For CFG 182 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 183 | 184 | # Get timesteps 185 | model.scheduler.set_timesteps(num_inference_steps, device=device) 186 | timesteps = model.scheduler.timesteps 187 | 188 | num_channels = model.unet.config.in_channels // 2 189 | noisy_images = model.prepare_intermediate_images( 190 | batch_size * num_images_per_prompt, 191 | num_channels, 192 | height, 193 | width, 194 | prompt_embeds.dtype, 195 | device, 196 | generator, 197 | ) 198 | 199 | # Prepare upscaled image and noise level 200 | image = model.preprocess_image(image, num_images_per_prompt, device) 201 | upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) 202 | 203 | noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) 204 | noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) 205 | upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) 206 | 207 | # Condition on noise level, for each model input 208 | noise_level = torch.cat([noise_level] * num_prompts * 2) 209 | 210 | # Denoising Loop 211 | for i, t in enumerate(timesteps): 212 | # Logic to keep a component fixed to reference image 213 | if fixed_im is not None: 214 | # Inject noise 215 | alpha_cumprod = model.scheduler.alphas_cumprod[t] 216 | im_noisy = torch.sqrt(alpha_cumprod) * fixed_im + torch.sqrt(1 - alpha_cumprod) * torch.randn_like(fixed_im) 217 | 218 | # Replace component in noisy images with componen from fixed image 219 | im_noisy_component = views[0].inverse_view(im_noisy).to(noisy_images.device).to(noisy_images.dtype) 220 | noisy_images_component = views[1].inverse_view(noisy_images[0]) 221 | noisy_images = im_noisy_component + noisy_images_component 222 | 223 | # Correct for factor of 2 TODO: Fix this.... 224 | noisy_images = noisy_images[None] / 2. 225 | 226 | # Cat noisy image with upscaled conditioning image 227 | model_input = torch.cat([noisy_images, upscaled], dim=1) 228 | 229 | # Apply views to noisy_image 230 | viewed_inputs = [] 231 | for view_fn in views: 232 | viewed_inputs.append(view_fn.view(model_input[0])) 233 | viewed_inputs = torch.stack(viewed_inputs) 234 | 235 | # Duplicate inputs for CFG 236 | # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ] 237 | model_input = torch.cat([viewed_inputs] * 2) 238 | model_input = model.scheduler.scale_model_input(model_input, t) 239 | 240 | # predict the noise residual 241 | noise_pred = model.unet( 242 | model_input, 243 | t, 244 | encoder_hidden_states=prompt_embeds, 245 | class_labels=noise_level, 246 | cross_attention_kwargs=None, 247 | return_dict=False, 248 | )[0] 249 | 250 | # Extract uncond (neg) and cond noise estimates 251 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 252 | 253 | # Invert the unconditional (negative) estimates 254 | # TODO: pretty sure you can combine these into one loop 255 | inverted_preds = [] 256 | for pred, view in zip(noise_pred_uncond, views): 257 | inverted_pred = view.inverse_view(pred) 258 | inverted_preds.append(inverted_pred) 259 | noise_pred_uncond = torch.stack(inverted_preds) 260 | 261 | # Invert the conditional estimates 262 | inverted_preds = [] 263 | for pred, view in zip(noise_pred_text, views): 264 | inverted_pred = view.inverse_view(pred) 265 | inverted_preds.append(inverted_pred) 266 | noise_pred_text = torch.stack(inverted_preds) 267 | 268 | # Split predicted noise and predicted variances 269 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) 270 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) 271 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 272 | 273 | # Combine noise estimates (and variance estimates) 274 | noise_pred = noise_pred.view(-1,num_prompts,3,height,width) 275 | predicted_variance = predicted_variance.view(-1,num_prompts,3,height,width) 276 | if reduction == 'mean': 277 | noise_pred = noise_pred.mean(1) 278 | predicted_variance = predicted_variance.mean(1) 279 | elif reduction == 'alternate': 280 | noise_pred = noise_pred[:,i%num_prompts] 281 | predicted_variance = predicted_variance[:,i%num_prompts] 282 | 283 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) 284 | 285 | # compute the previous noisy sample x_t -> x_t-1 286 | noisy_images = model.scheduler.step( 287 | noise_pred, t, noisy_images, generator=generator, return_dict=False 288 | )[0] 289 | 290 | # Return denoised images 291 | return noisy_images -------------------------------------------------------------------------------- /src/models/components/auffusion_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import math 5 | import os 6 | import random 7 | import torch 8 | import json 9 | import torch.utils.data 10 | import numpy as np 11 | import librosa 12 | from librosa.util import normalize 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | import torch.nn.functional as F 17 | import torch.nn as nn 18 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 19 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 20 | from huggingface_hub import snapshot_download 21 | 22 | MAX_WAV_VALUE = 32768.0 23 | 24 | 25 | def load_wav(full_path): 26 | sampling_rate, data = read(full_path) 27 | return data, sampling_rate 28 | 29 | 30 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 31 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 32 | 33 | 34 | def dynamic_range_decompression(x, C=1): 35 | return np.exp(x) / C 36 | 37 | 38 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 39 | return torch.log(torch.clamp(x, min=clip_val) * C) 40 | 41 | 42 | def dynamic_range_decompression_torch(x, C=1): 43 | return torch.exp(x) / C 44 | 45 | 46 | def spectral_normalize_torch(magnitudes): 47 | output = dynamic_range_compression_torch(magnitudes) 48 | return output 49 | 50 | 51 | def spectral_de_normalize_torch(magnitudes): 52 | output = dynamic_range_decompression_torch(magnitudes) 53 | return output 54 | 55 | 56 | mel_basis = {} 57 | hann_window = {} 58 | 59 | 60 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 61 | if torch.min(y) < -1.: 62 | print('min value is ', torch.min(y)) 63 | if torch.max(y) > 1.: 64 | print('max value is ', torch.max(y)) 65 | 66 | global mel_basis, hann_window 67 | if fmax not in mel_basis: 68 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 69 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 70 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 71 | 72 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 73 | y = y.squeeze(1) 74 | 75 | # complex tensor as default, then use view_as_real for future pytorch compatibility 76 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 77 | spec = torch.view_as_real(spec) 78 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 79 | 80 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 81 | spec = spectral_normalize_torch(spec) 82 | 83 | return spec 84 | 85 | 86 | def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 87 | if torch.min(y) < -1.: 88 | print('min value is ', torch.min(y)) 89 | if torch.max(y) > 1.: 90 | print('max value is ', torch.max(y)) 91 | 92 | global hann_window 93 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 94 | 95 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 96 | y = y.squeeze(1) 97 | 98 | # complex tensor as default, then use view_as_real for future pytorch compatibility 99 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 100 | spec = torch.view_as_real(spec) 101 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 102 | 103 | return spec 104 | 105 | 106 | def normalize_spectrogram( 107 | spectrogram: torch.Tensor, 108 | max_value: float = 200, 109 | min_value: float = 1e-5, 110 | power: float = 1., 111 | inverse: bool = False, 112 | flip: bool = True 113 | ) -> torch.Tensor: 114 | 115 | # Rescale to 0-1 116 | max_value = np.log(max_value) # 5.298317366548036 117 | min_value = np.log(min_value) # -11.512925464970229 118 | 119 | assert spectrogram.max() <= max_value and spectrogram.min() >= min_value 120 | 121 | data = (spectrogram - min_value) / (max_value - min_value) 122 | 123 | # Invert 124 | if inverse: 125 | data = 1 - data 126 | 127 | # Apply the power curve 128 | data = torch.pow(data, power) 129 | 130 | # 1D -> 3D 131 | data = data.repeat(3, 1, 1) 132 | 133 | # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner 134 | if flip: 135 | data = torch.flip(data, [1]) 136 | 137 | return data 138 | 139 | 140 | 141 | def denormalize_spectrogram( 142 | data: torch.Tensor, 143 | max_value: float = 200, 144 | min_value: float = 1e-5, 145 | power: float = 1, 146 | inverse: bool = False, 147 | ) -> torch.Tensor: 148 | 149 | max_value = np.log(max_value) 150 | min_value = np.log(min_value) 151 | 152 | # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner 153 | data = torch.flip(data, [1]) 154 | 155 | assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape)) 156 | 157 | # if data.shape[0] == 1: 158 | # data = data.repeat(3, 1, 1) 159 | 160 | # assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0]) 161 | 162 | # data = data[0] 163 | data = data.mean(dim=0) 164 | 165 | # Reverse the power curve 166 | data = torch.pow(data, 1 / power) 167 | 168 | # Invert 169 | if inverse: 170 | data = 1 - data 171 | 172 | # Rescale to max value 173 | spectrogram = data * (max_value - min_value) + min_value 174 | 175 | return spectrogram 176 | 177 | 178 | def get_mel_spectrogram_from_audio(audio, device="cuda"): 179 | audio = audio / MAX_WAV_VALUE 180 | audio = librosa.util.normalize(audio) * 0.95 181 | 182 | audio = torch.FloatTensor(audio) 183 | audio = audio.unsqueeze(0) 184 | 185 | waveform = audio.to(device) 186 | spec = mel_spectrogram(waveform, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False) 187 | return audio, spec 188 | 189 | 190 | 191 | LRELU_SLOPE = 0.1 192 | MAX_WAV_VALUE = 32768.0 193 | 194 | 195 | class AttrDict(dict): 196 | def __init__(self, *args, **kwargs): 197 | super(AttrDict, self).__init__(*args, **kwargs) 198 | self.__dict__ = self 199 | 200 | 201 | def get_config(config_path): 202 | config = json.loads(open(config_path).read()) 203 | config = AttrDict(config) 204 | return config 205 | 206 | def init_weights(m, mean=0.0, std=0.01): 207 | classname = m.__class__.__name__ 208 | if classname.find("Conv") != -1: 209 | m.weight.data.normal_(mean, std) 210 | 211 | 212 | def apply_weight_norm(m): 213 | classname = m.__class__.__name__ 214 | if classname.find("Conv") != -1: 215 | weight_norm(m) 216 | 217 | 218 | def get_padding(kernel_size, dilation=1): 219 | return int((kernel_size*dilation - dilation)/2) 220 | 221 | 222 | class ResBlock1(torch.nn.Module): 223 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 224 | super(ResBlock1, self).__init__() 225 | self.h = h 226 | self.convs1 = nn.ModuleList([ 227 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), 228 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), 229 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) 230 | ]) 231 | self.convs1.apply(init_weights) 232 | 233 | self.convs2 = nn.ModuleList([ 234 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), 235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) 237 | ]) 238 | self.convs2.apply(init_weights) 239 | 240 | def forward(self, x): 241 | for c1, c2 in zip(self.convs1, self.convs2): 242 | xt = F.leaky_relu(x, LRELU_SLOPE) 243 | xt = c1(xt) 244 | xt = F.leaky_relu(xt, LRELU_SLOPE) 245 | xt = c2(xt) 246 | x = xt + x 247 | return x 248 | 249 | def remove_weight_norm(self): 250 | for l in self.convs1: 251 | remove_weight_norm(l) 252 | for l in self.convs2: 253 | remove_weight_norm(l) 254 | 255 | 256 | class ResBlock2(torch.nn.Module): 257 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 258 | super(ResBlock2, self).__init__() 259 | self.h = h 260 | self.convs = nn.ModuleList([ 261 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), 262 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) 263 | ]) 264 | self.convs.apply(init_weights) 265 | 266 | def forward(self, x): 267 | for c in self.convs: 268 | xt = F.leaky_relu(x, LRELU_SLOPE) 269 | xt = c(xt) 270 | x = xt + x 271 | return x 272 | 273 | def remove_weight_norm(self): 274 | for l in self.convs: 275 | remove_weight_norm(l) 276 | 277 | 278 | 279 | class Generator(torch.nn.Module): 280 | def __init__(self, h): 281 | super(Generator, self).__init__() 282 | self.h = h 283 | self.num_kernels = len(h.resblock_kernel_sizes) 284 | self.num_upsamples = len(h.upsample_rates) 285 | self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512 286 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 287 | 288 | self.ups = nn.ModuleList() 289 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 290 | if (k-u) % 2 == 0: 291 | self.ups.append(weight_norm( 292 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 293 | k, u, padding=(k-u)//2))) 294 | else: 295 | self.ups.append(weight_norm( 296 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 297 | k, u, padding=(k-u)//2+1, output_padding=1))) 298 | 299 | # self.ups.append(weight_norm( 300 | # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 301 | # k, u, padding=(k-u)//2))) 302 | 303 | 304 | self.resblocks = nn.ModuleList() 305 | for i in range(len(self.ups)): 306 | ch = h.upsample_initial_channel//(2**(i+1)) 307 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 308 | self.resblocks.append(resblock(h, ch, k, d)) 309 | 310 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 311 | self.ups.apply(init_weights) 312 | self.conv_post.apply(init_weights) 313 | 314 | def forward(self, x): 315 | x = self.conv_pre(x) 316 | for i in range(self.num_upsamples): 317 | x = F.leaky_relu(x, LRELU_SLOPE) 318 | x = self.ups[i](x) 319 | xs = None 320 | for j in range(self.num_kernels): 321 | if xs is None: 322 | xs = self.resblocks[i*self.num_kernels+j](x) 323 | else: 324 | xs += self.resblocks[i*self.num_kernels+j](x) 325 | x = xs / self.num_kernels 326 | x = F.leaky_relu(x) 327 | x = self.conv_post(x) 328 | x = torch.tanh(x) 329 | 330 | return x 331 | 332 | def remove_weight_norm(self): 333 | for l in self.ups: 334 | remove_weight_norm(l) 335 | for l in self.resblocks: 336 | l.remove_weight_norm() 337 | remove_weight_norm(self.conv_pre) 338 | remove_weight_norm(self.conv_post) 339 | 340 | @classmethod 341 | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None): 342 | if not os.path.isdir(pretrained_model_name_or_path): 343 | pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) 344 | 345 | if subfolder is not None: 346 | pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder) 347 | config_path = os.path.join(pretrained_model_name_or_path, "config.json") 348 | ckpt_path = os.path.join(pretrained_model_name_or_path, "vocoder.pt") 349 | 350 | config = get_config(config_path) 351 | vocoder = cls(config) 352 | 353 | state_dict_g = torch.load(ckpt_path) 354 | vocoder.load_state_dict(state_dict_g["generator"]) 355 | vocoder.eval() 356 | vocoder.remove_weight_norm() 357 | return vocoder 358 | 359 | 360 | @torch.no_grad() 361 | def inference(self, mels, lengths=None): 362 | self.eval() 363 | with torch.no_grad(): 364 | wavs = self(mels).squeeze(1) 365 | 366 | # wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16") 367 | wavs = (wavs.cpu().numpy()).astype("float32") # I change the code from int16 to float32 368 | 369 | if lengths is not None: 370 | wavs = wavs[:, :lengths] 371 | 372 | return wavs --------------------------------------------------------------------------------