├── src
├── __init__.py
├── models
│ ├── __init__.py
│ └── components
│ │ ├── __init__.py
│ │ ├── learnable_image.py
│ │ └── auffusion_converter.py
├── utils
│ ├── __init__.py
│ ├── rich_utils.py
│ ├── re_ranking.py
│ ├── pylogger.py
│ ├── consistency_check.py
│ └── animation_with_text.py
├── colorization
│ ├── __init__.py
│ ├── views.py
│ ├── create_color_video.py
│ ├── colorizer.py
│ └── samplers.py
├── evaluator
│ ├── __init__.py
│ ├── clap.py
│ ├── clip.py
│ └── eval.py
├── guidance
│ ├── __init__.py
│ ├── deepfloyd.py
│ ├── stable_diffusion.py
│ └── auffusion.py
├── transformation
│ ├── __init__.py
│ ├── identity.py
│ ├── block_rearrange.py
│ ├── random_crop.py
│ └── img_to_spec.py
├── main_imprint.py
├── main_sds.py
└── main_denoise.py
├── huggingface_login.py
├── assets
└── teaser.jpg
├── .project-root
├── configs
├── main_denoise
│ ├── debug
│ │ └── default.yaml
│ ├── hydra
│ │ └── default.yaml
│ ├── experiment
│ │ └── examples
│ │ │ ├── dog.yaml
│ │ │ ├── tiger.yaml
│ │ │ ├── train.yaml
│ │ │ ├── bell.yaml
│ │ │ ├── corgi.yaml
│ │ │ ├── pond.yaml
│ │ │ ├── garden-v2.yaml
│ │ │ ├── horse.yaml
│ │ │ ├── kitten.yaml
│ │ │ ├── race.yaml
│ │ │ └── garden.yaml
│ └── main.yaml
├── main_imprint
│ ├── debug
│ │ └── default.yaml
│ ├── hydra
│ │ └── default.yaml
│ ├── experiment
│ │ └── examples
│ │ │ ├── dog.yaml
│ │ │ ├── bell.yaml
│ │ │ ├── tiger.yaml
│ │ │ ├── train.yaml
│ │ │ ├── kitten.yaml
│ │ │ ├── garden.yaml
│ │ │ └── race.yaml
│ └── main.yaml
└── main_sds
│ ├── debug
│ └── default.yaml
│ ├── hydra
│ └── default.yaml
│ ├── experiment
│ └── examples
│ │ ├── dog.yaml
│ │ ├── bell.yaml
│ │ ├── kitten.yaml
│ │ └── garden.yaml
│ └── main.yaml
├── LICENSE
├── environment.yml
├── .gitignore
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/colorization/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/evaluator/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/guidance/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/transformation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/huggingface_login.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import login
2 | login()
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IFICL/images-that-sound/HEAD/assets/teaser.jpg
--------------------------------------------------------------------------------
/.project-root:
--------------------------------------------------------------------------------
1 | # this file is required for inferring the project root directory
2 | # do not delete
3 |
--------------------------------------------------------------------------------
/configs/main_denoise/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | # overwrite task name so debugging logs are stored in separate folder
7 | task_name: "debug"
8 |
9 | hydra:
10 | job_logging:
11 | root:
12 | level: DEBUG
13 |
14 | extras:
15 | ignore_warnings: false
16 |
17 | trainer:
18 | num_inference_steps: 100
--------------------------------------------------------------------------------
/configs/main_imprint/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | # overwrite task name so debugging logs are stored in separate folder
7 | task_name: "debug"
8 |
9 | hydra:
10 | job_logging:
11 | root:
12 | level: DEBUG
13 |
14 | extras:
15 | ignore_warnings: false
16 |
17 | trainer:
18 | num_inference_steps: 100
--------------------------------------------------------------------------------
/configs/main_sds/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | # overwrite task name so debugging logs are stored in separate folder
7 | task_name: "debug"
8 |
9 | hydra:
10 | job_logging:
11 | root:
12 | level: DEBUG
13 |
14 | extras:
15 | ignore_warnings: false
16 |
17 | trainer:
18 | num_iteration: 10
19 | save_step: 1
20 | visualize_step: 1
--------------------------------------------------------------------------------
/configs/main_sds/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${output_dir}
11 |
12 | job_logging:
13 | handlers:
14 | file:
15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
16 | filename: ${output_dir}/main.log
--------------------------------------------------------------------------------
/src/transformation/identity.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class NaiveIdentity(nn.Module):
7 | def __init__(
8 | self,
9 | **kwargs
10 | ):
11 | '''We implement an identity transformation to support our code
12 | '''
13 | super().__init__()
14 |
15 | def forward(self, x, **kwargs):
16 | '''
17 | Input: x
18 | Output: x
19 | '''
20 | return x
21 |
--------------------------------------------------------------------------------
/configs/main_denoise/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${output_dir}
11 |
12 | job_logging:
13 | handlers:
14 | file:
15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
16 | filename: ${output_dir}/main.log
--------------------------------------------------------------------------------
/configs/main_imprint/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${output_dir}
11 |
12 | job_logging:
13 | handlers:
14 | file:
15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
16 | filename: ${output_dir}/main.log
--------------------------------------------------------------------------------
/src/colorization/views.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ColorLView():
5 | def __init__(self):
6 | pass
7 |
8 | def view(self, im):
9 | return im
10 |
11 | def inverse_view(self, noise):
12 | # Get L color by averaging color channels
13 | noise[:3] = 2 * torch.stack([noise[:3].mean(0)] * 3)
14 |
15 | return noise
16 |
17 |
18 | class ColorABView():
19 | def __init__(self):
20 | pass
21 |
22 | def view(self, im):
23 | return im
24 |
25 | def inverse_view(self, noise):
26 | # Get AB color by taking residual
27 | noise[:3] = 2 * (noise[:3] - torch.stack([noise[:3].mean(0)] * 3))
28 |
29 | return noise
30 |
31 |
--------------------------------------------------------------------------------
/src/transformation/block_rearrange.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from kornia.augmentation import RandomCrop
6 |
7 |
8 | class BlockRearranger(nn.Module):
9 | def __init__(
10 | self,
11 | **kwargs
12 | ):
13 | '''We implement an easy block rearrange operation
14 | '''
15 | super().__init__()
16 |
17 | def forward(self, latents, inverse=False):
18 | '''
19 | Input: (1, C, H, W)
20 | Output: (n_view, C, h, w)
21 | '''
22 | B, C, H, W = latents.shape
23 | if not inverse: # convert square to rectangle
24 | transformed_latents = torch.cat([latents[:, :, :H//2, :], latents[:, :, H//2:, :]], dim=-1)
25 | else: # convert rectangle to square
26 | transformed_latents = torch.cat([latents[:, :, :, :W//2 ], latents[:, :, :, W//2:]], dim=-2)
27 | return transformed_latents
28 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/dog.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/selected/dog'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of cute dogs, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'dog barking'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/dog.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/dog'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.1
25 |
26 | # image guidance
27 | image_prompt: 'a painting of cute dogs, grayscale'
28 | image_guidance_scale: 7.5
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'dog barking'
33 | audio_guidance_scale: 7.5
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/tiger.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/tiger'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of tigers, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'tiger growling'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/train'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of trains, grayscale'
28 | image_guidance_scale: 10
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'train whistling'
33 | audio_guidance_scale: 10
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/bell.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/bell'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of castle towers, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'bell ringing'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/bell.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/bell'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 |
23 | enable_clip_rank: false
24 | # enable_rank: True
25 | top_ranks: 0.1
26 |
27 | # image guidance
28 | image_prompt: 'a painting of castle towers, grayscale'
29 | image_guidance_scale: 7.5
30 | image_start_step: 10
31 |
32 | # audio guidance
33 | audio_prompt: 'bell ringing'
34 | audio_guidance_scale: 7.5
35 | audio_start_step: 0
36 | audio_weight: 0.5
37 |
38 | latent_transformation:
39 | _target_: src.transformation.identity.NaiveIdentity
40 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/tiger.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/tiger-long'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.1
25 |
26 | # image guidance
27 | image_prompt: 'a painting of tigers, grayscale'
28 | image_guidance_scale: 7.5
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'tiger growling'
33 | audio_guidance_scale: 7.5
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/train-long'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 |
19 | cutoff_latent: false
20 | crop_image: true
21 | use_colormap: true
22 |
23 | enable_clip_rank: false
24 | # enable_rank: True
25 | top_ranks: 0.1
26 |
27 | # image guidance
28 | image_prompt: 'a painting of trains, grayscale'
29 | image_guidance_scale: 7.5
30 | image_start_step: 10
31 |
32 | # audio guidance
33 | audio_prompt: 'train whistling'
34 | audio_guidance_scale: 7.5
35 | audio_start_step: 0
36 | audio_weight: 0.5
37 |
38 | latent_transformation:
39 | _target_: src.transformation.identity.NaiveIdentity
40 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/corgi.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/corgi'
9 |
10 | seed: 2034
11 |
12 | trainer:
13 | num_samples: 50
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of corgis, grayscale, black background'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'dog barking'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/pond.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/pond'
9 |
10 | seed: 2024
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a pond full of water lilies, grayscale, lithograph style'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'frog croaking'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/garden-v2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/selected/garden'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of a blooming garden, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'birds singing sweetly'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/horse.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/horse'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 1000
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of horse heads, grayscale, black background'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'horse neighing'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/kitten.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/kitten'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of furry kittens, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'a kitten meowing for attention'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/race.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/examples/race'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of auto racing game, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: "a race car passing by and disappearing"
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/kitten.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/kitten-long'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 |
19 | cutoff_latent: false
20 | crop_image: true
21 | use_colormap: true
22 |
23 | enable_clip_rank: false
24 | # enable_rank: True
25 | top_ranks: 0.1
26 |
27 | # image guidance
28 | image_prompt: 'a painting of furry kittens, grayscale'
29 | image_guidance_scale: 7.5
30 | image_start_step: 10
31 |
32 | # audio guidance
33 | audio_prompt: 'a kitten meowing for attention'
34 | audio_guidance_scale: 7.5
35 | audio_start_step: 0
36 | audio_weight: 0.5
37 |
38 | latent_transformation:
39 | _target_: src.transformation.identity.NaiveIdentity
40 |
--------------------------------------------------------------------------------
/configs/main_denoise/experiment/examples/garden.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-denoise/selected/garden'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 100
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 | cutoff_latent: false
19 | crop_image: true
20 | use_colormap: true
21 |
22 | enable_clip_rank: false
23 | # enable_rank: True
24 | top_ranks: 0.2
25 |
26 | # image guidance
27 | image_prompt: 'a painting of a blooming garden with many birds, grayscale'
28 | image_guidance_scale: 10.0
29 | image_start_step: 10
30 |
31 | # audio guidance
32 | audio_prompt: 'birds singing sweetly'
33 | audio_guidance_scale: 10.0
34 | audio_start_step: 0
35 | audio_weight: 0.5
36 |
37 | latent_transformation:
38 | _target_: src.transformation.identity.NaiveIdentity
39 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/garden.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/garden'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 |
19 | cutoff_latent: false
20 | crop_image: true
21 | use_colormap: true
22 |
23 | enable_clip_rank: false
24 | # enable_rank: True
25 | top_ranks: 0.1
26 |
27 | # image guidance
28 | image_prompt: 'a painting of a blooming garden with many birds, grayscale'
29 | image_guidance_scale: 7.5
30 | image_start_step: 10
31 |
32 | # audio guidance
33 | audio_prompt: 'birds singing sweetly'
34 | audio_guidance_scale: 7.5
35 | audio_start_step: 0
36 | audio_weight: 0.5
37 |
38 | latent_transformation:
39 | _target_: src.transformation.identity.NaiveIdentity
40 |
--------------------------------------------------------------------------------
/configs/main_imprint/experiment/examples/race.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-imprint/examples/racing-long'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_samples: 40
14 | num_inference_steps: 100
15 | img_height: 256
16 | img_width: 1024
17 |
18 |
19 | cutoff_latent: false
20 | crop_image: true
21 | use_colormap: true
22 |
23 | enable_clip_rank: false
24 | # enable_rank: True
25 | top_ranks: 0.1
26 |
27 | # image guidance
28 | image_prompt: 'a painting of auto racing game, grayscale'
29 | image_guidance_scale: 7.5
30 | image_start_step: 10
31 |
32 | # audio guidance
33 | audio_prompt: "a race car passing by and disappearing"
34 | audio_guidance_scale: 7.5
35 | audio_start_step: 0
36 | audio_weight: 0.5
37 |
38 | latent_transformation:
39 | _target_: src.transformation.identity.NaiveIdentity
40 |
--------------------------------------------------------------------------------
/src/transformation/random_crop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from kornia.augmentation import RandomCrop
6 |
7 |
8 | class ImageRandomCropper(nn.Module):
9 | def __init__(
10 | self,
11 | size,
12 | n_view=1,
13 | padding=None,
14 | cropping_mode="slice",
15 | p=1.0,
16 | **kwargs
17 | ):
18 | '''We implement an easy random cropping operation
19 | '''
20 | super().__init__()
21 |
22 | self.transformation = RandomCrop(
23 | size=size,
24 | padding=padding,
25 | p=p,
26 | cropping_mode=cropping_mode
27 | )
28 | self.n_view = n_view
29 |
30 | def forward(self, x):
31 | '''
32 | Input: (1, C, H, W)
33 | Output: (n_view, C, h, w)
34 | '''
35 | if x.shape[0] == 1:
36 | x = x.repeat(self.n_view, 1, 1, 1)
37 |
38 | if x.shape[1] == 1:
39 | x = x.repeat(1, 3, 1, 1)
40 |
41 | x = self.transformation(x)
42 | return x
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Ziyang Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/main_sds/experiment/examples/dog.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-sds/examples/dog'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_iteration: 40000
14 | batch_size: 8
15 | save_step: 50000
16 | visualize_step: 10000
17 | accumulate_grad_batches: 1
18 |
19 | use_colormap: true
20 | crop_image: true
21 |
22 | # image guidance
23 | image_prompt: 'a painting of cute dogs, grayscale'
24 | image_start_step: 5000
25 | image_guidance_scale: 80
26 | image_weight: 0.4
27 |
28 | # audio guidance
29 | audio_prompt: 'dog barking'
30 | audio_guidance_scale: 10
31 | audio_weight: 1
32 |
33 | image_learner:
34 | _target_: src.models.components.learnable_image.LearnableImageFourier
35 | height: 256
36 | width: 1024
37 | num_channels: 1
38 |
39 | audio_transformation:
40 | _target_: src.transformation.img_to_spec.ImageToSpec
41 | inverse: false
42 | flip: false
43 | rgb2gray: mean
44 |
45 | image_diffusion_guidance:
46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance
47 | repo_id: DeepFloyd/IF-I-M-v1.0
48 | fp16: true
49 | t_consistent: true
50 | t_range: [0.02, 0.98]
--------------------------------------------------------------------------------
/configs/main_sds/experiment/examples/bell.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-sds/examples/bell'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_iteration: 40000
14 | batch_size: 8
15 | save_step: 50000
16 | visualize_step: 10000
17 | accumulate_grad_batches: 1
18 |
19 | use_colormap: true
20 | crop_image: true
21 |
22 | # image guidance
23 | image_prompt: 'a painting of a castle towers, grayscale'
24 | image_start_step: 5000
25 | image_guidance_scale: 80
26 | image_weight: 0.4
27 |
28 | # audio guidance
29 | audio_prompt: 'Bell ringing'
30 | audio_guidance_scale: 10
31 | audio_weight: 1
32 |
33 | image_learner:
34 | _target_: src.models.components.learnable_image.LearnableImageFourier
35 | height: 256
36 | width: 1024
37 | num_channels: 1
38 |
39 | audio_transformation:
40 | _target_: src.transformation.img_to_spec.ImageToSpec
41 | inverse: false
42 | flip: false
43 | rgb2gray: mean
44 |
45 | image_diffusion_guidance:
46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance
47 | repo_id: DeepFloyd/IF-I-M-v1.0
48 | fp16: true
49 | t_consistent: true
50 | t_range: [0.02, 0.98]
--------------------------------------------------------------------------------
/configs/main_sds/experiment/examples/kitten.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-sds/examples/kitten'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_iteration: 40000
14 | batch_size: 8
15 | save_step: 50000
16 | visualize_step: 10000
17 | accumulate_grad_batches: 1
18 |
19 | use_colormap: true
20 | crop_image: true
21 |
22 | # image guidance
23 | image_prompt: 'a painting of furry kittens, grayscale'
24 | image_start_step: 5000
25 | image_guidance_scale: 80
26 | image_weight: 0.4
27 |
28 | # audio guidance
29 | audio_prompt: 'a kitten meowing for attention'
30 | audio_guidance_scale: 10
31 | audio_weight: 1
32 |
33 | image_learner:
34 | _target_: src.models.components.learnable_image.LearnableImageFourier
35 | height: 256
36 | width: 1024
37 | num_channels: 1
38 |
39 | audio_transformation:
40 | _target_: src.transformation.img_to_spec.ImageToSpec
41 | inverse: false
42 | flip: false
43 | rgb2gray: mean
44 |
45 | image_diffusion_guidance:
46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance
47 | repo_id: DeepFloyd/IF-I-M-v1.0
48 | fp16: true
49 | t_consistent: true
50 | t_range: [0.02, 0.98]
--------------------------------------------------------------------------------
/configs/main_sds/experiment/examples/garden.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | # this allows you to overwrite only specified parameters
7 |
8 | task_name: 'soundify-sds/examples/garden'
9 |
10 | seed: 1234
11 |
12 | trainer:
13 | num_iteration: 40000
14 | batch_size: 8
15 | save_step: 50000
16 | visualize_step: 10000
17 | accumulate_grad_batches: 1
18 |
19 | use_colormap: true
20 | crop_image: true
21 |
22 | # image guidance
23 | image_prompt: 'a painting of a blooming garden with many birds, grayscale'
24 | image_start_step: 5000
25 | image_guidance_scale: 80
26 | image_weight: 0.4
27 |
28 | # audio guidance
29 | audio_prompt: 'birds singing sweetly'
30 | audio_guidance_scale: 10
31 | audio_weight: 1
32 |
33 | image_learner:
34 | _target_: src.models.components.learnable_image.LearnableImageFourier
35 | height: 256
36 | width: 1024
37 | num_channels: 1
38 |
39 | audio_transformation:
40 | _target_: src.transformation.img_to_spec.ImageToSpec
41 | inverse: false
42 | flip: false
43 | rgb2gray: mean
44 |
45 | image_diffusion_guidance:
46 | _target_: src.guidance.deepfloyd.DeepfloydGuidance
47 | repo_id: DeepFloyd/IF-I-M-v1.0
48 | fp16: true
49 | t_consistent: true
50 | t_range: [0.02, 0.98]
--------------------------------------------------------------------------------
/configs/main_imprint/main.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - hydra: default
4 | # experiment configs allow for version control of specific hyperparameters
5 | # e.g. best hyperparameters for given model and datamodule
6 | - experiment: null
7 | - debug: null
8 |
9 | seed: 2024
10 | log_dir: 'logs'
11 | task_name: 'soundify-imprint'
12 | output_dir: ${log_dir}/${task_name}
13 |
14 | trainer:
15 | num_samples: 5
16 | num_inference_steps: 100
17 | img_height: 256
18 | img_width: 1024
19 |
20 | mag_ratio: 0.5
21 | inverse_image: true
22 | crop_image: false
23 | use_colormap: true
24 |
25 | enable_rank: False
26 | enable_clip_rank: False
27 | top_ranks: 0.2
28 |
29 | # image guidance
30 | image_prompt: 'a castle with bell towers, grayscale, lithograph style'
31 | image_neg_prompt: ''
32 | image_guidance_scale: 7.5
33 | image_start_step: 0
34 |
35 | # audio guidance
36 | audio_prompt: 'bell ringing'
37 | audio_neg_prompt: ''
38 | audio_guidance_scale: 7.5
39 | audio_start_step: 0
40 | audio_weight: 0.5
41 |
42 | audio_diffusion_guidance:
43 | _target_: src.guidance.auffusion.AuffusionGuidance
44 | repo_id: auffusion/auffusion-full-no-adapter
45 | fp16: True
46 | t_range: [0.02, 0.98]
47 |
48 | image_diffusion_guidance:
49 | _target_: src.guidance.stable_diffusion.StableDiffusionGuidance
50 | repo_id: runwayml/stable-diffusion-v1-5
51 | fp16: True
52 | t_consistent: True
53 | t_range: [0.02, 0.98]
54 |
55 |
56 | latent_transformation:
57 | _target_: src.transformation.identity.NaiveIdentity
58 |
59 | audio_evaluator:
60 | _target_: src.evaluator.clap.CLAPEvaluator
61 |
62 | visual_evaluator:
63 | _target_: src.evaluator.clip.CLIPEvaluator
64 |
65 | extras:
66 | ignore_warnings: true
67 | print_config: true
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/src/transformation/img_to_spec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from kornia.augmentation import RandomCrop
6 |
7 |
8 | class ImageToSpec(nn.Module):
9 | def __init__(
10 | self,
11 | inverse=False,
12 | flip=False,
13 | rgb2gray='mean',
14 | **kwargs
15 | ):
16 | '''We implement an simple image-to-spectrogram transformation
17 | '''
18 | super().__init__()
19 |
20 | self.inverse = inverse
21 | self.flip = flip
22 | self.rgb2gray = rgb2gray
23 |
24 | if self.rgb2gray == 'mean':
25 | self.coefficients = torch.ones(3).float() / 3.0
26 | elif self.rgb2gray in ['NTSC', 'ntsc']:
27 | self.coefficients = torch.tensor([0.299, 0.587, 0.114])
28 | elif self.rgb2gray == 'luminance':
29 | self.coefficients = torch.tensor([0.2126, 0.7152, 0.0722])
30 | elif self.rgb2gray == 'r_channel':
31 | self.coefficients = torch.tensor([1.0, 0.0, 0.0])
32 | elif self.rgb2gray == 'g_channel':
33 | self.coefficients = torch.tensor([0.0, 1.0, 0.0])
34 | elif self.rgb2gray == 'b_channel':
35 | self.coefficients = torch.tensor([0.0, 0.0, 1.0])
36 |
37 | def forward(self, x):
38 | '''
39 | Input: (1, C, H, W)
40 | Output: (1, 1, H, W)
41 | '''
42 |
43 | if x.shape[1] == 1:
44 | x = x.repeat(1, 3, 1, 1)
45 |
46 | coefficients = self.coefficients.view(1, -1, 1, 1).to(dtype=x.dtype, device=x.device)
47 | x = torch.sum(x * coefficients, dim=1, keepdim=True)
48 |
49 | if self.inverse:
50 | x = 1.0 - x
51 |
52 | if self.flip:
53 | x = torch.flip(x, [2])
54 |
55 | return x
56 |
--------------------------------------------------------------------------------
/configs/main_sds/main.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - hydra: default
4 | # experiment configs allow for version control of specific hyperparameters
5 | # e.g. best hyperparameters for given model and datamodule
6 | - experiment: null
7 | - debug: null
8 |
9 | seed: 2024
10 | log_dir: 'logs'
11 | output_dir: ${log_dir}/${task_name}
12 | task_name: 'soundify-sds'
13 |
14 | trainer:
15 | num_iteration: 40000
16 | batch_size: 8
17 | save_step: 50000
18 | visualize_step: 10000
19 | accumulate_grad_batches: 1
20 |
21 | use_colormap: false
22 | crop_image: false
23 |
24 | # image guidance
25 | image_prompt: 'a painting of castle towers, grayscale'
26 | image_start_step: 5000
27 | image_guidance_scale: 80
28 | image_weight: 0.4
29 |
30 | # audio guidance
31 | audio_prompt: 'bell ringing'
32 | audio_guidance_scale: 10
33 | audio_weight: 1
34 |
35 |
36 | image_learner:
37 | _target_: src.models.components.learnable_image.LearnableImageFourier
38 | height: 256
39 | width: 1024
40 | num_channels: 3
41 |
42 |
43 | audio_diffusion_guidance:
44 | _target_: src.guidance.auffusion.AuffusionGuidance
45 | repo_id: auffusion/auffusion-full-no-adapter
46 | fp16: True
47 | t_range: [0.02, 0.98]
48 |
49 | audio_transformation:
50 | _target_: src.transformation.img_to_spec.ImageToSpec
51 | inverse: false
52 | flip: false
53 | rgb2gray: mean
54 |
55 |
56 | image_diffusion_guidance:
57 | _target_: src.guidance.deepfloyd.DeepfloydGuidance
58 | repo_id: DeepFloyd/IF-I-M-v1.0
59 | fp16: true
60 | t_consistent: true
61 | t_range: [0.02, 0.98]
62 |
63 | image_transformation:
64 | _target_: src.transformation.random_crop.ImageRandomCropper
65 | size: [256, 256]
66 | n_view: ${trainer.batch_size}
67 |
68 |
69 | optimizer:
70 | _target_: torch.optim.AdamW
71 | _partial_: true
72 | lr: 0.0001
73 | weight_decay: 0.001
74 |
75 | extras:
76 | ignore_warnings: true
77 | print_config: true
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/configs/main_denoise/main.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - hydra: default
4 | # experiment configs allow for version control of specific hyperparameters
5 | # e.g. best hyperparameters for given model and datamodule
6 | - experiment: null
7 | - debug: null
8 |
9 | seed: 2024
10 | log_dir: 'logs'
11 | task_name: 'soundify-denoise'
12 | output_dir: ${log_dir}/${task_name}
13 |
14 | trainer:
15 | num_samples: 50
16 | num_inference_steps: 100
17 | img_height: 256
18 | img_width: 1024
19 | cutoff_latent: false
20 | crop_image: false
21 | use_colormap: true
22 |
23 | enable_rank: False
24 | enable_clip_rank: False
25 | top_ranks: 0.2
26 |
27 | # image guidance
28 | image_prompt: 'a castle with bell towers, grayscale, lithograph style'
29 | image_neg_prompt: ''
30 | image_guidance_scale: 10.0
31 | image_start_step: 10
32 |
33 | # audio guidance
34 | audio_prompt: 'bell ringing'
35 | audio_neg_prompt: ''
36 | audio_guidance_scale: 10.0
37 | audio_start_step: 0
38 | audio_weight: 0.5
39 |
40 | audio_diffusion_guidance:
41 | _target_: src.guidance.auffusion.AuffusionGuidance
42 | repo_id: auffusion/auffusion-full-no-adapter
43 | fp16: True
44 | t_range: [0.02, 0.98]
45 |
46 | image_diffusion_guidance:
47 | _target_: src.guidance.stable_diffusion.StableDiffusionGuidance
48 | repo_id: runwayml/stable-diffusion-v1-5
49 | fp16: True
50 | t_consistent: True
51 | t_range: [0.02, 0.98]
52 |
53 | diffusion_scheduler:
54 | _target_: diffusers.DDIMScheduler.from_pretrained
55 | pretrained_model_name_or_path: runwayml/stable-diffusion-v1-5
56 | subfolder: "scheduler"
57 | torch_dtype: torch.float16
58 |
59 | latent_transformation:
60 | _target_: src.transformation.identity.NaiveIdentity
61 |
62 | audio_evaluator:
63 | _target_: src.evaluator.clap.CLAPEvaluator
64 |
65 | visual_evaluator:
66 | _target_: src.evaluator.clip.CLIPEvaluator
67 |
68 | extras:
69 | ignore_warnings: true
70 | print_config: true
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/src/evaluator/clap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import soundfile as sf
4 | import librosa
5 | from lightning import seed_everything
6 | from transformers import AutoProcessor, ClapModel
7 |
8 | class CLAPEvaluator(nn.Module):
9 | def __init__(
10 | self,
11 | repo_id="laion/clap-htsat-unfused",
12 | **kwargs
13 | ):
14 | super(CLAPEvaluator, self).__init__()
15 | self.repo_id = repo_id
16 |
17 | # create model and load pretrained weights from huggingface
18 | self.model = ClapModel.from_pretrained(self.repo_id)
19 | self.processor = AutoProcessor.from_pretrained(self.repo_id)
20 |
21 | def forward(self, text, audio, sampling_rate=16000):
22 | return self.calc_score(text, audio, sampling_rate=sampling_rate)
23 |
24 | def calc_score(self, text, audio, sampling_rate=16000):
25 | if sampling_rate != 48000:
26 | audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=48000)
27 | audio = librosa.util.normalize(audio)
28 | inputs = self.processor(text=text, audios=audio, return_tensors="pt", sampling_rate=48000, padding=True).to(self.model.device)
29 | outputs = self.model(**inputs)
30 | logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score
31 | score = logits_per_audio / self.model.logit_scale_a.exp()
32 | score = score.squeeze().item()
33 | score = max(score, 0.0)
34 | return score
35 |
36 |
37 | if __name__ == "__main__":
38 | # import pdb; pdb.set_trace()
39 | seed_everything(2024, workers=True)
40 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
41 |
42 | text = 'Birds singing sweetly'
43 | audio_path = '/home/czyang/Workspace/images-that-sound/data/audios/audio_02.wav'
44 | # import pdb; pdb.set_trace()
45 | audio, sr = sf.read(audio_path)
46 |
47 | clap = CLAPEvaluator().to(device)
48 | score = clap.calc_score(text, audio, sampling_rate=sr)
49 | print(score)
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # reasons you might want to use `environment.yml` instead of `requirements.txt`:
2 | # - pip installs packages in a loop, without ensuring dependencies across all packages
3 | # are fulfilled simultaneously, but conda achieves proper dependency control across
4 | # all packages
5 | # - conda allows for installing packages without requiring certain compilers or
6 | # libraries to be available in the system, since it installs precompiled binaries
7 |
8 | name: soundify
9 |
10 | channels:
11 | - pytorch
12 | - conda-forge
13 | - defaults
14 | - huggingface
15 | - nvidia
16 |
17 | # it is strongly recommended to specify versions of packages installed through conda
18 | # to avoid situation when version-unspecified packages install their latest major
19 | # versions which can sometimes break things
20 |
21 | # current approach below keeps the dependencies in the same major versions across all
22 | # users, but allows for different minor and patch versions of packages where backwards
23 | # compatibility is usually guaranteed
24 |
25 | dependencies:
26 | - python=3.10
27 | - pytorch=2.1.2
28 | - torchvision=0.16
29 | - torchaudio=2.1.2
30 | - pytorch-cuda=12.1
31 | - lightning=2.2.0
32 | - torchmetrics=1.*
33 | - hydra-core=1.3
34 | - rich=13.*
35 | # - pre-commit=3.*
36 | # - pytest=7.*
37 | - diffusers=0.25.*
38 | - transformers=4.36.*
39 | - pysoundfile=0.12.1
40 | - kornia=0.7.0
41 | - tensorboard=2.15.1
42 | - accelerate
43 | - librosa
44 |
45 | - matplotlib # need to include this package in the yml file else the conda is not able to solve the conflict. this will significantly slow down the speed, remove it if you don't need
46 | - moviepy
47 | - imagemagick
48 |
49 |
50 | # --------- loggers --------- #
51 | # - wandb
52 | # - neptune-client
53 | # - mlflow
54 | # - comet-ml
55 | # - aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
56 |
57 | - pip>=23
58 | - pip:
59 | - hydra-optuna-sweeper
60 | - hydra-colorlog
61 | - rootutils
62 | - gpustat
63 | - nvitop
64 | - sentencepiece==0.2.0
65 | - bitsandbytes==0.43.1
66 |
--------------------------------------------------------------------------------
/src/evaluator/clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms.functional as TF
4 |
5 | from PIL import Image
6 |
7 | from torchmetrics.multimodal.clip_score import CLIPScore
8 | from torchmetrics.functional.multimodal import clip_score
9 | from functools import partial
10 |
11 |
12 | class CLIPEvaluator(nn.Module):
13 | def __init__(
14 | self,
15 | repo_id='openai/clip-vit-base-patch16',
16 | **kwargs
17 | ):
18 | super(CLIPEvaluator, self).__init__()
19 | self.repo_id = repo_id
20 |
21 | # create model and load pretrained weights from huggingface
22 | # self.clip = partial(clip_score, model_name_or_path=repo_id)
23 | self.clip = CLIPScore(model_name_or_path=repo_id)
24 | self.clip.reset()
25 |
26 | def forward(self, image, text):
27 | image, text = self.processing(image, text)
28 | image_int = (image * 255).to(torch.uint8)
29 | score = self.clip(image_int, text)
30 | score = score.item() / 100
31 | return score
32 |
33 | def processing(self, image, text):
34 | bsz, C, H, W = image.shape
35 | # import pdb; pdb.set_trace()
36 | if H != W:
37 | image = image.unfold(dimension=3, size=H, step=H//8) # cut image into several HxH images
38 | image = image.permute(0, 3, 1, 2, 4).squeeze(0)
39 | text = [text] * image.shape[0]
40 | return image, text
41 |
42 | if __name__ == "__main__":
43 | # import pdb; pdb.set_trace()
44 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
45 | # Load image
46 | # im_path = "logs/soundify-denoise/colorization/bell_example_29/img.png"
47 | # text = "a castle with bell towers, grayscale, lithograph style"
48 | im_path = "/home/czyang/Workspace/images-that-sound/logs/soundify-denoise/colorization/bell_example_29/img.png"
49 | text = "a castle with bell towers, grayscale, lithograph style"
50 |
51 | im = Image.open(im_path)
52 | im = TF.to_tensor(im).to(device)
53 | im = im.unsqueeze(0)
54 | # import pdb; pdb.set_trace()
55 | clip = CLIPEvaluator().to(device)
56 | score = clip(im, text)
57 | print(score)
58 |
59 | score = clip(im, text)
60 | print(score)
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/src/utils/rich_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Sequence
3 | import os, sys
4 |
5 | import rich
6 | import rich.syntax
7 | import rich.tree
8 | from hydra.core.hydra_config import HydraConfig
9 | from pytorch_lightning.utilities.rank_zero import rank_zero_only
10 | from omegaconf import DictConfig, OmegaConf, open_dict
11 | from rich.prompt import Prompt
12 |
13 |
14 | from src.utils import pylogger
15 | log = pylogger.RankedLogger(__name__, rank_zero_only=True)
16 |
17 | @rank_zero_only
18 | def print_config_tree(
19 | cfg: DictConfig,
20 | print_order: Sequence[str] = (
21 | "extras",
22 | ),
23 | resolve: bool = False,
24 | save_to_file: bool = False,
25 | ) -> None:
26 | """Prints the contents of a DictConfig as a tree structure using the Rich library.
27 |
28 | :param cfg: A DictConfig composed by Hydra.
29 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
30 | "callbacks", "logger", "trainer", "paths", "extras")``.
31 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
32 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
33 | """
34 | style = "dim"
35 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
36 |
37 | queue = []
38 |
39 | # add fields from `print_order` to queue
40 | for field in print_order:
41 | queue.append(field) if field in cfg else log.warning(
42 | f"Field '{field}' not found in config. Skipping '{field}' config printing..."
43 | )
44 |
45 | # add all the other fields to queue (not specified in `print_order`)
46 | for field in cfg:
47 | if field not in queue:
48 | queue.append(field)
49 |
50 | # generate config tree from queue
51 | for field in queue:
52 | branch = tree.add(field, style=style, guide_style=style)
53 |
54 | config_group = cfg[field]
55 | if isinstance(config_group, DictConfig):
56 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
57 | else:
58 | branch_content = str(config_group)
59 |
60 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
61 |
62 | # print config tree
63 | rich.print(tree)
64 |
--------------------------------------------------------------------------------
/src/utils/re_ranking.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import shutil
5 |
6 | import rootutils
7 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
8 | from src.utils.pylogger import RankedLogger
9 | log = RankedLogger(__name__, rank_zero_only=True)
10 |
11 |
12 | def select_top_k_ranking(cfg, clip_scores, clap_scores):
13 | top_ranks = cfg.trainer.get("top_ranks", 0.2)
14 | clap_scores = np.array(clap_scores)
15 | clip_scores = np.array(clip_scores)
16 | top_ranks = int(clip_scores.shape[0] * top_ranks)
17 |
18 | # Get top-k indices for each array
19 | top_k_clap_indices = np.argsort(clap_scores)[::-1][:top_ranks]
20 | top_k_clip_indices = np.argsort(clip_scores)[::-1][:top_ranks]
21 |
22 | # Find the joint indices of top-k indices
23 | joint_indices = set(top_k_clap_indices).intersection(set(top_k_clip_indices))
24 | joint_indices = list(joint_indices)
25 |
26 | ori_dir = os.path.join(cfg.output_dir, 'results')
27 | examples = glob.glob(f'{ori_dir}/*')
28 | examples.sort()
29 | selected_dir = os.path.join(cfg.output_dir, 'results_selected')
30 | os.makedirs(selected_dir, exist_ok=True)
31 |
32 | log.info(f"Selected {len(joint_indices)} examples.")
33 | for ind in joint_indices:
34 | example_dir_path = examples[ind]
35 | dir_name = example_dir_path.split('/')[-1]
36 | save_dir_path = os.path.join(selected_dir, dir_name)
37 | shutil.copytree(example_dir_path, save_dir_path)
38 | return
39 |
40 | def select_top_k_clip_ranking(cfg, clip_scores):
41 | # import pdb; pdb.set_trace()
42 | top_ranks = cfg.trainer.get("top_ranks", 0.1)
43 | clip_scores = np.array(clip_scores)
44 | top_ranks = int(clip_scores.shape[0] * top_ranks)
45 |
46 | # Get top-k indices
47 | top_k_clip_indices = np.argsort(clip_scores)[::-1][:top_ranks]
48 |
49 | ori_dir = os.path.join(cfg.output_dir, 'results')
50 | examples = glob.glob(f'{ori_dir}/*')
51 | examples.sort()
52 | selected_dir = os.path.join(cfg.output_dir, 'results_selected')
53 | os.makedirs(selected_dir, exist_ok=True)
54 |
55 | log.info(f"Selected {len(top_k_clip_indices)} examples.")
56 | for ind in top_k_clip_indices:
57 | example_dir_path = examples[ind]
58 | dir_name = example_dir_path.split('/')[-1]
59 | save_dir_path = os.path.join(selected_dir, dir_name)
60 | shutil.copytree(example_dir_path, save_dir_path)
61 | return
62 |
--------------------------------------------------------------------------------
/src/utils/pylogger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Mapping, Optional
3 |
4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
5 |
6 |
7 | class RankedLogger(logging.LoggerAdapter):
8 | """A multi-GPU-friendly python command line logger."""
9 |
10 | def __init__(
11 | self,
12 | name: str = __name__,
13 | rank_zero_only: bool = False,
14 | extra: Optional[Mapping[str, object]] = None,
15 | ) -> None:
16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes
17 | with their rank prefixed in the log message.
18 |
19 | :param name: The name of the logger. Default is ``__name__``.
20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
22 | """
23 | logger = logging.getLogger(name)
24 | logging.getLogger('PIL').setLevel(logging.WARNING)
25 | super().__init__(logger=logger, extra=extra)
26 | self.rank_zero_only = rank_zero_only
27 |
28 | def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
29 | """Delegate a log call to the underlying logger, after prefixing its message with the rank
30 | of the process it's being logged from. If `'rank'` is provided, then the log will only
31 | occur on that rank/process.
32 |
33 | :param level: The level to log at. Look at `logging.__init__.py` for more information.
34 | :param msg: The message to log.
35 | :param rank: The rank to log at.
36 | :param args: Additional args to pass to the underlying logging function.
37 | :param kwargs: Any additional keyword args to pass to the underlying logging function.
38 | """
39 | if self.isEnabledFor(level):
40 | msg, kwargs = self.process(msg, kwargs)
41 | current_rank = getattr(rank_zero_only, "rank", None)
42 | if current_rank is None:
43 | raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
44 | msg = rank_prefixed_message(msg, current_rank)
45 | if self.rank_zero_only:
46 | if current_rank == 0:
47 | self.logger.log(level, msg, *args, **kwargs)
48 | else:
49 | if rank is None:
50 | self.logger.log(level, msg, *args, **kwargs)
51 | elif current_rank == rank:
52 | self.logger.log(level, msg, *args, **kwargs)
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | ### VisualStudioCode
131 | .vscode/*
132 | !.vscode/settings.json
133 | !.vscode/tasks.json
134 | !.vscode/launch.json
135 | !.vscode/extensions.json
136 | *.code-workspace
137 | **/.vscode
138 |
139 | # JetBrains
140 | .idea/
141 |
142 | # Data & Models
143 | *.h5
144 | *.tar
145 | *.tar.gz
146 |
147 | # Lightning-Hydra-Template
148 | configs/local/default.yaml
149 | # /data/
150 | /logs/
151 | .env
152 |
153 | # Aim logging
154 | .aim
155 |
156 | /data
157 | /data/*
158 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Images that Sound
2 |
3 | [Ziyang Chen](https://ificl.github.io/), [Daniel Geng](https://dangeng.github.io/), [Andrew Owens](https://andrewowens.com/)
4 |
5 | University of Michigan, Ann Arbor
6 |
7 | arXiv 2024
8 |
9 | [[Paper](https://arxiv.org/abs/2405.12221)] [[Project Page](https://ificl.github.io/images-that-sound/)]
10 |
11 |
12 | This repository contains the code to generate *images that sound*, a special spectrogram that can be seen as images and played as sound.
13 |
14 |

15 |
16 |
17 |
18 | ## Environment
19 | To setup the environment, please simply run:
20 | ```bash
21 | conda env create -f environment.yml
22 | conda activate soundify
23 | ```
24 | ***Pro tip***: *we highly recommend using [mamba](https://github.com/conda-forge/miniforge) instead of conda for much faster environment solving and installation.*
25 |
26 | **DeepFlyod**: our repo also uses [DeepFloyd IF](https://huggingface.co/docs/diffusers/api/pipelines/deepfloyd_if). To use DeepFloyd IF, you must accept its usage conditions. To do so:
27 |
28 | 1. Sign up or log in to [Hugging Face account](https://huggingface.co/join).
29 | 2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0).
30 | 3. Log in locally by running `python huggingface_login.py` and entering your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens) when prompted. It does not matter how you answer the `Add token as git credential? (Y/n)` question.
31 |
32 |
33 | ## Usage
34 |
35 | We use pretrained image latent diffusion [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and pretrained audio latent diffusion [Auffusion](https://huggingface.co/auffusion/auffusion-full-no-adapter), which finetuned from Stable Diffusion. We provide the codes (including visualization) and instructions for our approach (multimodal denoising) and two proposed baselines: Imprint and SDS. We note that our code is based on the [hydra](https://github.com/facebookresearch/hydra), you can overwrite the parameters based on hydra.
36 |
37 | ### Multimodal denoising
38 | To create *images that sound* using our multimodal denoising method, run the code with config files under `configs/main_denoise/experiment`:
39 | ```bash
40 | python src/main_denoise.py experiment=examples/bell
41 | ```
42 | **Note:** our method does not have a high success rate since it's zero-shot and it highly depends on initial random noises. We recommend generating more samples such as N=100 to selectively hand-pick high-quality results.
43 |
44 |
45 | ### Imprint baseline
46 | To create *images that sound* using our proposed imprint baseline method, run the code with config files under `configs/main_imprint/experiment`:
47 | ```bash
48 | python src/main_imprint.py experiment=examples/bell
49 | ```
50 |
51 | ### SDS baseline
52 | To create *images that sound* using our proposed multimodal SDS baseline method, run the code with config file under `configs/main_sds/experiment`:
53 | ```bash
54 | python src/main_sds.py experiment=examples/bell
55 | ```
56 | **Note:** we find that Audio SDS doesn't work for a lot of audio prompts. We hypothesize the reason is that latent diffusions don't work quite well as pixel-based diffusion for SDS.
57 |
58 | ### Colorization
59 | We also provide the colorization code under `src/colorization` which is adopted from [Factorized Diffusion](https://github.com/dangeng/visual_anagrams). To directly generate colorized videos with audio, run the code:
60 | ```bash
61 | python src/colorization/create_color_video.py \
62 | --sample_dir /path/to/generated/sample/dir \
63 | --prompt "a colorful photo of [object]" \
64 | --num_samples 16 --guidance_scale 10 \
65 | --num_inference_steps 30 --start_diffusion_step 7
66 | ```
67 | **Note:** since our generated images fall outside the distribution, we recommend running more trials (num_samples=16) to select best colorized results.
68 |
69 |
70 |
71 | ## Acknowledgement
72 | Our code is based on [Lightning-Hydra-Template](https://github.com/ashleve/lightning-hydra-template), [diffusers](https://github.com/huggingface/diffusers), [stable-dreamfusion](https://github.com/ashawkey/stable-dreamfusion), [Diffusion-Illusions](https://github.com/RyannDaGreat/Diffusion-Illusions), [Auffusion](https://github.com/happylittlecat2333/Auffusion), and [visual-anagrams](https://github.com/dangeng/visual_anagrams). We appreciate their open-source codes.
--------------------------------------------------------------------------------
/src/evaluator/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from PIL import Image
3 | from tqdm import tqdm
4 | import os
5 | import glob
6 | import soundfile as sf
7 | from omegaconf import OmegaConf, DictConfig
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torchvision.transforms.functional as TF
13 |
14 | from lightning import seed_everything
15 |
16 | import rootutils
17 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
18 |
19 | from src.evaluator.clap import CLAPEvaluator
20 | from src.evaluator.clip import CLIPEvaluator
21 | from src.utils.consistency_check import griffin_lim
22 |
23 |
24 | def bootstrap_confidence_intervals(data, num_bootstraps=10000):
25 | # Bootstrap resampling
26 | bootstrap_samples = np.random.choice(data, size=(num_bootstraps, data.shape[0]), replace=True)
27 | bootstrap_means = np.mean(bootstrap_samples, axis=1)
28 |
29 | # Compute confidence interval
30 | confidence_interval = np.percentile(bootstrap_means, [2.5, 97.5])
31 |
32 | # Calculate point estimate (mean)
33 | sample_mean = np.mean(data)
34 |
35 | # Calculate margin of error
36 | margin_of_error = (confidence_interval[1] - confidence_interval[0]) / 2
37 |
38 | return sample_mean, margin_of_error
39 |
40 |
41 | def eval(args):
42 | seed_everything(2024, workers=True)
43 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
44 |
45 | tqdm.write("Preparing samples to be evaluated...")
46 | eval_samples = glob.glob(f'{args.eval_path}/*/results/*')
47 | eval_samples.sort()
48 |
49 | if args.max_sample != -1:
50 | eval_samples = eval_samples[:args.max_sample]
51 |
52 | tqdm.write("Instantiating audio evaluator...")
53 | audio_evaluator = CLAPEvaluator().to(device)
54 |
55 | tqdm.write("Instantiating visual evaluator...")
56 | visual_evaluator = CLIPEvaluator().to(device)
57 |
58 | clip_scores = []
59 | clap_scores = []
60 |
61 | for sample in tqdm(eval_samples, desc="Evaluating"):
62 | clip_score, clap_score = evaluate_single_sample(sample, audio_evaluator, visual_evaluator, device, use_griffin_lim=args.use_griffin_lim)
63 |
64 | clip_scores.append(clip_score)
65 | clap_scores.append(clap_score)
66 |
67 | clip_scores = np.array(clip_scores)
68 | clap_scores = np.array(clap_scores)
69 |
70 | # Choose a multiplier (e.g., for 95% confidence interval, multiplier is approximately 1.96)
71 | confidence_multiplier = 1.96
72 | n = clip_scores.shape[0]
73 |
74 | avg_clip_score = clip_scores.mean()
75 | # _, clip_error = bootstrap_confidence_intervals(clip_scores)
76 | clip_error = confidence_multiplier * np.std(clip_scores) / np.sqrt(n)
77 | tqdm.write(f"Averaged CLIP score: {avg_clip_score * 100} | margin of error: {clip_error * 100}")
78 |
79 | avg_clap_score = clap_scores.mean()
80 | # _, clap_error = bootstrap_confidence_intervals(clap_scores)
81 | clap_error = confidence_multiplier * np.std(clap_scores) / np.sqrt(n)
82 | tqdm.write(f"Averaged CLAP score: {avg_clap_score * 100} | margin of error: {clap_error * 100}")
83 |
84 |
85 | def evaluate_single_sample(sample_dir, audio_evaluator, visual_evaluator, device, use_griffin_lim=False):
86 | # import pdb; pdb.set_trace()
87 | # read sample dir
88 | gray_im_path = f'{sample_dir}/spec.png'
89 | audio_path = f'{sample_dir}/audio.wav'
90 | config_path = f'{sample_dir}/config.yaml'
91 | cfg = OmegaConf.load(config_path)
92 | image_prompt = cfg.trainer.image_prompt
93 | audio_prompt = cfg.trainer.audio_prompt
94 |
95 | # Load gray image and evaluate
96 | gray_im = Image.open(gray_im_path)
97 | gray_im = TF.to_tensor(gray_im).to(device)
98 | spec = gray_im.detach().cpu()
99 | gray_im = gray_im.mean(dim=0, keepdim=True).repeat(3, 1, 1)
100 | gray_im = gray_im.unsqueeze(0)
101 |
102 | clip_score = visual_evaluator(gray_im, image_prompt)
103 |
104 | # load audio waveform and evaluate
105 | audio, sr = sf.read(audio_path)
106 |
107 | if use_griffin_lim:
108 | # import pdb; pdb.set_trace()
109 | audio = griffin_lim(spec, audio)
110 |
111 | clap_score = audio_evaluator(audio_prompt, audio, sampling_rate=sr)
112 |
113 | return clip_score, clap_score
114 |
115 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/auffusion"
116 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/AV-IF-SDS-V2"
117 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/AV-Denoise-cfg7.5"
118 | # python src/evaluator/eval.py --eval_path "logs/Evaluation/AV-Denoise-notime"
119 |
120 |
121 | if __name__ == '__main__':
122 | # Parse args
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument("--eval_path", required=True, type=str)
125 | parser.add_argument('--use_griffin_lim', default=False, action='store_true')
126 | parser.add_argument("--max_sample", type=int, default=-1)
127 |
128 | args = parser.parse_args()
129 |
130 | eval(args)
--------------------------------------------------------------------------------
/src/utils/consistency_check.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | from PIL import Image
5 | import shutil
6 |
7 | import numpy as np
8 | import torch
9 | from torchvision.utils import save_image
10 | import torchvision.transforms.functional as TF
11 | import torchaudio
12 |
13 | import soundfile as sf
14 | import librosa
15 |
16 | import rootutils
17 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
18 |
19 | from src.models.components.auffusion_converter import mel_spectrogram, normalize_spectrogram, denormalize_spectrogram
20 |
21 |
22 |
23 | def wav2spec(audio, sr):
24 | audio = torch.FloatTensor(audio)
25 | audio = audio.unsqueeze(0)
26 | spec = mel_spectrogram(audio, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
27 | spec = normalize_spectrogram(spec)
28 | return spec
29 |
30 |
31 | def griffin_lim(mel_spec, ori_audio):
32 | mel_spec = denormalize_spectrogram(mel_spec)
33 | mel_spec = torch.exp(mel_spec)
34 |
35 | audio = librosa.feature.inverse.mel_to_audio(
36 | mel_spec.numpy(),
37 | sr=16000,
38 | n_fft=2048,
39 | hop_length=160,
40 | win_length=1024,
41 | power=1,
42 | center=True,
43 | # length=ori_audio.shape[0]
44 | )
45 |
46 | length = ori_audio.shape[0]
47 |
48 | if audio.shape[0] > length:
49 | audio = audio[:length]
50 | elif audio.shape[0] < length:
51 | audio = np.pad(audio, (0, length - audio.shape[0]), mode='constant')
52 |
53 | audio = np.clip(audio, a_min=-1, a_max=1)
54 | return audio
55 |
56 | def inverse_stft(mel_spec, ori_audio):
57 | mel_spec = denormalize_spectrogram(mel_spec)
58 | mel_spec = torch.exp(mel_spec)
59 |
60 | n_fft = 2048
61 | hop_length = 160
62 | win_length = 1024
63 | power = 1
64 | center = False
65 |
66 | spec_mag = librosa.feature.inverse.mel_to_stft(mel_spec.numpy(), sr=16000, n_fft=n_fft, power=power)
67 | spec_mag = torch.tensor(spec_mag).float()
68 |
69 | audio_length = ori_audio.shape[0]
70 | ori_audio = torch.tensor(ori_audio)
71 | ori_audio = ori_audio.unsqueeze(0)
72 | ori_audio = torch.nn.functional.pad(ori_audio.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
73 | ori_audio = ori_audio.squeeze(1)
74 | vocoder_spec = torch.stft(ori_audio, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
75 | # Get the phase from the complex spectrogram
76 | vocoder_phase = torch.angle(vocoder_spec).float()
77 | # import pdb; pdb.set_trace()
78 |
79 | # Combine the new magnitude with the original phase
80 | # We use polar coordinates to transform magnitude and phase into a complex number
81 | reconstructed_complex_spec = torch.polar(spec_mag.unsqueeze(0), vocoder_phase)
82 |
83 | # Perform the ISTFT to convert the spectrogram back to time domain audio signal
84 | reconstructed_audio = torch.istft(
85 | reconstructed_complex_spec,
86 | n_fft,
87 | hop_length=hop_length,
88 | win_length=win_length,
89 | window=torch.hann_window(win_length),
90 | center=True,
91 | normalized=False,
92 | onesided=True,
93 | length=audio_length # Ensure output audio length matches original audio
94 | )
95 |
96 | reconstructed_audio = reconstructed_audio.squeeze(0).numpy()
97 | reconstructed_audio = np.clip(reconstructed_audio, a_min=-1, a_max=1)
98 |
99 | return reconstructed_audio
100 |
101 |
102 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/colorization/bell_example_005"
103 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/colorization/tiger_example_002"
104 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/colorization/dog_example_06"
105 | # python src/utils/consistency_check.py --dir "logs/soundify-denoise/debug/results/example_015"
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 | parser = argparse.ArgumentParser()
111 | parser.add_argument("--dir", required=False, type=str, default="logs/soundify-denoise/colorization/bell_example_29")
112 |
113 | args = parser.parse_args()
114 | save_dir = f"logs/consistency-check/{args.dir.split('/')[-1]}"
115 | os.makedirs(save_dir, exist_ok=True)
116 |
117 | # import pdb; pdb.set_trace()
118 |
119 | # audio from vocoder to spectrogram
120 | audio_path = os.path.join(args.dir, "audio.wav")
121 | audio_data, sampling_rate = sf.read(audio_path)
122 | spec = wav2spec(audio_data, sampling_rate)
123 |
124 | save_path = os.path.join(save_dir, f"respec-hifi.png")
125 | save_image(spec, save_path, padding=0)
126 |
127 | # spectrogram to audio using griffin-lim
128 | spec_path = os.path.join(args.dir, "spec.png")
129 | spec = Image.open(spec_path)
130 | spec = TF.to_tensor(spec)
131 |
132 | save_path = os.path.join(save_dir, f"spec.png")
133 | shutil.copyfile(spec_path, save_path)
134 |
135 | audio_istft = inverse_stft(spec, audio_data)
136 | save_audio_path = os.path.join(save_dir, f"audio-istft.wav")
137 | sf.write(save_audio_path, audio_istft, samplerate=16000)
138 | spec_istft = wav2spec(audio_istft, sampling_rate)
139 |
140 | save_path = os.path.join(save_dir, f"respec-istft.png")
141 | save_image(spec_istft, save_path, padding=0)
142 |
143 |
144 | audio_gl = griffin_lim(spec, audio_data)
145 | save_audio_path = os.path.join(save_dir, f"audio-griffin-lim.wav")
146 | sf.write(save_audio_path, audio_gl, samplerate=16000)
147 |
148 | # audio_data, sampling_rate = sf.read(save_audio_path)
149 |
150 | spec_gl = wav2spec(audio_gl, sampling_rate)
151 |
152 | save_path = os.path.join(save_dir, f"respec-gl.png")
153 | save_image(spec_gl, save_path, padding=0)
154 |
155 |
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/src/colorization/create_color_video.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import glob
6 | import os
7 | from omegaconf import OmegaConf, DictConfig, open_dict
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torchvision.transforms.functional as TF
12 | from torchvision.utils import save_image
13 |
14 | import rootutils
15 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16 |
17 | from src.colorization.colorizer import FactorizedColorization
18 | from src.utils.animation_with_text import create_animation_with_text
19 |
20 |
21 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/bell_example_29 --prompt "a colorful photo of a castle with bell towers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
22 |
23 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/bell_example_005 --prompt "a colorful photo of a castle with bell towers" --num_samples 12 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
24 |
25 |
26 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/tiger_example_02 --prompt "a colorful photo of a tigers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
27 |
28 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/tiger_example_06 --prompt "a colorful photo of a tigers" --num_samples 8 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
29 |
30 |
31 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/race_example_002 --prompt "a colorful photo of a auto racing game" --num_samples 12 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
32 |
33 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/bird_example_40 --prompt "a blooming garden with many birds" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
34 |
35 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/kitten_example_08 --prompt "a colorful photo of kittens" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
36 |
37 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/kitten_example_08_v2 --prompt "a colorful photo of kittens with blue eyes and pink noses" --num_samples 16 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
38 |
39 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/dog_example_06 --prompt "a colorful photo of dogs" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
40 |
41 | # python src/colorization/create_color_video.py --sample_dir logs/soundify-denoise/colorization/train_example_02 --prompt "a colorful photo of a long train" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
42 |
43 | if __name__ == '__main__':
44 | # Parse args
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--sample_dir", required=True, type=str)
47 | parser.add_argument("--prompt", required=True, type=str, help='Prompts to use for colorization')
48 | parser.add_argument("--num_samples", type=int, default=4)
49 | parser.add_argument("--depth", type=int, default=0)
50 | # parser.add_argument('--no_colormap', default=False, action='store_true')
51 | parser.add_argument("--guidance_scale", type=float, default=10.0)
52 | parser.add_argument("--num_inference_steps", type=int, default=30)
53 | parser.add_argument("--seed", type=int, default=0)
54 | parser.add_argument("--device", type=str, default='cuda')
55 | parser.add_argument("--noise_level", type=int, default=50, help='Noise level for stage 2')
56 | parser.add_argument("--start_diffusion_step", type=int, default=7, help='What step to start the diffusion process')
57 |
58 |
59 | args = parser.parse_args()
60 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
61 |
62 | # create diffusion colorization instance
63 | colorizer = FactorizedColorization(
64 | inverse_color=False,
65 | num_inference_steps=args.num_inference_steps,
66 | guidance_scale=args.guidance_scale,
67 | start_diffusion_step=args.start_diffusion_step,
68 | noise_level=args.noise_level,
69 | ).to(device)
70 |
71 |
72 | sample_dirs = glob.glob(f"{args.sample_dir}" + "/*" * args.depth)
73 | sample_dirs.sort()
74 |
75 | # read sample dir
76 | for sample_dir in sample_dirs:
77 | gray_im_path = f'{sample_dir}/img.png'
78 | spec = f'{sample_dir}/spec_colormap.png'
79 | if not os.path.exists(spec):
80 | spec = f'{sample_dir}/spec.png'
81 |
82 | audio = f'{sample_dir}/audio.wav'
83 | config_path = f'{sample_dir}/config.yaml'
84 | cfg = OmegaConf.load(config_path)
85 | image_prompt = args.prompt
86 | audio_prompt = cfg.trainer.audio_prompt
87 |
88 | with open_dict(cfg):
89 | cfg.trainer.colorization_prompt = args.prompt
90 | OmegaConf.save(cfg, config_path)
91 |
92 | # Load gray image
93 | gray_im = Image.open(gray_im_path)
94 | gray_im = TF.to_tensor(gray_im).to(device)
95 |
96 | img_save_dir = os.path.join(sample_dir, 'colorized_imgs')
97 | os.makedirs(img_save_dir, exist_ok=True)
98 |
99 | video_save_dir = os.path.join(sample_dir, 'colorized_videos')
100 | os.makedirs(video_save_dir, exist_ok=True)
101 |
102 | # Sample illusions
103 | for i in tqdm(range(args.num_samples), desc="Sampling images"):
104 | generator = torch.manual_seed(args.seed + i)
105 | image = colorizer(gray_im, args.prompt, generator=generator)
106 | img_save_path = f'{img_save_dir}/{i:04}.png'
107 | save_image(image, img_save_path, padding=0)
108 |
109 | video_save_path = f'{video_save_dir}/{i:04}.mp4'
110 | create_animation_with_text(img_save_path, spec, audio, video_save_path, image_prompt, audio_prompt)
111 |
112 |
--------------------------------------------------------------------------------
/src/models/components/learnable_image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class LearnableImage(nn.Module):
7 | def __init__(
8 | self,
9 | height :int,
10 | width :int,
11 | num_channels:int
12 | ):
13 | '''This is an abstract class, and is meant to be subclassed before use upon calling forward(), retuns a tensor of shape (num_channels, height, width)
14 | '''
15 | super().__init__()
16 |
17 | self.height = height
18 | self.width = width
19 | self.num_channels = num_channels
20 |
21 | def as_numpy_image(self):
22 | image = self.forward()
23 | image = image.cpu().numpy()
24 | image = image.transpose(1, 2, 0)
25 | return image
26 |
27 |
28 | class LearnableImageFourier(LearnableImage):
29 | def __init__(
30 | self,
31 | height :int=256, # Height of the learnable images
32 | width :int=256, # Width of the learnable images
33 | num_channels:int=3 , # Number of channels in the images
34 | hidden_dim :int=256, # Number of dimensions per hidden layer of the MLP
35 | num_features:int=128, # Number of fourier features per coordinate
36 | scale :int=10 , # Magnitude of the initial feature noise
37 | renormalize :bool=False
38 | ):
39 | super().__init__(height, width, num_channels)
40 |
41 | self.hidden_dim = hidden_dim
42 | self.num_features = num_features
43 | self.scale = scale
44 | self.renormalize = renormalize
45 |
46 | # The following objects do NOT have parameters, and are not changed while optimizing this class
47 | self.uv_grid = nn.Parameter(get_uv_grid(height, width, batch_size=1), requires_grad=False)
48 | self.feature_extractor = GaussianFourierFeatureTransform(2, num_features, scale)
49 | self.features = nn.Parameter(self.feature_extractor(self.uv_grid), requires_grad=False) # pre-compute this if we're regressing on images
50 |
51 | H = hidden_dim # Number of hidden features. These 1x1 convolutions act as a per-pixel MLP
52 | C = num_channels # Shorter variable names let us align the code better
53 | M = 2 * num_features
54 | self.model = nn.Sequential(
55 | nn.Conv2d(M, H, kernel_size=1), nn.ReLU(), nn.BatchNorm2d(H),
56 | nn.Conv2d(H, H, kernel_size=1), nn.ReLU(), nn.BatchNorm2d(H),
57 | nn.Conv2d(H, H, kernel_size=1), nn.ReLU(), nn.BatchNorm2d(H),
58 | nn.Conv2d(H, C, kernel_size=1),
59 | nn.Sigmoid(),
60 | )
61 |
62 | def forward(self):
63 | features = self.features
64 |
65 | output = self.model(features).squeeze(0)
66 |
67 | assert output.shape==(self.num_channels, self.height, self.width)
68 |
69 | if self.renormalize:
70 | output = output * 2 - 1.0 # renormalize to [-1, 1]
71 |
72 | return output
73 |
74 |
75 | class LearnableImageParam(LearnableImage):
76 | def __init__(
77 | self,
78 | height :int=256, # Height of the learnable images
79 | width :int=256, # Width of the learnable images
80 | num_channels:int=3 , # Number of channels in the images\
81 | **kwargs
82 | ):
83 | super().__init__(height, width, num_channels)
84 |
85 | self.model = nn.Parameter(torch.randn(num_channels, height, width))
86 |
87 |
88 | def forward(self):
89 | x = self.model
90 | return x
91 |
92 |
93 | ##################################
94 | ######## HELPER FUNCTIONS ########
95 | ##################################
96 |
97 | class GaussianFourierFeatureTransform(nn.Module):
98 | """
99 | Original authors: https://github.com/ndahlquist/pytorch-fourier-feature-networks
100 |
101 | An implementation of Gaussian Fourier feature mapping.
102 |
103 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
104 | https://arxiv.org/abs/2006.10739
105 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
106 |
107 | Given an input of size [batches, num_input_channels, width, height],
108 | returns a tensor of size [batches, num_features*2, width, height].
109 | """
110 | def __init__(self, num_channels, num_features=256, scale=10):
111 | #It generates fourier components of Arandom frequencies, not all of them.
112 | #The frequencies are determined by a random normal distribution, multiplied by "scale"
113 | #So, when "scale" is higher, the fourier features will have higher frequencies
114 | #In learnable_image_tutorial.ipynb, this translates to higher fidelity images.
115 | #In other words, 'scale' loosely refers to the X,Y scale of the images
116 | #With a high scale, you can learn detailed images with simple MLP's
117 | #If it's too high though, it won't really learn anything but high frequency noise
118 |
119 | super().__init__()
120 |
121 | self.num_channels = num_channels
122 | self.num_features = num_features
123 |
124 | #freqs are n-dimensional spatial frequencies, where n=num_channels
125 | self.freqs = nn.Parameter(torch.randn(num_channels, num_features) * scale, requires_grad=False)
126 |
127 | def forward(self, x):
128 | assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())
129 |
130 | batch_size, num_channels, height, width = x.shape
131 |
132 | assert num_channels == self.num_channels,\
133 | "Expected input to have {} channels (got {} channels)".format(self.num_channels, num_channels)
134 |
135 | # Make shape compatible for matmul with freqs.
136 | # From [B, C, H, W] to [(B*H*W), C].
137 | x = x.permute(0, 2, 3, 1).reshape(batch_size * height * width, num_channels)
138 |
139 | # [(B*H*W), C] x [C, F] = [(B*H*W), F]
140 | x = x @ self.freqs
141 |
142 | # From [(B*H*W), F] to [B, H, W, F]
143 | x = x.view(batch_size, height, width, self.num_features)
144 | # From [B, H, W, F] to [B, F, H, W
145 | x = x.permute(0, 3, 1, 2)
146 |
147 | x = 2 * torch.pi * x
148 |
149 | output = torch.cat([torch.sin(x), torch.cos(x)], dim=1)
150 |
151 | assert output.shape==(batch_size, 2*self.num_features, height, width)
152 |
153 | return output
154 |
155 |
156 | def get_uv_grid(height:int, width:int, batch_size:int=1)->torch.Tensor:
157 | #Returns a torch cpu tensor of shape (batch_size,2,height,width)
158 | #Note: batch_size can probably be removed from this function after refactoring this file. It's always 1 in all usages.
159 | #The second dimension is (x,y) coordinates, which go from [0 to 1) from edge to edge
160 | #(In other words, it will include x=y=0, but instead of x=y=1 the other corner will be x=y=.999)
161 | #(this is so it doesn't wrap around the texture 360 degrees)
162 |
163 | # import pdb; pdb.set_trace()
164 | assert height>0 and width>0 and batch_size>0,'All dimensions must be positive integers'
165 |
166 | y_coords = np.linspace(0, 1, height, endpoint=False)
167 | x_coords = np.linspace(0, 1, width , endpoint=False)
168 |
169 | uv_grid = np.stack(np.meshgrid(y_coords, x_coords), -1)
170 | uv_grid = torch.tensor(uv_grid).unsqueeze(0).permute(0, 3, 2, 1).float().contiguous()
171 | uv_grid = uv_grid.repeat(batch_size,1,1,1)
172 |
173 | assert tuple(uv_grid.shape)==(batch_size,2,height,width)
174 |
175 | return uv_grid
176 |
177 |
--------------------------------------------------------------------------------
/src/colorization/colorizer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import os
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torchvision.transforms.functional as TF
10 | from torchvision.utils import save_image
11 |
12 | from diffusers import DiffusionPipeline
13 |
14 |
15 | import rootutils
16 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
17 |
18 | from src.colorization.views import ColorLView, ColorABView
19 | from src.colorization.samplers import sample_stage_1, sample_stage_2
20 |
21 |
22 |
23 | class FactorizedColorization(nn.Module):
24 | '''Colorization diffusion model by Factorized Diffusion
25 | '''
26 | def __init__(
27 | self,
28 | inverse_color=False,
29 | **kwargs
30 | ):
31 | super().__init__()
32 |
33 | # Make DeepFloyd IF stage I
34 | self.stage_1 = DiffusionPipeline.from_pretrained(
35 | "DeepFloyd/IF-I-M-v1.0",
36 | variant="fp16",
37 | torch_dtype=torch.float16
38 | )
39 | self.stage_1.enable_model_cpu_offload()
40 |
41 | # Make DeepFloyd IF stage II
42 | self.stage_2 = DiffusionPipeline.from_pretrained(
43 | "DeepFloyd/IF-II-M-v1.0",
44 | text_encoder=None,
45 | variant="fp16",
46 | torch_dtype=torch.float16,
47 | )
48 | self.stage_2.enable_model_cpu_offload()
49 |
50 | # if inverse the gray scale
51 | self.inverse_color = inverse_color
52 |
53 | # get views
54 | self.views = [ColorLView(), ColorABView()]
55 |
56 | self.num_inference_steps = kwargs.get("num_inference_steps", 30)
57 | self.guidance_scale = kwargs.get("guidance_scale", 10.0)
58 | self.start_diffusion_step = kwargs.get("start_diffusion_step", 0)
59 | self.noise_level = kwargs.get("noise_level", 50)
60 |
61 |
62 | @torch.no_grad()
63 | def get_text_embeds(self, prompt):
64 | # Get prompt embeddings (need two, because code is designed for
65 | # two components: L and ab)
66 | prompts = [prompt] * 2
67 | prompt_embeds = [self.stage_1.encode_prompt(p) for p in prompts]
68 | prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
69 | prompt_embeds = torch.cat(prompt_embeds)
70 | negative_prompt_embeds = torch.cat(negative_prompt_embeds) # These are just null embeds
71 | return prompt_embeds, negative_prompt_embeds
72 |
73 | def forward(
74 | self,
75 | gray_im,
76 | prompt,
77 | num_inference_steps=None,
78 | guidance_scale=None,
79 | start_diffusion_step=None,
80 | noise_level=None,
81 | generator=None
82 | ):
83 | # 1. overwrite the hyparams if provided
84 | num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
85 | guidance_scale = self.guidance_scale if guidance_scale is None else guidance_scale
86 | start_diffusion_step = self.start_diffusion_step if start_diffusion_step is None else start_diffusion_step
87 | noise_level = self.noise_level if noise_level is None else noise_level
88 |
89 | # 2. prepare the text embeddings
90 | prompt_embeds, negative_prompt_embeds = self.get_text_embeds(prompt)
91 |
92 | # import pdb; pdb.set_trace()
93 |
94 | # 3. prepare grayscale image
95 | _, height, width = gray_im.shape
96 | if self.inverse_color:
97 | gray_im = 1.0 - gray_im
98 |
99 | gray_im = gray_im * 2.0 - 1 # normalize the pixel value
100 |
101 | # 4. Sample 64x64 image
102 | image = sample_stage_1(
103 | self.stage_1,
104 | prompt_embeds,
105 | negative_prompt_embeds,
106 | self.views,
107 | height=height // 4,
108 | width=width // 4,
109 | fixed_im=gray_im,
110 | num_inference_steps=num_inference_steps,
111 | guidance_scale=guidance_scale,
112 | generator=generator,
113 | start_diffusion_step=start_diffusion_step
114 | )
115 |
116 | # 5. Sample 256x256 image, by upsampling 64x64 image
117 | image = sample_stage_2(
118 | self.stage_2,
119 | image,
120 | prompt_embeds,
121 | negative_prompt_embeds,
122 | self.views,
123 | height=height,
124 | width=width,
125 | fixed_im=gray_im,
126 | num_inference_steps=num_inference_steps,
127 | guidance_scale=guidance_scale,
128 | noise_level=noise_level,
129 | generator=generator
130 | )
131 |
132 | # 6. return the final image
133 | image = image / 2 + 0.5
134 | return image
135 |
136 |
137 | # python colorizer.py --name colorize.castle.full --gray_im_path ./imgs/castle.full.png --prompt "a colorful photo of a white castle with bell towers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
138 |
139 | # python colorizer.py --name colorize.racing.full --gray_im_path ./imgs/racing.full.png --prompt "a colorful photo of a auto racing game" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
140 |
141 | # python colorizer.py --name colorize.tiger.full --gray_im_path ./imgs/tiger.full.png --prompt "a colorful photo of a tigers" --num_samples 4 --guidance_scale 10 --num_inference_steps 30 --start_diffusion_step 7
142 |
143 | # python colorizer.py --name colorize.dog.full --gray_im_path ./imgs/dog.full.png --prompt "a colorful photo of puppies on green grass" --num_samples 4 --guidance_scale 10.0 --num_inference_steps 30 --start_diffusion_step 7
144 |
145 | # python colorizer.py --name colorize.spec.full --gray_im_path ./imgs/spec.full.png --prompt "a colorful photo of kittens" --num_samples 8 --guidance_scale 10.0 --num_inference_steps 30 --start_diffusion_step 0
146 |
147 | if __name__ == '__main__':
148 | # Parse args
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument("--name", required=True, type=str)
151 | parser.add_argument("--gray_im_path", required=True, type=str)
152 | parser.add_argument("--prompt", required=True, type=str, help='Prompts to use for colorization')
153 | parser.add_argument("--num_samples", type=int, default=4)
154 | parser.add_argument("--guidance_scale", type=float, default=10.0)
155 | parser.add_argument("--num_inference_steps", type=int, default=30)
156 | parser.add_argument("--seed", type=int, default=0)
157 | parser.add_argument("--device", type=str, default='cuda')
158 | parser.add_argument("--noise_level", type=int, default=50, help='Noise level for stage 2')
159 | parser.add_argument("--start_diffusion_step", type=int, default=7, help='What step to start the diffusion process')
160 |
161 |
162 | args = parser.parse_args()
163 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
164 |
165 | # create diffusion colorization instance
166 | colorizer = FactorizedColorization(
167 | inverse_color=False,
168 | num_inference_steps=args.num_inference_steps,
169 | guidance_scale=args.guidance_scale,
170 | start_diffusion_step=args.start_diffusion_step,
171 | noise_level=args.noise_level,
172 | ).to(device)
173 |
174 | # Load gray image
175 | gray_im = Image.open(args.gray_im_path)
176 | gray_im = TF.to_tensor(gray_im).to(device)
177 |
178 | save_dir = os.path.join('results', args.name)
179 | os.makedirs(save_dir, exist_ok=True)
180 |
181 | # Sample illusions
182 | for i in tqdm(range(args.num_samples), desc="Sampling images"):
183 | generator = torch.manual_seed(args.seed + i)
184 | image = colorizer(gray_im, args.prompt, generator=generator)
185 | save_image(image, f'{save_dir}/{i:04}.png', padding=0)
186 |
--------------------------------------------------------------------------------
/src/guidance/deepfloyd.py:
--------------------------------------------------------------------------------
1 | from diffusers import IFPipeline, DDPMScheduler
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from lightning import seed_everything
8 |
9 |
10 | class DeepfloydGuidance(nn.Module):
11 | def __init__(
12 | self,
13 | repo_id='DeepFloyd/IF-I-XL-v1.0',
14 | fp16=True,
15 | t_range=[0.02, 0.98],
16 | t_consistent=False,
17 | **kwargs
18 | ):
19 | super().__init__()
20 |
21 | self.repo_id = repo_id
22 | self.precision_t = torch.float16 if fp16 else torch.float32
23 |
24 | # Create model
25 | pipe = IFPipeline.from_pretrained(repo_id, torch_dtype=self.precision_t)
26 |
27 | self.unet = pipe.unet
28 | self.tokenizer = pipe.tokenizer
29 | self.text_encoder = pipe.text_encoder
30 | self.unet = pipe.unet
31 | self.scheduler = pipe.scheduler
32 |
33 | self.pipe = pipe
34 |
35 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps
36 | self.min_step = int(self.num_train_timesteps * t_range[0])
37 | self.max_step = int(self.num_train_timesteps * t_range[1])
38 | self.t_consistent = t_consistent
39 |
40 | self.register_buffer('alphas', self.scheduler.alphas_cumprod) # for convenience
41 |
42 | @torch.no_grad()
43 | def get_text_embeds(self, prompt, device):
44 | # prompt: [str]
45 | prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
46 | inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')
47 | embeddings = self.text_encoder(inputs.input_ids.to(device))[0]
48 | embeddings = embeddings.to(dtype=self.text_encoder.dtype, device=device)
49 | return embeddings
50 |
51 |
52 | def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, t=None, grad_scale=1):
53 | pred_rgb = pred_rgb.to(self.unet.dtype)
54 | # [0, 1] to [-1, 1] and make sure shape is [64, 64]
55 | images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
56 |
57 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
58 | if t is None:
59 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
60 | if self.t_consistent:
61 | t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=images.device)
62 | t = t.repeat(images.shape[0])
63 | else:
64 | t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=images.device)
65 | else:
66 | t = t.to(dtype=torch.long, device=images.device)
67 |
68 | # predict the noise residual with unet, NO grad!
69 | with torch.no_grad():
70 | # add noise
71 | noise = torch.randn_like(images)
72 | images_noisy = self.scheduler.add_noise(images, noise, t)
73 |
74 | # pred noise
75 | model_input = torch.cat([images_noisy] * 2)
76 | model_input = self.scheduler.scale_model_input(model_input, t)
77 | tt = torch.cat([t] * 2)
78 | noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
79 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
80 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
81 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
82 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
83 |
84 | # w(t), sigma_t^2
85 | w = (1 - self.alphas[t])
86 | grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
87 | grad = torch.nan_to_num(grad)
88 |
89 | targets = (images - grad).detach()
90 | loss = 0.5 * F.mse_loss(images.float(), targets, reduction='sum') / images.shape[0]
91 |
92 | return loss
93 |
94 |
95 | @torch.no_grad()
96 | def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):
97 |
98 | images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)
99 | images = images * self.scheduler.init_noise_sigma
100 |
101 | self.scheduler.set_timesteps(num_inference_steps)
102 |
103 | for i, t in enumerate(self.scheduler.timesteps):
104 | # expand the image if we are doing classifier-free guidance to avoid doing two forward passes.
105 | model_input = torch.cat([images] * 2)
106 | model_input = self.scheduler.scale_model_input(model_input, t)
107 |
108 | # predict the noise residual
109 | noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
110 |
111 | # perform guidance
112 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
113 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
114 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
115 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
116 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
117 |
118 | # compute the previous noisy sample x_t -> x_t-1
119 | images = self.scheduler.step(noise_pred, t, images).prev_sample
120 |
121 | images = (images + 1) / 2
122 |
123 | return images
124 |
125 | def prompt_to_img(self, prompts, negative_prompts='', height=64, width=64, num_inference_steps=50, guidance_scale=7.5, device=None):
126 |
127 | if isinstance(prompts, str):
128 | prompts = [prompts]
129 |
130 | if isinstance(negative_prompts, str):
131 | negative_prompts = [negative_prompts]
132 |
133 | # Prompts -> text embeds
134 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768]
135 | neg_embeds = self.get_text_embeds(negative_prompts, device)
136 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
137 |
138 | # Text embeds -> img
139 | imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
140 |
141 | # Img to Numpy
142 | imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
143 | imgs = (imgs * 255).round().astype('uint8')
144 |
145 | return imgs
146 |
147 |
148 | if __name__ == '__main__':
149 | import argparse
150 | from PIL import Image
151 | import os
152 |
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument('--prompt', type=str, default='an oil paint of modern city, street view')
155 | parser.add_argument('--negative', default='', type=str)
156 | parser.add_argument('--repo_id', type=str, default='DeepFloyd/IF-I-XL-v1.0', help="stable diffusion version")
157 | parser.add_argument('--fp16', action='store_true', help="use float16 for training")
158 | parser.add_argument('--H', type=int, default=64)
159 | parser.add_argument('--W', type=int, default=64)
160 | parser.add_argument('--seed', type=int, default=0)
161 | parser.add_argument('--steps', type=int, default=50)
162 | opt = parser.parse_args()
163 |
164 | seed_everything(opt.seed)
165 |
166 | device = torch.device('cuda')
167 |
168 | sd = DeepfloydGuidance(repo_id=opt.repo_id, fp16=opt.fp16).to(device)
169 |
170 | imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, device=device)
171 |
172 | # visualize image
173 | save_path = 'logs/test'
174 | os.makedirs(save_path, exist_ok=True)
175 | image = Image.fromarray(imgs[0], mode='RGB')
176 | image.save(os.path.join(save_path, f'{opt.prompt}.png'))
--------------------------------------------------------------------------------
/src/main_imprint.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | from tqdm import tqdm
4 | import warnings
5 | import soundfile as sf
6 | import numpy as np
7 | import shutil
8 | import glob
9 | import copy
10 | import matplotlib.pyplot as plt
11 |
12 | import hydra
13 | from omegaconf import OmegaConf, DictConfig, open_dict
14 | import torch
15 | import torch.nn as nn
16 | from transformers import logging
17 | from lightning import seed_everything
18 |
19 | from torchvision.utils import save_image
20 |
21 |
22 | import rootutils
23 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
24 |
25 | from src.utils.rich_utils import print_config_tree
26 | from src.utils.animation_with_text import create_animation_with_text, create_single_image_animation_with_text
27 | from src.utils.re_ranking import select_top_k_ranking, select_top_k_clip_ranking
28 | from src.utils.pylogger import RankedLogger
29 | log = RankedLogger(__name__, rank_zero_only=True)
30 |
31 |
32 |
33 | def save_audio(audio, save_path):
34 | sf.write(save_path, audio, samplerate=16000)
35 |
36 |
37 | def encode_prompt(prompt, diffusion_guidance, device, negative_prompt='', time_repeat=1):
38 | '''Encode text prompts into embeddings
39 | '''
40 | prompts = [prompt] * time_repeat
41 | negative_prompts = [negative_prompt] * time_repeat
42 |
43 | # Prompts -> text embeds
44 | cond_embeds = diffusion_guidance.get_text_embeds(prompts, device) # [B, 77, 768]
45 | uncond_embeds = diffusion_guidance.get_text_embeds(negative_prompts, device) # [B, 77, 768]
46 | text_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) # [2 * B, 77, 768]
47 | return text_embeds
48 |
49 | def estimate_noise(diffusion, latents, t, text_embeddings, guidance_scale):
50 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
51 | latent_model_input = torch.cat([latents] * 2)
52 | # predict the noise residual
53 | noise_pred = diffusion.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
54 |
55 | # perform guidance
56 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
57 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
58 | return noise_pred
59 |
60 |
61 |
62 | @hydra.main(version_base="1.3", config_path="../configs/main_imprint", config_name="main.yaml")
63 | def main(cfg: DictConfig) -> Optional[float]:
64 | """Main function for training
65 | """
66 |
67 | if cfg.extras.get("ignore_warnings"):
68 | log.info("Disabling python warnings! ")
69 | warnings.filterwarnings("ignore")
70 | logging.set_verbosity_error()
71 |
72 | if cfg.extras.get("print_config"):
73 | print_config_tree(cfg, resolve=True)
74 |
75 | # set seed for random number generators in pytorch, numpy and python.random
76 | if cfg.get("seed"):
77 | seed_everything(cfg.seed, workers=True)
78 |
79 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
80 |
81 | log.info(f"Instantiating Image Diffusion model <{cfg.image_diffusion_guidance._target_}>")
82 | image_diffusion_guidance = hydra.utils.instantiate(cfg.image_diffusion_guidance).to(device)
83 |
84 | log.info(f"Instantiating Audio Diffusion guidance model <{cfg.audio_diffusion_guidance._target_}>")
85 | audio_diffusion_guidance = hydra.utils.instantiate(cfg.audio_diffusion_guidance).to(device)
86 |
87 | # create transformation
88 | log.info(f"Instantiating latent transformation <{cfg.latent_transformation._target_}>")
89 | latent_transformation = hydra.utils.instantiate(cfg.latent_transformation).to(device)
90 |
91 | # create audio evaluator
92 | if cfg.audio_evaluator:
93 | log.info(f"Instantiating audio evaluator <{cfg.audio_evaluator._target_}>")
94 | audio_evaluator = hydra.utils.instantiate(cfg.audio_evaluator).to(device)
95 | else:
96 | audio_evaluator = None
97 |
98 | if cfg.visual_evaluator:
99 | log.info(f"Instantiating visual evaluator <{cfg.visual_evaluator._target_}>")
100 | visual_evaluator = hydra.utils.instantiate(cfg.visual_evaluator).to(device)
101 | else:
102 | visual_evaluator = None
103 |
104 | clip_scores = []
105 | clap_scores = []
106 | log.info(f"Starting sampling!")
107 | for idx in tqdm(range(cfg.trainer.num_samples), desc='Sampling'):
108 | clip_score, clap_score = create_sample(cfg, image_diffusion_guidance, audio_diffusion_guidance, latent_transformation, visual_evaluator, audio_evaluator, idx, device)
109 | clip_scores.append(clip_score)
110 | clap_scores.append(clap_score)
111 |
112 | # re-ranking by metrics
113 | enable_rank = cfg.trainer.get("enable_rank", False)
114 | if enable_rank:
115 | log.info(f"Starting re-ranking and selection!")
116 | select_top_k_ranking(cfg, clip_scores, clap_scores)
117 |
118 | enable_clip_rank = cfg.trainer.get("enable_clip_rank", False)
119 | if enable_clip_rank:
120 | log.info(f"Starting re-ranking and selection by CLIP score!")
121 | select_top_k_clip_ranking(cfg, clip_scores)
122 |
123 | log.info(f"Finished!")
124 |
125 |
126 | @torch.no_grad()
127 | def create_sample(cfg, image_diffusion, audio_diffusion, latent_transformation, visual_evaluator, audio_evaluator, idx, device):
128 | image_guidance_scale, audio_guidance_scale = cfg.trainer.image_guidance_scale, cfg.trainer.audio_guidance_scale
129 | height, width = cfg.trainer.img_height, cfg.trainer.img_width
130 | inverse_image = cfg.trainer.get("inverse_image", False)
131 | use_colormap = cfg.trainer.get("use_colormap", False)
132 | crop_image = cfg.trainer.get("crop_image", False)
133 |
134 | generator = torch.manual_seed(cfg.seed + idx)
135 |
136 | # obtain the image and spec for each modality's diffusion process
137 | image = image_diffusion.prompt_to_img(cfg.trainer.image_prompt, negative_prompts=cfg.trainer.image_neg_prompt, height=height, width=width, num_inference_steps=50, guidance_scale=image_guidance_scale, device=device, generator=generator)
138 | image = image.mean(dim=1) # make grayscale image
139 |
140 | spec = audio_diffusion.prompt_to_spec(cfg.trainer.audio_prompt, negative_prompts=cfg.trainer.audio_neg_prompt, height=height, width=width, num_inference_steps=100, guidance_scale=audio_guidance_scale, device=device, generator=generator)
141 | spec = spec.mean(dim=1) # make a single channel
142 |
143 | # perform the naive baseline
144 | mag_ratio = cfg.trainer.get("mag_ratio", 0.5)
145 | if inverse_image:
146 | image = 1.0 - image
147 | image_mask = 1 - mag_ratio * image
148 | spec_new = spec * image_mask
149 | img = image
150 | # import pdb; pdb.set_trace()
151 | audio = audio_diffusion.spec_to_audio(spec_new)
152 | audio = np.ravel(audio)
153 |
154 | if crop_image:
155 | pixel = 32
156 | audio_length = int(pixel / width * audio.shape[0])
157 | img = img[..., :-pixel]
158 | spec = spec[..., :-pixel]
159 | spec_new = spec_new[..., :-pixel]
160 | audio = audio[:-audio_length]
161 |
162 | # evaluate with CLIP
163 | if visual_evaluator is not None:
164 | clip_score = visual_evaluator(img.repeat(3, 1, 1).unsqueeze(0), cfg.trainer.image_prompt)
165 | else:
166 | clip_score = None
167 |
168 | # evaluate with CLAP
169 | if audio_evaluator is not None:
170 | clap_score = audio_evaluator(cfg.trainer.audio_prompt, audio)
171 | else:
172 | clap_score = None
173 |
174 | sample_dir = os.path.join(cfg.output_dir, 'results', f'example_{str(idx+1).zfill(3)}')
175 | os.makedirs(sample_dir, exist_ok=True)
176 |
177 | # import pdb; pdb.set_trace()
178 | # save config with example-specific information
179 | cfg_save_path = os.path.join(sample_dir, 'config.yaml')
180 | current_cfg = copy.deepcopy(cfg)
181 | current_cfg.seed = cfg.seed + idx
182 | with open_dict(current_cfg):
183 | current_cfg.clip_score = clip_score
184 | current_cfg.clap_score = clap_score
185 | OmegaConf.save(current_cfg, cfg_save_path)
186 |
187 | # save image
188 | img_save_path = os.path.join(sample_dir, f'img.png')
189 | save_image(img, img_save_path)
190 |
191 | # save audio
192 | audio_save_path = os.path.join(sample_dir, f'audio.wav')
193 | save_audio(audio, audio_save_path)
194 |
195 | # save spec
196 | spec_save_path = os.path.join(sample_dir, f'spec_ori.png')
197 | save_image(spec, spec_save_path)
198 |
199 | spec_save_path = os.path.join(sample_dir, f'spec.png')
200 | save_image(spec_new, spec_save_path)
201 |
202 | # save spec with colormap (renormalize the spectrogram range)
203 | if use_colormap:
204 | spec_save_path = os.path.join(sample_dir, f'spec_colormap.png')
205 | spec_colormap = spec_new.mean(dim=0).cpu().numpy()
206 | plt.imsave(spec_save_path, spec_colormap, cmap='gray')
207 |
208 |
209 | # save video
210 | video_output_path = os.path.join(sample_dir, f'video.mp4')
211 | if img.shape[-2:] == spec.shape[-2:]:
212 | create_single_image_animation_with_text(spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt)
213 | else:
214 | create_animation_with_text(img_save_path, spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt)
215 | return clip_score, clap_score
216 |
217 |
218 | if __name__ == "__main__":
219 | main()
220 |
--------------------------------------------------------------------------------
/src/main_sds.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | from tqdm import tqdm
4 | import warnings
5 | import soundfile as sf
6 | import numpy as np
7 | import shutil
8 | import matplotlib.pyplot as plt
9 |
10 | import hydra
11 | from omegaconf import OmegaConf, DictConfig
12 | import torch
13 | import torch.nn as nn
14 | from transformers import logging
15 | from lightning import seed_everything
16 |
17 | from torchvision.utils import save_image
18 |
19 |
20 | import rootutils
21 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22 |
23 | from src.utils.rich_utils import print_config_tree
24 | from src.utils.animation_with_text import create_animation_with_text, create_single_image_animation_with_text
25 | from src.utils.pylogger import RankedLogger
26 | log = RankedLogger(__name__, rank_zero_only=True)
27 |
28 |
29 | def save_model(output_dir, step, net):
30 | save_dir = os.path.join(output_dir, 'checkpoints')
31 | os.makedirs(save_dir, exist_ok=True)
32 | path = os.path.join(save_dir, 'checkpoint_latest.pth.tar')
33 | torch.save(
34 | {
35 | 'step': step,
36 | 'state_dict': net.state_dict(),
37 | },
38 | path
39 | )
40 |
41 | def save_audio(audio, save_path):
42 | sf.write(save_path, audio, samplerate=16000)
43 |
44 |
45 | def encode_prompt(prompt, diffusion_guidance, device, time_repeat=1):
46 | '''Encode text prompts into embeddings
47 | '''
48 | prompts = [prompt] * time_repeat
49 | null_prompts = [''] * time_repeat
50 |
51 | # Prompts -> text embeds
52 | cond_embeds = diffusion_guidance.get_text_embeds(prompts, device) # [B, 77, 768]
53 | uncond_embeds = diffusion_guidance.get_text_embeds(null_prompts, device) # [B, 77, 768]
54 | text_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) # [2 * B, 77, 768]
55 | return text_embeds
56 |
57 |
58 |
59 | @hydra.main(version_base="1.3", config_path="../configs/main_sds", config_name="main.yaml")
60 | def main(cfg: DictConfig) -> Optional[float]:
61 | """Main function for training
62 | """
63 |
64 | if cfg.extras.get("ignore_warnings"):
65 | log.info("Disabling python warnings! ")
66 | warnings.filterwarnings("ignore")
67 | logging.set_verbosity_error()
68 |
69 | if cfg.extras.get("print_config"):
70 | print_config_tree(cfg, resolve=True)
71 |
72 | # set seed for random number generators in pytorch, numpy and python.random
73 | if cfg.get("seed"):
74 | seed_everything(cfg.seed, workers=True)
75 |
76 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
77 |
78 | log.info(f"Instantiating Image Learner model <{cfg.image_learner._target_}>")
79 | image_learner = hydra.utils.instantiate(cfg.image_learner).to(device)
80 |
81 | log.info(f"Instantiating Image Diffusion model <{cfg.image_diffusion_guidance._target_}>")
82 | image_diffusion_guidance = hydra.utils.instantiate(cfg.image_diffusion_guidance).to(device)
83 |
84 | log.info(f"Instantiating Audio Diffusion guidance model <{cfg.audio_diffusion_guidance._target_}>")
85 | audio_diffusion_guidance = hydra.utils.instantiate(cfg.audio_diffusion_guidance).to(device)
86 |
87 | # create optimizer
88 | log.info(f"Instantiating optimizer <{cfg.optimizer._target_}>")
89 | optimizer = hydra.utils.instantiate(cfg.optimizer)
90 | optimizer = optimizer(params=image_learner.parameters())
91 |
92 | # create transformation
93 | log.info(f"Instantiating image transformation <{cfg.image_transformation._target_}>")
94 | image_transformation = hydra.utils.instantiate(cfg.image_transformation).to(device)
95 |
96 | log.info(f"Instantiating audio transformation <{cfg.audio_transformation._target_}>")
97 | audio_transformation = hydra.utils.instantiate(cfg.audio_transformation).to(device)
98 |
99 | log.info(f"Starting training!")
100 | trainer(cfg, image_learner, image_diffusion_guidance, audio_diffusion_guidance, optimizer, image_transformation, audio_transformation, device)
101 |
102 |
103 | def trainer(cfg, image_learner, image_diffusion, audio_diffusion, optimizer, image_transformation, audio_transformation, device):
104 | image_guidance_scale, audio_guidance_scale = cfg.trainer.image_guidance_scale, cfg.trainer.audio_guidance_scale
105 | # image_weight, audio_weight = cfg.trainer.image_weight, cfg.trainer.audio_weight
106 | image_start_step = cfg.trainer.get("image_start_step", 0)
107 | audio_start_step = cfg.trainer.get("audio_start_step", 0)
108 |
109 | accumulate_grad_batches = cfg.trainer.get("accumulate_grad_batches", 1)
110 | use_colormap = cfg.trainer.get("use_colormap", False)
111 | crop_image = cfg.trainer.get("crop_image", False)
112 |
113 | image_text_embeds = encode_prompt(cfg.trainer.image_prompt, image_diffusion, device, time_repeat=cfg.trainer.batch_size)
114 | audio_text_embeds = encode_prompt(cfg.trainer.audio_prompt, audio_diffusion, device, time_repeat=1)
115 |
116 | image_learner.train()
117 | for step in tqdm(range(cfg.trainer.num_iteration), desc="Training"):
118 | # import pdb; pdb.set_trace()
119 | image = image_learner() # [C, H, W]
120 | images = image.unsqueeze(0) # (1, C, H, W)
121 |
122 | # perform image guidance
123 | rgb_images = image_transformation(images) # (B, C, h, w)
124 |
125 | # perform audio guidance
126 | spec_images = audio_transformation(images) # (1, 1, H, W)
127 |
128 | if step >= image_start_step:
129 | image_weight = cfg.trainer.image_weight
130 | image_loss = image_diffusion.train_step(image_text_embeds, rgb_images, guidance_scale=image_guidance_scale, grad_scale=1)
131 | else:
132 | image_weight = 0.0
133 | image_loss = torch.tensor(0.0).to(device)
134 |
135 | if step >= audio_start_step:
136 | audio_weight = cfg.trainer.audio_weight
137 | audio_loss = audio_diffusion.train_step(audio_text_embeds, spec_images, guidance_scale=audio_guidance_scale, grad_scale=1)
138 | else:
139 | audio_weight = 0.0
140 | audio_loss = torch.tensor(0.0).to(device)
141 |
142 | loss = image_weight * image_loss + audio_weight * audio_loss
143 | loss = loss / accumulate_grad_batches
144 | loss.backward()
145 |
146 | # apply gradient accumulation
147 | if (step + 1) % accumulate_grad_batches == 0 or (step + 1) == cfg.trainer.num_iteration:
148 | optimizer.step()
149 | optimizer.zero_grad()
150 |
151 | tqdm.write(f"Iteration: {step+1}/{cfg.trainer.num_iteration}, loss: {loss.item():.4f} | visual loss: {image_loss.item():.4f} | audio loss: {audio_loss.item():.4f}")
152 |
153 | if (step + 1) % cfg.trainer.save_step == 0:
154 | save_model(cfg.output_dir, step, image_learner)
155 |
156 | if (step + 1) % cfg.trainer.visualize_step == 0:
157 | img_save_dir = os.path.join(cfg.output_dir, 'image_results')
158 | os.makedirs(img_save_dir, exist_ok=True)
159 | img_save_path = os.path.join(img_save_dir, f'img_{str(step+1).zfill(6)}.png')
160 | save_image(image, img_save_path)
161 |
162 | spec_save_dir = os.path.join(cfg.output_dir, 'spec_results')
163 | os.makedirs(spec_save_dir, exist_ok=True)
164 | spec_save_path = os.path.join(spec_save_dir, f'spec_{str(step+1).zfill(6)}.png')
165 | save_image(spec_images.squeeze(0), spec_save_path)
166 |
167 | audio_save_dir = os.path.join(cfg.output_dir, 'audio_results')
168 | os.makedirs(audio_save_dir, exist_ok=True)
169 | audio_save_path = os.path.join(audio_save_dir, f'audio_{str(step+1).zfill(6)}.wav')
170 |
171 | audio = audio_diffusion.spec_to_audio(spec_images.squeeze(0))
172 | audio = np.ravel(audio)
173 | save_audio(audio, audio_save_path)
174 |
175 | # obtain final results
176 | img = image_learner() # [C, H, W]
177 | spec = audio_transformation(img.unsqueeze(0)).squeeze(0) # (1, H, W)
178 | audio = audio_diffusion.spec_to_audio(spec)
179 | audio = np.ravel(audio)
180 |
181 | if crop_image:
182 | pixel = 32
183 | audio_length = int(pixel / image_learner.width * audio.shape[0])
184 | img = img[..., :-pixel]
185 | spec = spec[..., :-pixel]
186 | audio = audio[:-audio_length]
187 |
188 | # save the final results
189 | sample_dir = os.path.join(cfg.output_dir, 'results', f'final')
190 | os.makedirs(sample_dir, exist_ok=True)
191 |
192 | # save config
193 | cfg_path = os.path.join(cfg.output_dir, '.hydra', 'config.yaml')
194 | cfg_save_path = os.path.join(sample_dir, 'config.yaml')
195 | shutil.copyfile(cfg_path, cfg_save_path)
196 |
197 | # save image
198 | img_save_path = os.path.join(sample_dir, f'img.png')
199 | save_image(img, img_save_path)
200 |
201 | # save audio
202 | audio_save_path = os.path.join(sample_dir, f'audio.wav')
203 | save_audio(audio, audio_save_path)
204 |
205 | # save spec
206 | spec_save_path = os.path.join(sample_dir, f'spec.png')
207 | save_image(spec.mean(dim=0, keepdim=True), spec_save_path)
208 |
209 | if use_colormap:
210 | spec_save_path = os.path.join(sample_dir, f'spec_colormap.png')
211 | spec_colormap = spec.mean(dim=0).detach().cpu().numpy()
212 | plt.imsave(spec_save_path, spec_colormap, cmap='gray')
213 |
214 | # log.info("Generating video ...")
215 | # save video
216 | video_output_path = os.path.join(cfg.output_dir, 'video.mp4')
217 |
218 | if image_learner.num_channels == 1:
219 | create_single_image_animation_with_text(spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt)
220 | else:
221 | create_animation_with_text(img_save_path, spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt)
222 | # log.info("Generated video.")
223 |
224 |
225 | if __name__ == "__main__":
226 | main()
227 |
--------------------------------------------------------------------------------
/src/guidance/stable_diffusion.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torchvision.utils import save_image
7 |
8 |
9 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
10 | from diffusers.utils.import_utils import is_xformers_available
11 |
12 | from lightning import seed_everything
13 |
14 |
15 |
16 | class StableDiffusionGuidance(nn.Module):
17 | def __init__(
18 | self,
19 | repo_id='runwayml/stable-diffusion-v1-5',
20 | fp16=True,
21 | t_range=[0.02, 0.98],
22 | t_consistent=False,
23 | **kwargs
24 | ):
25 | super().__init__()
26 |
27 | self.repo_id = repo_id
28 |
29 | self.precision_t = torch.float16 if fp16 else torch.float32
30 |
31 | # Create model
32 | self.vae, self.tokenizer, self.text_encoder, self.unet = self.create_model_from_pipe(repo_id, self.precision_t)
33 |
34 | self.scheduler = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler", torch_dtype=self.precision_t)
35 |
36 | self.register_buffer('alphas_cumprod', self.scheduler.alphas_cumprod)
37 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps
38 | self.min_step = int(self.num_train_timesteps * t_range[0])
39 | self.max_step = int(self.num_train_timesteps * t_range[1])
40 | self.t_consistent = t_consistent
41 |
42 | def create_model_from_pipe(self, repo_id, dtype):
43 | pipe = StableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
44 | vae = pipe.vae
45 | tokenizer = pipe.tokenizer
46 | text_encoder = pipe.text_encoder
47 | unet = pipe.unet
48 | return vae, tokenizer, text_encoder, unet
49 |
50 | @torch.no_grad()
51 | def get_text_embeds(self, prompt, device):
52 | # prompt: [str]
53 | # import pdb; pdb.set_trace()
54 | inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
55 | prompt_embeds = self.text_encoder(inputs.input_ids.to(device))[0]
56 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
57 | return prompt_embeds
58 |
59 | def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, t=None, grad_scale=1, save_guidance_path:Path=None):
60 | # import pdb; pdb.set_trace()
61 | pred_rgb = pred_rgb.to(self.vae.dtype)
62 | if as_latent:
63 | # latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
64 | latents = pred_rgb
65 | else:
66 | # interp to 512x512 to be fed into vae.
67 | # pred_rgb_512 = pred_rgb
68 | pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
69 | # encode image into latents with vae, requires grad!
70 | latents = self.encode_imgs(pred_rgb_512)
71 |
72 | if t is None:
73 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
74 | if self.t_consistent:
75 | t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=latents.device)
76 | t = t.repeat(latents.shape[0])
77 | else:
78 | t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=latents.device)
79 | else:
80 | t = t.to(dtype=torch.long, device=latents.device)
81 |
82 | # predict the noise residual with unet, NO grad!
83 | with torch.no_grad():
84 | # add noise
85 | noise = torch.randn_like(latents)
86 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
87 | # pred noise
88 | latent_model_input = torch.cat([latents_noisy] * 2)
89 | tt = torch.cat([t] * 2)
90 | noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
91 |
92 | # perform guidance (high scale from paper!)
93 | noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
94 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
95 |
96 | # w(t), sigma_t^2
97 | w = (1 - self.alphas_cumprod[t])
98 | grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
99 | grad = torch.nan_to_num(grad)
100 |
101 | if save_guidance_path:
102 | with torch.no_grad():
103 | if as_latent:
104 | pred_rgb_512 = self.decode_latents(latents)
105 |
106 | # visualize predicted denoised image
107 | # The following block of code is equivalent to `predict_start_from_noise`...
108 | # see zero123_utils.py's version for a simpler implementation.
109 | alphas = self.scheduler.alphas.to(latents.device)
110 | total_timesteps = self.max_step - self.min_step + 1
111 | index = total_timesteps - t.to(latents.device) - 1
112 | b = len(noise_pred)
113 | a_t = alphas[index].reshape(b,1,1,1).to(latents.device)
114 | sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
115 | sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(latents.device)
116 | pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
117 | result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
118 |
119 | # visualize noisier image
120 | result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
121 |
122 | # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
123 | viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
124 | save_image(viz_images, save_guidance_path)
125 |
126 | targets = (latents - grad).detach()
127 | loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
128 |
129 | return loss
130 |
131 |
132 | @torch.no_grad()
133 | def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, generator=None):
134 |
135 | if latents is None:
136 | latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=self.unet.dtype).to(text_embeddings.device)
137 |
138 | self.scheduler.set_timesteps(num_inference_steps)
139 |
140 | for i, t in enumerate(self.scheduler.timesteps):
141 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
142 | latent_model_input = torch.cat([latents] * 2)
143 | # predict the noise residual
144 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
145 |
146 | # perform guidance
147 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
148 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
149 |
150 | # compute the previous noisy sample x_t -> x_t-1
151 | latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
152 |
153 | return latents
154 |
155 | def decode_latents(self, latents):
156 | latents = latents.to(self.vae.dtype)
157 | latents = 1 / self.vae.config.scaling_factor * latents
158 |
159 | imgs = self.vae.decode(latents).sample
160 | imgs = (imgs / 2 + 0.5).clamp(0, 1)
161 |
162 | return imgs
163 |
164 | def encode_imgs(self, imgs, generator=None):
165 | # imgs: [B, 3, H, W]
166 | imgs = 2 * imgs - 1
167 | posterior = self.vae.encode(imgs).latent_dist
168 | latents = posterior.sample(generator) * self.vae.config.scaling_factor
169 | return latents
170 |
171 | def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, device=None, generator=None):
172 | if isinstance(prompts, str):
173 | prompts = [prompts]
174 |
175 | if isinstance(negative_prompts, str):
176 | negative_prompts = [negative_prompts]
177 |
178 | # Prompts -> text embeds
179 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768]
180 | neg_embeds = self.get_text_embeds(negative_prompts, device)
181 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
182 |
183 | # Text embeds -> img latents
184 | latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) # [1, 4, 64, 64]
185 |
186 | # Img latents -> imgs
187 | imgs = self.decode_latents(latents) # [1, 3, 512, 512]
188 |
189 | # # Img to Numpy
190 | # imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
191 | # imgs = (imgs * 255).round().astype('uint8')
192 |
193 | return imgs
194 |
195 |
196 | if __name__ == '__main__':
197 |
198 | import argparse
199 | from PIL import Image
200 | import os
201 |
202 | parser = argparse.ArgumentParser()
203 | parser.add_argument('--prompt', type=str, default='')
204 | parser.add_argument('--negative', default='', type=str)
205 | parser.add_argument('--repo_id', type=str, default='runwayml/stable-diffusion-v1-5', help="stable diffusion version")
206 | parser.add_argument('--fp16', action='store_true', help="use float16 for training")
207 | parser.add_argument('--H', type=int, default=512)
208 | parser.add_argument('--W', type=int, default=512)
209 | parser.add_argument('--seed', type=int, default=0)
210 | parser.add_argument('--steps', type=int, default=50)
211 | opt = parser.parse_args()
212 |
213 | seed_everything(opt.seed)
214 |
215 | device = torch.device('cuda')
216 |
217 | sd = StableDiffusionGuidance(repo_id=opt.repo_id, fp16=opt.fp16)
218 | sd = sd.to(device)
219 |
220 | imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, device=device)
221 |
222 | # visualize image
223 | save_path = 'logs/test'
224 | os.makedirs(save_path, exist_ok=True)
225 | image = Image.fromarray(imgs[0], mode='RGB')
226 | image.save(os.path.join(save_path, f'{opt.prompt}.png'))
--------------------------------------------------------------------------------
/src/main_denoise.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | from tqdm import tqdm
4 | import warnings
5 | import soundfile as sf
6 | import numpy as np
7 | import shutil
8 | import glob
9 | import copy
10 | import matplotlib.pyplot as plt
11 |
12 | import hydra
13 | from omegaconf import OmegaConf, DictConfig, open_dict
14 | import torch
15 | import torch.nn as nn
16 | from transformers import logging
17 | from lightning import seed_everything
18 |
19 | from torchvision.utils import save_image
20 |
21 |
22 | import rootutils
23 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
24 |
25 | from src.utils.rich_utils import print_config_tree
26 | from src.utils.animation_with_text import create_animation_with_text, create_single_image_animation_with_text
27 | from src.utils.re_ranking import select_top_k_ranking, select_top_k_clip_ranking
28 | from src.utils.pylogger import RankedLogger
29 | log = RankedLogger(__name__, rank_zero_only=True)
30 |
31 |
32 | def save_audio(audio, save_path):
33 | sf.write(save_path, audio, samplerate=16000)
34 |
35 |
36 | def encode_prompt(prompt, diffusion_guidance, device, negative_prompt='', time_repeat=1):
37 | '''Encode text prompts into embeddings
38 | '''
39 | prompts = [prompt] * time_repeat
40 | negative_prompts = [negative_prompt] * time_repeat
41 |
42 | # Prompts -> text embeds
43 | cond_embeds = diffusion_guidance.get_text_embeds(prompts, device) # [B, 77, 768]
44 | uncond_embeds = diffusion_guidance.get_text_embeds(negative_prompts, device) # [B, 77, 768]
45 | text_embeds = torch.cat([uncond_embeds, cond_embeds], dim=0) # [2 * B, 77, 768]
46 | return text_embeds
47 |
48 | def estimate_noise(diffusion, latents, t, text_embeddings, guidance_scale):
49 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
50 | latent_model_input = torch.cat([latents] * 2)
51 | # predict the noise residual
52 | noise_pred = diffusion.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
53 |
54 | # perform guidance
55 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
56 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
57 | return noise_pred
58 |
59 |
60 | @hydra.main(version_base="1.3", config_path="../configs/main_denoise", config_name="main.yaml")
61 | def main(cfg: DictConfig) -> Optional[float]:
62 | """Main function for training
63 | """
64 |
65 | if cfg.extras.get("ignore_warnings"):
66 | log.info("Disabling python warnings! ")
67 | warnings.filterwarnings("ignore")
68 | logging.set_verbosity_error()
69 |
70 | if cfg.extras.get("print_config"):
71 | print_config_tree(cfg, resolve=True)
72 |
73 | # set seed for random number generators in pytorch, numpy and python.random
74 | if cfg.get("seed"):
75 | seed_everything(cfg.seed, workers=True)
76 |
77 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
78 |
79 | log.info(f"Instantiating Image Diffusion model <{cfg.image_diffusion_guidance._target_}>")
80 | image_diffusion_guidance = hydra.utils.instantiate(cfg.image_diffusion_guidance).to(device)
81 |
82 | log.info(f"Instantiating Audio Diffusion guidance model <{cfg.audio_diffusion_guidance._target_}>")
83 | audio_diffusion_guidance = hydra.utils.instantiate(cfg.audio_diffusion_guidance).to(device)
84 |
85 | # create the shared noise scheduler
86 | log.info(f"Instantiating joint diffusion scheduler <{cfg.diffusion_scheduler._target_}>")
87 | scheduler = hydra.utils.instantiate(cfg.diffusion_scheduler)
88 |
89 | # create transformation
90 | log.info(f"Instantiating latent transformation <{cfg.latent_transformation._target_}>")
91 | latent_transformation = hydra.utils.instantiate(cfg.latent_transformation).to(device)
92 |
93 | # create audio evaluator
94 | if cfg.audio_evaluator:
95 | log.info(f"Instantiating audio evaluator <{cfg.audio_evaluator._target_}>")
96 | audio_evaluator = hydra.utils.instantiate(cfg.audio_evaluator).to(device)
97 | else:
98 | audio_evaluator = None
99 |
100 | if cfg.visual_evaluator:
101 | log.info(f"Instantiating visual evaluator <{cfg.visual_evaluator._target_}>")
102 | visual_evaluator = hydra.utils.instantiate(cfg.visual_evaluator).to(device)
103 | else:
104 | visual_evaluator = None
105 |
106 | log.info(f"Starting sampling!")
107 | clip_scores = []
108 | clap_scores = []
109 |
110 | for idx in tqdm(range(cfg.trainer.num_samples), desc='Sampling'):
111 | clip_score, clap_score = denoise(cfg, image_diffusion_guidance, audio_diffusion_guidance, scheduler, latent_transformation, visual_evaluator, audio_evaluator, idx, device)
112 | clip_scores.append(clip_score)
113 | clap_scores.append(clap_score)
114 |
115 | # re-ranking by metrics
116 | enable_rank = cfg.trainer.get("enable_rank", False)
117 | if enable_rank:
118 | log.info(f"Starting re-ranking and selection!")
119 | select_top_k_ranking(cfg, clip_scores, clap_scores)
120 |
121 | enable_clip_rank = cfg.trainer.get("enable_clip_rank", False)
122 | if enable_clip_rank:
123 | log.info(f"Starting re-ranking and selection by CLIP score!")
124 | select_top_k_clip_ranking(cfg, clip_scores)
125 |
126 | log.info(f"Finished!")
127 |
128 |
129 | @torch.no_grad()
130 | def denoise(cfg, image_diffusion, audio_diffusion, scheduler, latent_transformation, visual_evaluator, audio_evaluator, idx, device):
131 | image_guidance_scale, audio_guidance_scale = cfg.trainer.image_guidance_scale, cfg.trainer.audio_guidance_scale
132 | height, width = cfg.trainer.img_height, cfg.trainer.img_width
133 | image_start_step = cfg.trainer.get("image_start_step", 0)
134 | audio_start_step = cfg.trainer.get("audio_start_step", 0)
135 | audio_weight = cfg.trainer.get("audio_weight", 0.5)
136 | use_colormap = cfg.trainer.get("use_colormap", False)
137 |
138 | cutoff_latent = cfg.trainer.get("cutoff_latent", False)
139 | crop_image = cfg.trainer.get("crop_image", False)
140 |
141 | generator = torch.manual_seed(cfg.seed + idx)
142 |
143 | # obtain the text embeddings for each modality's diffusion process
144 | image_text_embeds = encode_prompt(cfg.trainer.image_prompt, image_diffusion, device, negative_prompt=cfg.trainer.image_neg_prompt, time_repeat=1)
145 | audio_text_embeds = encode_prompt(cfg.trainer.audio_prompt, audio_diffusion, device, negative_prompt=cfg.trainer.audio_neg_prompt, time_repeat=1)
146 |
147 | scheduler.set_timesteps(cfg.trainer.num_inference_steps)
148 |
149 | # init random latents
150 | latents = torch.randn((image_text_embeds.shape[0] // 2, image_diffusion.unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=image_diffusion.precision_t).to(device)
151 |
152 | for i, t in enumerate(scheduler.timesteps):
153 | if i >= image_start_step:
154 | image_noise = estimate_noise(image_diffusion, latents, t, image_text_embeds, image_guidance_scale)
155 | else:
156 | image_noise = None
157 |
158 | if i >= audio_start_step:
159 | transform_latents = latent_transformation(latents, inverse=False)
160 | audio_noise = estimate_noise(audio_diffusion, transform_latents, t, audio_text_embeds, audio_guidance_scale)
161 | audio_noise = latent_transformation(audio_noise, inverse=True)
162 | else:
163 | audio_noise = None
164 |
165 | if image_noise is not None and audio_noise is not None:
166 | noise_pred = (1.0 - audio_weight) * image_noise + audio_weight * audio_noise
167 | elif image_noise is not None and audio_noise is None:
168 | noise_pred = image_noise
169 | elif image_noise is None and audio_noise is not None:
170 | noise_pred = audio_noise
171 | else:
172 | log.info("No estimated noise! Exit.")
173 | raise NotImplementedError
174 |
175 | # compute the previous noisy sample x_t -> x_t-1
176 | latents = scheduler.step(noise_pred, t, latents)['prev_sample']
177 |
178 | if cutoff_latent and not crop_image:
179 | latents = latents[..., :-4] # we cut off 4 latents so that we can directly remove the black region
180 |
181 | # Img latents -> imgs
182 | img = image_diffusion.decode_latents(latents) # [1, 3, H, W]
183 |
184 | # Img latents -> audio
185 | audio_latents = latent_transformation(latents, inverse=False)
186 | spec = audio_diffusion.decode_latents(audio_latents).squeeze(0) # [3, 256, 1024]
187 | audio = audio_diffusion.spec_to_audio(spec)
188 | audio = np.ravel(audio)
189 |
190 | if crop_image and not cutoff_latent:
191 | pixel = 32
192 | audio_length = int(pixel / width * audio.shape[0])
193 | img = img[..., :-pixel]
194 | spec = spec[..., :-pixel]
195 | audio = audio[:-audio_length]
196 |
197 | # evaluate with CLIP
198 | if visual_evaluator is not None:
199 | clip_score = visual_evaluator(img, cfg.trainer.image_prompt)
200 | else:
201 | clip_score = None
202 |
203 | # evaluate with CLAP
204 | if audio_evaluator is not None:
205 | clap_score = audio_evaluator(cfg.trainer.audio_prompt, audio)
206 | else:
207 | clap_score = None
208 |
209 | sample_dir = os.path.join(cfg.output_dir, 'results', f'example_{str(idx+1).zfill(3)}')
210 | os.makedirs(sample_dir, exist_ok=True)
211 |
212 | # save config with example-specific information
213 | cfg_save_path = os.path.join(sample_dir, 'config.yaml')
214 | current_cfg = copy.deepcopy(cfg)
215 | current_cfg.seed = cfg.seed + idx
216 | with open_dict(current_cfg):
217 | current_cfg.clip_score = clip_score
218 | current_cfg.clap_score = clap_score
219 | OmegaConf.save(current_cfg, cfg_save_path)
220 |
221 | # save image
222 | img_save_path = os.path.join(sample_dir, f'img.png')
223 | save_image(img, img_save_path)
224 |
225 | # save audio
226 | audio_save_path = os.path.join(sample_dir, f'audio.wav')
227 | save_audio(audio, audio_save_path)
228 |
229 | # save spec
230 | spec_save_path = os.path.join(sample_dir, f'spec.png')
231 | save_image(spec.mean(dim=0, keepdim=True), spec_save_path)
232 |
233 | # save spec with colormap (renormalize the spectrogram range)
234 | if use_colormap:
235 | spec_save_path = os.path.join(sample_dir, f'spec_colormap.png')
236 | spec_colormap = spec.mean(dim=0).cpu().numpy()
237 | plt.imsave(spec_save_path, spec_colormap, cmap='gray')
238 |
239 | # save video
240 | video_output_path = os.path.join(sample_dir, f'video.mp4')
241 | if img.shape[-2:] == spec.shape[-2:]:
242 | create_single_image_animation_with_text(spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt)
243 | else:
244 | create_animation_with_text(img_save_path, spec_save_path, audio_save_path, video_output_path, cfg.trainer.image_prompt, cfg.trainer.audio_prompt)
245 |
246 | return clip_score, clap_score
247 |
248 |
249 | if __name__ == "__main__":
250 | main()
251 |
--------------------------------------------------------------------------------
/src/utils/animation_with_text.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import os
4 | import soundfile as sf
5 | import subprocess
6 | from omegaconf import OmegaConf, DictConfig
7 |
8 |
9 | from moviepy.editor import VideoClip, ImageClip, TextClip
10 |
11 | # import pdb; pdb.set_trace()
12 |
13 |
14 | # Run the ffmpeg command to get the path
15 | ffmpeg_path = subprocess.check_output(['which', 'ffmpeg']).decode().strip()
16 | # Set the environment variable
17 | os.environ["IMAGEIO_FFMPEG_EXE"] = ffmpeg_path
18 | # Optionally, print the path for verification
19 | print("IMAGEIO_FFMPEG_EXE set to:", os.environ["IMAGEIO_FFMPEG_EXE"])
20 |
21 |
22 | # import moviepy.config_defaults
23 | # magick_path = subprocess.check_output(['which', 'magick']).decode().strip()
24 | # moviepy.config_defaults.IMAGEMAGICK_BINARY = magick_path
25 | # print("IMAGEMAGICK_BINARY Path set to:", moviepy.config_defaults.IMAGEMAGICK_BINARY)
26 |
27 |
28 |
29 |
30 | def create_animation_with_text(image_path1, image_path2, audio_path, output_path, image_prompt, audio_prompt):
31 | # set up video hyperparameter
32 | space_between_images = 30
33 | padding = 40
34 | space_between_text_image = 8
35 | fontsize = 26
36 | font = 'Ubuntu-Mono'
37 | text_method = 'caption'
38 |
39 | # Load images
40 | image1 = np.array(Image.open(image_path1).convert('RGB'))
41 | image2 = np.array(Image.open(image_path2).convert('RGB'))
42 | image1_height, image1_width, _ = image1.shape
43 | image2_height, image2_width, _ = image2.shape
44 |
45 | max_image_width = max(image1_width, image2_width)
46 | # Add text prompts
47 | text_size = [max_image_width, None]
48 | image_prompt = TextClip('Image prompt: ' + image_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size)
49 | audio_prompt = TextClip('Audio prompt: ' + audio_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size)
50 | image_text_height = np.array(image_prompt.get_frame(0)).shape[0]
51 | audio_text_height = np.array(audio_prompt.get_frame(0)).shape[0]
52 | # Calculate total height for the video
53 | total_height = image1_height + image2_height + space_between_images + 2 * padding + image_text_height + audio_text_height + 2 * space_between_text_image
54 |
55 | # Create a white slider image without shadow
56 | slider_width = 5 # Increased width
57 | slider_height = image1_height + image2_height + space_between_images # Adjusted height
58 | slider_color = (240, 240, 240) # White
59 |
60 | # Create slider without shadow
61 | slider = np.zeros((slider_height, slider_width, 3), dtype=np.uint8)
62 | slider[:, :] = slider_color # Main slider area
63 |
64 | # import pdb; pdb.set_trace()
65 | # Load audio using soundfile
66 | audio_data, sample_rate = sf.read(audio_path)
67 |
68 | # Calculate video duration based on audio length
69 | video_duration = len(audio_data) / sample_rate
70 |
71 | # Define video dimensions
72 | video_width = max_image_width + 2 * padding
73 |
74 | # Function to generate frame at time t
75 | def make_frame(t):
76 | # Calculate slider position
77 | # import pdb; pdb.set_trace()
78 | slider_position = int(t * (max_image_width - slider_width//2) / video_duration)
79 |
80 | # Create a white blank frame
81 | frame = np.ones((total_height, video_width, 3), dtype=np.uint8) * 255
82 |
83 | # Calculate positions for image prompt and add it
84 | image_text = np.array(image_prompt.get_frame(t))
85 | image_prompt_start_pos = (padding, (video_width - image_text.shape[1]) // 2 )
86 | image_prompt_end_pos = (padding + image_text.shape[0], image_prompt_start_pos[1] + image_text.shape[1])
87 | frame[image_prompt_start_pos[0]: image_prompt_end_pos[0], image_prompt_start_pos[1]: image_prompt_end_pos[1]] = image_text
88 |
89 | # Add images to the frame
90 | image1_start_pos = (image_prompt_end_pos[0] + space_between_text_image, padding)
91 | frame[image1_start_pos[0]: (image1_start_pos[0] + image1_height), image1_start_pos[1]:(image1_start_pos[1] + image1_width)] = image1
92 |
93 | image2_start_pos = (image1_start_pos[0] + image1_height + space_between_images, padding)
94 | frame[image2_start_pos[0]: (image2_start_pos[0] + image2_height), image2_start_pos[1]:(image2_start_pos[1] + image2_width)] = image2
95 |
96 | # Calculate positions for image prompt and add it
97 | audio_text = np.array(audio_prompt.get_frame(t))
98 | audio_prompt_start_pos = (image2_start_pos[0] + image2_height + space_between_text_image, (video_width - audio_text.shape[1]) // 2)
99 | audio_prompt_end_pos = (audio_prompt_start_pos[0] + audio_text.shape[0], audio_prompt_start_pos[1] + audio_text.shape[1])
100 | frame[audio_prompt_start_pos[0]: audio_prompt_end_pos[0], audio_prompt_start_pos[1]: audio_prompt_end_pos[1]] = audio_text
101 |
102 | # Add slider to the frame
103 | frame[image1_start_pos[0]:(slider_height+image1_start_pos[0]), (padding+slider_position):(padding+slider_position+slider_width)] = slider
104 |
105 | return frame
106 |
107 | # Create a VideoClip
108 | video_clip = VideoClip(make_frame, duration=video_duration)
109 |
110 | # Write the final video
111 | temp_path = output_path[:-4] + '-temp.mp4'
112 | video_clip.write_videofile(temp_path, codec='libx264', fps=60, logger=None)
113 |
114 | # the reason we do this is because when change audio codec, the quality of audio is changed a lot.
115 | # So we copy the original audio to ensure we have best audio quality
116 | os.system(f"ffmpeg -v quiet -y -i \"{temp_path}\" -i {audio_path} -c:v copy -c:a aac {output_path}")
117 | os.system(f"rm {temp_path}")
118 |
119 |
120 | def create_single_image_animation_with_text(image_path, audio_path, output_path, image_prompt, audio_prompt):
121 | # set up video hyperparameter
122 | padding = 40
123 | space_between_text_image = 8
124 | fontsize = 26
125 | font = 'Ubuntu-Mono'
126 | text_method = 'caption'
127 |
128 | # Load images
129 | image = np.array(Image.open(image_path).convert('RGB'))
130 | image_height, image_width, _ = image.shape
131 |
132 | max_image_width = image_width
133 |
134 | # Add text prompts
135 | text_size = [max_image_width, None]
136 | image_prompt = TextClip('Image prompt: ' + image_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size)
137 | audio_prompt = TextClip('Audio prompt: ' + audio_prompt, font=font, fontsize=fontsize, bg_color='white', color='black', method=text_method, size=text_size)
138 | image_text_height = np.array(image_prompt.get_frame(0)).shape[0]
139 | audio_text_height = np.array(audio_prompt.get_frame(0)).shape[0]
140 |
141 | # Calculate total height for the video
142 | total_height = image_height + 2 * padding + image_text_height + audio_text_height + 2 * space_between_text_image
143 |
144 | # Create a white slider image without shadow
145 | slider_width = 5 # Increased width
146 | slider_height = image_height # Adjusted height
147 | slider_color = (240, 240, 240) # White
148 | border_color = (200, 200, 200) # gray
149 |
150 |
151 | # Create slider without shadow
152 | slider = np.zeros((slider_height, slider_width, 3), dtype=np.uint8)
153 | slider[:, :] = slider_color # Main slider area
154 |
155 | # import pdb; pdb.set_trace()
156 | # Load audio using soundfile
157 | audio_data, sample_rate = sf.read(audio_path)
158 |
159 | # Calculate video duration based on audio length
160 | video_duration = len(audio_data) / sample_rate
161 |
162 | # Define video dimensions
163 | video_width = max_image_width + 2 * padding
164 |
165 | # Function to generate frame at time t
166 | def make_frame(t):
167 | # Calculate slider position
168 | # import pdb; pdb.set_trace()
169 | slider_position = int(t * (max_image_width - slider_width // 2) / video_duration)
170 |
171 | # Create a white blank frame
172 | frame = np.ones((total_height, video_width, 3), dtype=np.uint8) * 255
173 |
174 | # Calculate positions for image prompt and add it
175 | image_text = np.array(image_prompt.get_frame(t))
176 | image_prompt_start_pos = (padding, (video_width - image_text.shape[1]) // 2 )
177 | image_prompt_end_pos = (padding + image_text.shape[0], image_prompt_start_pos[1] + image_text.shape[1])
178 | frame[image_prompt_start_pos[0]: image_prompt_end_pos[0], image_prompt_start_pos[1]: image_prompt_end_pos[1]] = image_text
179 |
180 | # Add image to the frame
181 | image_start_pos = (image_prompt_end_pos[0] + space_between_text_image, padding)
182 | frame[image_start_pos[0]: (image_start_pos[0] + image_height), image_start_pos[1]:(image_start_pos[1] + image_width)] = image
183 |
184 | # Calculate positions for image prompt and add it
185 | audio_text = np.array(audio_prompt.get_frame(t))
186 | audio_prompt_start_pos = (image_start_pos[0] + image_height + space_between_text_image, (video_width - audio_text.shape[1]) // 2)
187 | audio_prompt_end_pos = (audio_prompt_start_pos[0] + audio_text.shape[0], audio_prompt_start_pos[1] + audio_text.shape[1])
188 | frame[audio_prompt_start_pos[0]: audio_prompt_end_pos[0], audio_prompt_start_pos[1]: audio_prompt_end_pos[1]] = audio_text
189 |
190 | # Add slider to the frame
191 | frame[image_start_pos[0]:(slider_height+image_start_pos[0]), (padding+slider_position):(padding+slider_position+slider_width)] = slider
192 |
193 | return frame
194 |
195 | # Create a VideoClip
196 | video_clip = VideoClip(make_frame, duration=video_duration)
197 |
198 | # Write the final video
199 | temp_path = output_path[:-4] + '-temp.mp4'
200 | video_clip.write_videofile(temp_path, codec='libx264', fps=60, logger=None)
201 |
202 | # the reason we do this is because when change audio codec, the quality of audio is changed a lot.
203 | # So we copy the original audio to ensure we have best audio quality
204 | # os.system(f"ffmpeg -v quiet -y -i \"{temp_path}\" -i {audio_path} -c:v copy -c:a copy {output_path}")
205 | os.system(f"ffmpeg -v quiet -y -i \"{temp_path}\" -i {audio_path} -c:v copy -c:a aac {output_path}")
206 |
207 | os.system(f"rm {temp_path}")
208 |
209 |
210 | # Example usage:
211 | if __name__ == '__main__':
212 | # rgb = '/home/czyang/Workspace/images-that-sound/logs/soundify/kitten/image_results/img_030000.png'
213 | # spec = '/home/czyang/Workspace/images-that-sound/logs/soundify/kitten/spec_results/spec_030000.png'
214 | # audio = '/home/czyang/Workspace/images-that-sound/logs/soundify/kitten/audio_results/audio_030000.wav'
215 | # image_prompt = 'an oil paint of playground with cats chasing and playing'
216 | # audio_prompt = 'A kitten mewing for attention'
217 | # audio = 'audio_050000.wav'
218 |
219 | # example_path = '/home/czyang/Workspace/images-that-sound/logs/soundify-denoise/good-examples/example_02'
220 | example_path = '/home/czyang/Workspace/images-that-sound/logs/soundify-denoise/colorization/tiger_example_06'
221 |
222 | rgb = f'{example_path}/img_rgb.png'
223 | spec = f'{example_path}/spec.png'
224 | audio = f'{example_path}/audio.wav'
225 | config_path = f'{example_path}/config.yaml'
226 | cfg = OmegaConf.load(config_path)
227 | # image_prompt = cfg.trainer.colorization_prompt
228 | image_prompt = cfg.trainer.image_prompt
229 | audio_prompt = cfg.trainer.audio_prompt
230 | # output_path = f'{example_path}/video_rgb.mp4'
231 | output_path = f'test.mp4'
232 |
233 | # create_animation_with_text(rgb, spec, audio, output_path, image_prompt, audio_prompt)
234 | create_single_image_animation_with_text(spec, audio, output_path, image_prompt, audio_prompt)
235 |
--------------------------------------------------------------------------------
/src/guidance/auffusion.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torchvision.utils import save_image
7 |
8 |
9 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
10 | from diffusers.utils.import_utils import is_xformers_available
11 |
12 | # from .perpneg_utils import weighted_perpendicular_aggregator
13 |
14 | from lightning import seed_everything
15 |
16 | import rootutils
17 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
18 |
19 | from src.models.components.auffusion_converter import Generator, denormalize_spectrogram
20 |
21 |
22 |
23 | class AuffusionGuidance(nn.Module):
24 | def __init__(
25 | self,
26 | repo_id='auffusion/auffusion-full-no-adapter',
27 | fp16=True,
28 | t_range=[0.02, 0.98],
29 | **kwargs
30 | ):
31 | super().__init__()
32 |
33 | self.repo_id = repo_id
34 |
35 | self.precision_t = torch.float16 if fp16 else torch.float32
36 |
37 | # Create model
38 | self.vae, self.tokenizer, self.text_encoder, self.unet = self.create_model_from_pipe(repo_id, self.precision_t)
39 | self.scheduler = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler", torch_dtype=self.precision_t)
40 | self.vocoder = Generator.from_pretrained(repo_id, subfolder="vocoder").to(dtype=self.precision_t)
41 |
42 | self.register_buffer('alphas_cumprod', self.scheduler.alphas_cumprod)
43 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps
44 | self.min_step = int(self.num_train_timesteps * t_range[0])
45 | self.max_step = int(self.num_train_timesteps * t_range[1])
46 |
47 | def create_model_from_pipe(self, repo_id, dtype):
48 | pipe = StableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
49 | vae = pipe.vae
50 | tokenizer = pipe.tokenizer
51 | text_encoder = pipe.text_encoder
52 | unet = pipe.unet
53 | return vae, tokenizer, text_encoder, unet
54 |
55 | @torch.no_grad()
56 | def get_text_embeds(self, prompt, device):
57 | # prompt: [str]
58 | # import pdb; pdb.set_trace()
59 | inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
60 | prompt_embeds = self.text_encoder(inputs.input_ids.to(device))[0]
61 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
62 | return prompt_embeds
63 |
64 | def train_step(self, text_embeddings, pred_spec, guidance_scale=100, as_latent=False, t=None, grad_scale=1, save_guidance_path:Path=None):
65 | # import pdb; pdb.set_trace()
66 | pred_spec = pred_spec.to(self.vae.dtype)
67 |
68 | if as_latent:
69 | latents = pred_spec
70 | else:
71 | if pred_spec.shape[1] != 3:
72 | pred_spec = pred_spec.repeat(1, 3, 1, 1)
73 |
74 | # encode image into latents with vae, requires grad!
75 | latents = self.encode_imgs(pred_spec)
76 |
77 | if t is None:
78 | # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
79 | t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=latents.device)
80 | else:
81 | t = t.to(dtype=torch.long, device=latents.device)
82 |
83 | # predict the noise residual with unet, NO grad!
84 | with torch.no_grad():
85 | # add noise
86 | noise = torch.randn_like(latents)
87 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
88 | # pred noise
89 | latent_model_input = torch.cat([latents_noisy] * 2)
90 | tt = torch.cat([t] * 2)
91 | noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
92 |
93 | # perform guidance (high scale from paper!)
94 | noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
95 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
96 |
97 | # w(t), sigma_t^2
98 | w = (1 - self.alphas_cumprod[t])
99 | grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
100 | grad = torch.nan_to_num(grad)
101 |
102 | if save_guidance_path:
103 | with torch.no_grad():
104 | if as_latent:
105 | pred_rgb_512 = self.decode_latents(latents)
106 |
107 | # visualize predicted denoised image
108 | # The following block of code is equivalent to `predict_start_from_noise`...
109 | # see zero123_utils.py's version for a simpler implementation.
110 | alphas = self.scheduler.alphas.to(latents.device)
111 | total_timesteps = self.max_step - self.min_step + 1
112 | index = total_timesteps - t.to(latents.device) - 1
113 | b = len(noise_pred)
114 | a_t = alphas[index].reshape(b,1,1,1).to(latents.device)
115 | sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
116 | sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(latents.device)
117 | pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
118 | result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
119 |
120 | # visualize noisier image
121 | result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
122 |
123 | # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
124 | viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
125 | save_image(viz_images, save_guidance_path)
126 |
127 | targets = (latents - grad).detach()
128 | loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
129 |
130 | return loss
131 |
132 |
133 | @torch.no_grad()
134 | def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, generator=None):
135 |
136 | if latents is None:
137 | latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=self.unet.dtype).to(text_embeddings.device)
138 |
139 | self.scheduler.set_timesteps(num_inference_steps)
140 |
141 | for i, t in enumerate(self.scheduler.timesteps):
142 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
143 | latent_model_input = torch.cat([latents] * 2)
144 | # predict the noise residual
145 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
146 |
147 | # perform guidance
148 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
149 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
150 |
151 | # compute the previous noisy sample x_t -> x_t-1
152 | latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
153 |
154 | return latents
155 |
156 | def decode_latents(self, latents):
157 | latents = latents.to(self.vae.dtype)
158 | latents = 1 / self.vae.config.scaling_factor * latents
159 |
160 | imgs = self.vae.decode(latents).sample
161 | imgs = (imgs / 2 + 0.5).clamp(0, 1)
162 |
163 | return imgs
164 |
165 | def encode_imgs(self, imgs):
166 | # imgs: [B, 3, H, W]
167 | imgs = 2 * imgs - 1
168 | posterior = self.vae.encode(imgs).latent_dist
169 | latents = posterior.sample() * self.vae.config.scaling_factor
170 |
171 | return latents
172 |
173 | def spec_to_audio(self, spec):
174 | spec = spec.to(dtype=self.precision_t)
175 | denorm_spec = denormalize_spectrogram(spec)
176 | audio = self.vocoder.inference(denorm_spec)
177 | return audio
178 |
179 | def prompt_to_audio(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, device=None, generator=None):
180 | if isinstance(prompts, str):
181 | prompts = [prompts]
182 |
183 | if isinstance(negative_prompts, str):
184 | negative_prompts = [negative_prompts]
185 |
186 | # Prompts -> text embeds
187 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768]
188 | neg_embeds = self.get_text_embeds(negative_prompts, device)
189 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
190 |
191 | # Text embeds -> img latents
192 | latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) # [1, 4, 64, 64]
193 |
194 | # Img latents -> imgs
195 | imgs = self.decode_latents(latents) # [1, 3, 512, 512]
196 | spec = imgs[0]
197 | denorm_spec = denormalize_spectrogram(spec)
198 | audio = self.vocoder.inference(denorm_spec)
199 | # # Img to Numpy
200 | # imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
201 | # imgs = (imgs * 255).round().astype('uint8')
202 |
203 | return audio
204 |
205 | def prompt_to_spec(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, device=None, generator=None):
206 | if isinstance(prompts, str):
207 | prompts = [prompts]
208 |
209 | if isinstance(negative_prompts, str):
210 | negative_prompts = [negative_prompts]
211 |
212 | # Prompts -> text embeds
213 | pos_embeds = self.get_text_embeds(prompts, device) # [1, 77, 768]
214 | neg_embeds = self.get_text_embeds(negative_prompts, device)
215 | text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
216 |
217 | # Text embeds -> img latents
218 | latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) # [1, 4, 64, 64]
219 |
220 | # Img latents -> imgs
221 | imgs = self.decode_latents(latents) # [1, 3, 512, 512]
222 | return imgs
223 |
224 | if __name__ == '__main__':
225 | import numpy as np
226 | import argparse
227 | from PIL import Image
228 | import os
229 | import soundfile as sf
230 |
231 | parser = argparse.ArgumentParser()
232 | parser.add_argument('--prompt', type=str, default='A kitten mewing for attention')
233 | parser.add_argument('--negative', default='', type=str)
234 | parser.add_argument('--repo_id', type=str, default='auffusion/auffusion-full-no-adapter', help="stable diffusion version")
235 | parser.add_argument('--fp16', action='store_true', help="use float16 for training")
236 | parser.add_argument('--H', type=int, default=256)
237 | parser.add_argument('--W', type=int, default=1024)
238 | parser.add_argument('--seed', type=int, default=0)
239 | parser.add_argument('--steps', type=int, default=100)
240 | parser.add_argument('--out_dir', type=str, default='logs/test')
241 |
242 | opt = parser.parse_args()
243 |
244 | seed_everything(opt.seed)
245 |
246 | device = torch.device('cuda')
247 |
248 | sd = AuffusionGuidance(repo_id=opt.repo_id, fp16=opt.fp16)
249 | sd = sd.to(device)
250 |
251 | audio = sd.prompt_to_audio(opt.prompt, opt.negative, opt.H, opt.W, opt.steps, device=device)
252 | # import pdb; pdb.set_trace()
253 | # visualize audio
254 | save_folder = opt.out_dir
255 | os.makedirs(save_folder, exist_ok=True)
256 | save_path = os.path.join(save_folder, f'{opt.prompt}.wav')
257 | sf.write(save_path, np.ravel(audio), samplerate=16000)
--------------------------------------------------------------------------------
/src/colorization/samplers.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | from PIL import Image
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision.transforms.functional as TF
8 |
9 | from diffusers.utils.torch_utils import randn_tensor
10 |
11 | @torch.no_grad()
12 | def sample_stage_1(
13 | model,
14 | prompt_embeds,
15 | negative_prompt_embeds,
16 | views,
17 | height=None,
18 | width=None,
19 | fixed_im=None,
20 | num_inference_steps=100,
21 | guidance_scale=7.0,
22 | reduction='mean',
23 | generator=None,
24 | num_recurse=1,
25 | start_diffusion_step=0
26 | ):
27 |
28 | # Params
29 | num_images_per_prompt = 1
30 | device = torch.device('cuda') # Sometimes model device is cpu???
31 | height = model.unet.config.sample_size if height is None else height
32 | width = model.unet.config.sample_size if width is None else width
33 | batch_size = 1 # TODO: Support larger batch sizes, maybe
34 | num_prompts = prompt_embeds.shape[0]
35 | assert num_prompts == len(views), \
36 | "Number of prompts must match number of views!"
37 |
38 | # Resize image to correct size
39 | if fixed_im is not None:
40 | fixed_im = TF.resize(fixed_im, height, antialias=False)
41 |
42 | # For CFG
43 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
44 |
45 | # Setup timesteps
46 | model.scheduler.set_timesteps(num_inference_steps, device=device)
47 | timesteps = model.scheduler.timesteps
48 |
49 | # Make intermediate_images
50 | noisy_images = model.prepare_intermediate_images(
51 | batch_size * num_images_per_prompt,
52 | model.unet.config.in_channels,
53 | height,
54 | width,
55 | prompt_embeds.dtype,
56 | device,
57 | generator,
58 | )
59 |
60 | for i, t in enumerate(timesteps):
61 | for j in range(num_recurse):
62 | # Logic to keep a component fixed to reference image
63 | if fixed_im is not None:
64 | # Inject noise
65 | alpha_cumprod = model.scheduler.alphas_cumprod[t]
66 | im_noisy = torch.sqrt(alpha_cumprod) * fixed_im + torch.sqrt(1 - alpha_cumprod) * torch.randn_like(fixed_im)
67 |
68 | # Replace component in noisy images with component from fixed image
69 | im_noisy_component = views[0].inverse_view(im_noisy).to(noisy_images.device).to(noisy_images.dtype)
70 | noisy_images_component = views[1].inverse_view(noisy_images[0])
71 | noisy_images = im_noisy_component + noisy_images_component
72 |
73 | # Correct for factor of 2 from view TODO: Fix this....
74 | noisy_images = noisy_images[None] / 2.
75 |
76 | # "Reset" diffusion by replacing noisy images with noisy version
77 | # of grayscale image. All diffusion steps before this one are "thrown away"
78 | if i == start_diffusion_step:
79 | noisy_images = im_noisy.to(noisy_images.device).to(noisy_images.dtype)[None]
80 |
81 | # Apply views to noisy_image
82 | viewed_noisy_images = []
83 | for view_fn in views:
84 | viewed_noisy_images.append(view_fn.view(noisy_images[0]))
85 | viewed_noisy_images = torch.stack(viewed_noisy_images)
86 |
87 | # Duplicate inputs for CFG
88 | # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
89 | model_input = torch.cat([viewed_noisy_images] * 2)
90 | model_input = model.scheduler.scale_model_input(model_input, t)
91 |
92 | # Predict noise estimate
93 | noise_pred = model.unet(
94 | model_input,
95 | t,
96 | encoder_hidden_states=prompt_embeds,
97 | cross_attention_kwargs=None,
98 | return_dict=False,
99 | )[0]
100 |
101 | # Extract uncond (neg) and cond noise estimates
102 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
103 |
104 | # Invert the unconditional (negative) estimates
105 | inverted_preds = []
106 | for pred, view in zip(noise_pred_uncond, views):
107 | inverted_pred = view.inverse_view(pred)
108 | inverted_preds.append(inverted_pred)
109 | noise_pred_uncond = torch.stack(inverted_preds)
110 |
111 | # Invert the conditional estimates
112 | inverted_preds = []
113 | for pred, view in zip(noise_pred_text, views):
114 | inverted_pred = view.inverse_view(pred)
115 | inverted_preds.append(inverted_pred)
116 | noise_pred_text = torch.stack(inverted_preds)
117 |
118 | # Split into noise estimate and variance estimates
119 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
120 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
121 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
122 |
123 | # Reduce predicted noise and variances
124 | noise_pred = noise_pred.view(-1,num_prompts,3,height,width)
125 | predicted_variance = predicted_variance.view(-1,num_prompts,3,height,width)
126 | if reduction == 'mean':
127 | noise_pred = noise_pred.mean(1)
128 | predicted_variance = predicted_variance.mean(1)
129 | elif reduction == 'alternate':
130 | noise_pred = noise_pred[:,i%num_prompts]
131 | predicted_variance = predicted_variance[:,i%num_prompts]
132 | else:
133 | raise ValueError('Reduction must be either `mean` or `alternate`')
134 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
135 |
136 | # compute the previous noisy sample x_t -> x_t-1
137 | noisy_images = model.scheduler.step(
138 | noise_pred, t, noisy_images, generator=generator, return_dict=False
139 | )[0]
140 |
141 | if j != (num_recurse - 1):
142 | beta = model.scheduler.betas[t]
143 | noisy_images = (1 - beta).sqrt() * noisy_images + beta.sqrt() * torch.randn_like(noisy_images)
144 |
145 | # Return denoised images
146 | return noisy_images
147 |
148 |
149 |
150 |
151 |
152 | @torch.no_grad()
153 | def sample_stage_2(
154 | model,
155 | image,
156 | prompt_embeds,
157 | negative_prompt_embeds,
158 | views,
159 | height=None,
160 | width=None,
161 | fixed_im=None,
162 | num_inference_steps=100,
163 | guidance_scale=7.0,
164 | reduction='mean',
165 | noise_level=50,
166 | generator=None
167 | ):
168 |
169 | # Params
170 | batch_size = 1 # TODO: Support larger batch sizes, maybe
171 | num_prompts = prompt_embeds.shape[0]
172 | height = model.unet.config.sample_size if height is None else height
173 | width = model.unet.config.sample_size if width is None else width
174 | device = image.device
175 | num_images_per_prompt = 1
176 |
177 | # Resize fixed image to correct size
178 | if fixed_im is not None:
179 | fixed_im = TF.resize(fixed_im, height, antialias=False)
180 |
181 | # For CFG
182 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
183 |
184 | # Get timesteps
185 | model.scheduler.set_timesteps(num_inference_steps, device=device)
186 | timesteps = model.scheduler.timesteps
187 |
188 | num_channels = model.unet.config.in_channels // 2
189 | noisy_images = model.prepare_intermediate_images(
190 | batch_size * num_images_per_prompt,
191 | num_channels,
192 | height,
193 | width,
194 | prompt_embeds.dtype,
195 | device,
196 | generator,
197 | )
198 |
199 | # Prepare upscaled image and noise level
200 | image = model.preprocess_image(image, num_images_per_prompt, device)
201 | upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
202 |
203 | noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
204 | noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
205 | upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
206 |
207 | # Condition on noise level, for each model input
208 | noise_level = torch.cat([noise_level] * num_prompts * 2)
209 |
210 | # Denoising Loop
211 | for i, t in enumerate(timesteps):
212 | # Logic to keep a component fixed to reference image
213 | if fixed_im is not None:
214 | # Inject noise
215 | alpha_cumprod = model.scheduler.alphas_cumprod[t]
216 | im_noisy = torch.sqrt(alpha_cumprod) * fixed_im + torch.sqrt(1 - alpha_cumprod) * torch.randn_like(fixed_im)
217 |
218 | # Replace component in noisy images with componen from fixed image
219 | im_noisy_component = views[0].inverse_view(im_noisy).to(noisy_images.device).to(noisy_images.dtype)
220 | noisy_images_component = views[1].inverse_view(noisy_images[0])
221 | noisy_images = im_noisy_component + noisy_images_component
222 |
223 | # Correct for factor of 2 TODO: Fix this....
224 | noisy_images = noisy_images[None] / 2.
225 |
226 | # Cat noisy image with upscaled conditioning image
227 | model_input = torch.cat([noisy_images, upscaled], dim=1)
228 |
229 | # Apply views to noisy_image
230 | viewed_inputs = []
231 | for view_fn in views:
232 | viewed_inputs.append(view_fn.view(model_input[0]))
233 | viewed_inputs = torch.stack(viewed_inputs)
234 |
235 | # Duplicate inputs for CFG
236 | # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
237 | model_input = torch.cat([viewed_inputs] * 2)
238 | model_input = model.scheduler.scale_model_input(model_input, t)
239 |
240 | # predict the noise residual
241 | noise_pred = model.unet(
242 | model_input,
243 | t,
244 | encoder_hidden_states=prompt_embeds,
245 | class_labels=noise_level,
246 | cross_attention_kwargs=None,
247 | return_dict=False,
248 | )[0]
249 |
250 | # Extract uncond (neg) and cond noise estimates
251 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
252 |
253 | # Invert the unconditional (negative) estimates
254 | # TODO: pretty sure you can combine these into one loop
255 | inverted_preds = []
256 | for pred, view in zip(noise_pred_uncond, views):
257 | inverted_pred = view.inverse_view(pred)
258 | inverted_preds.append(inverted_pred)
259 | noise_pred_uncond = torch.stack(inverted_preds)
260 |
261 | # Invert the conditional estimates
262 | inverted_preds = []
263 | for pred, view in zip(noise_pred_text, views):
264 | inverted_pred = view.inverse_view(pred)
265 | inverted_preds.append(inverted_pred)
266 | noise_pred_text = torch.stack(inverted_preds)
267 |
268 | # Split predicted noise and predicted variances
269 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
270 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
271 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
272 |
273 | # Combine noise estimates (and variance estimates)
274 | noise_pred = noise_pred.view(-1,num_prompts,3,height,width)
275 | predicted_variance = predicted_variance.view(-1,num_prompts,3,height,width)
276 | if reduction == 'mean':
277 | noise_pred = noise_pred.mean(1)
278 | predicted_variance = predicted_variance.mean(1)
279 | elif reduction == 'alternate':
280 | noise_pred = noise_pred[:,i%num_prompts]
281 | predicted_variance = predicted_variance[:,i%num_prompts]
282 |
283 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
284 |
285 | # compute the previous noisy sample x_t -> x_t-1
286 | noisy_images = model.scheduler.step(
287 | noise_pred, t, noisy_images, generator=generator, return_dict=False
288 | )[0]
289 |
290 | # Return denoised images
291 | return noisy_images
--------------------------------------------------------------------------------
/src/models/components/auffusion_converter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | import math
5 | import os
6 | import random
7 | import torch
8 | import json
9 | import torch.utils.data
10 | import numpy as np
11 | import librosa
12 | from librosa.util import normalize
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | import torch.nn.functional as F
17 | import torch.nn as nn
18 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
19 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
20 | from huggingface_hub import snapshot_download
21 |
22 | MAX_WAV_VALUE = 32768.0
23 |
24 |
25 | def load_wav(full_path):
26 | sampling_rate, data = read(full_path)
27 | return data, sampling_rate
28 |
29 |
30 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
31 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
32 |
33 |
34 | def dynamic_range_decompression(x, C=1):
35 | return np.exp(x) / C
36 |
37 |
38 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
39 | return torch.log(torch.clamp(x, min=clip_val) * C)
40 |
41 |
42 | def dynamic_range_decompression_torch(x, C=1):
43 | return torch.exp(x) / C
44 |
45 |
46 | def spectral_normalize_torch(magnitudes):
47 | output = dynamic_range_compression_torch(magnitudes)
48 | return output
49 |
50 |
51 | def spectral_de_normalize_torch(magnitudes):
52 | output = dynamic_range_decompression_torch(magnitudes)
53 | return output
54 |
55 |
56 | mel_basis = {}
57 | hann_window = {}
58 |
59 |
60 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
61 | if torch.min(y) < -1.:
62 | print('min value is ', torch.min(y))
63 | if torch.max(y) > 1.:
64 | print('max value is ', torch.max(y))
65 |
66 | global mel_basis, hann_window
67 | if fmax not in mel_basis:
68 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
69 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
70 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
71 |
72 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
73 | y = y.squeeze(1)
74 |
75 | # complex tensor as default, then use view_as_real for future pytorch compatibility
76 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
77 | spec = torch.view_as_real(spec)
78 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
79 |
80 | spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
81 | spec = spectral_normalize_torch(spec)
82 |
83 | return spec
84 |
85 |
86 | def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
87 | if torch.min(y) < -1.:
88 | print('min value is ', torch.min(y))
89 | if torch.max(y) > 1.:
90 | print('max value is ', torch.max(y))
91 |
92 | global hann_window
93 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
94 |
95 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
96 | y = y.squeeze(1)
97 |
98 | # complex tensor as default, then use view_as_real for future pytorch compatibility
99 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
100 | spec = torch.view_as_real(spec)
101 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
102 |
103 | return spec
104 |
105 |
106 | def normalize_spectrogram(
107 | spectrogram: torch.Tensor,
108 | max_value: float = 200,
109 | min_value: float = 1e-5,
110 | power: float = 1.,
111 | inverse: bool = False,
112 | flip: bool = True
113 | ) -> torch.Tensor:
114 |
115 | # Rescale to 0-1
116 | max_value = np.log(max_value) # 5.298317366548036
117 | min_value = np.log(min_value) # -11.512925464970229
118 |
119 | assert spectrogram.max() <= max_value and spectrogram.min() >= min_value
120 |
121 | data = (spectrogram - min_value) / (max_value - min_value)
122 |
123 | # Invert
124 | if inverse:
125 | data = 1 - data
126 |
127 | # Apply the power curve
128 | data = torch.pow(data, power)
129 |
130 | # 1D -> 3D
131 | data = data.repeat(3, 1, 1)
132 |
133 | # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
134 | if flip:
135 | data = torch.flip(data, [1])
136 |
137 | return data
138 |
139 |
140 |
141 | def denormalize_spectrogram(
142 | data: torch.Tensor,
143 | max_value: float = 200,
144 | min_value: float = 1e-5,
145 | power: float = 1,
146 | inverse: bool = False,
147 | ) -> torch.Tensor:
148 |
149 | max_value = np.log(max_value)
150 | min_value = np.log(min_value)
151 |
152 | # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner
153 | data = torch.flip(data, [1])
154 |
155 | assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape))
156 |
157 | # if data.shape[0] == 1:
158 | # data = data.repeat(3, 1, 1)
159 |
160 | # assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0])
161 |
162 | # data = data[0]
163 | data = data.mean(dim=0)
164 |
165 | # Reverse the power curve
166 | data = torch.pow(data, 1 / power)
167 |
168 | # Invert
169 | if inverse:
170 | data = 1 - data
171 |
172 | # Rescale to max value
173 | spectrogram = data * (max_value - min_value) + min_value
174 |
175 | return spectrogram
176 |
177 |
178 | def get_mel_spectrogram_from_audio(audio, device="cuda"):
179 | audio = audio / MAX_WAV_VALUE
180 | audio = librosa.util.normalize(audio) * 0.95
181 |
182 | audio = torch.FloatTensor(audio)
183 | audio = audio.unsqueeze(0)
184 |
185 | waveform = audio.to(device)
186 | spec = mel_spectrogram(waveform, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False)
187 | return audio, spec
188 |
189 |
190 |
191 | LRELU_SLOPE = 0.1
192 | MAX_WAV_VALUE = 32768.0
193 |
194 |
195 | class AttrDict(dict):
196 | def __init__(self, *args, **kwargs):
197 | super(AttrDict, self).__init__(*args, **kwargs)
198 | self.__dict__ = self
199 |
200 |
201 | def get_config(config_path):
202 | config = json.loads(open(config_path).read())
203 | config = AttrDict(config)
204 | return config
205 |
206 | def init_weights(m, mean=0.0, std=0.01):
207 | classname = m.__class__.__name__
208 | if classname.find("Conv") != -1:
209 | m.weight.data.normal_(mean, std)
210 |
211 |
212 | def apply_weight_norm(m):
213 | classname = m.__class__.__name__
214 | if classname.find("Conv") != -1:
215 | weight_norm(m)
216 |
217 |
218 | def get_padding(kernel_size, dilation=1):
219 | return int((kernel_size*dilation - dilation)/2)
220 |
221 |
222 | class ResBlock1(torch.nn.Module):
223 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
224 | super(ResBlock1, self).__init__()
225 | self.h = h
226 | self.convs1 = nn.ModuleList([
227 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
228 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
229 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2])))
230 | ])
231 | self.convs1.apply(init_weights)
232 |
233 | self.convs2 = nn.ModuleList([
234 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
237 | ])
238 | self.convs2.apply(init_weights)
239 |
240 | def forward(self, x):
241 | for c1, c2 in zip(self.convs1, self.convs2):
242 | xt = F.leaky_relu(x, LRELU_SLOPE)
243 | xt = c1(xt)
244 | xt = F.leaky_relu(xt, LRELU_SLOPE)
245 | xt = c2(xt)
246 | x = xt + x
247 | return x
248 |
249 | def remove_weight_norm(self):
250 | for l in self.convs1:
251 | remove_weight_norm(l)
252 | for l in self.convs2:
253 | remove_weight_norm(l)
254 |
255 |
256 | class ResBlock2(torch.nn.Module):
257 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
258 | super(ResBlock2, self).__init__()
259 | self.h = h
260 | self.convs = nn.ModuleList([
261 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
262 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1])))
263 | ])
264 | self.convs.apply(init_weights)
265 |
266 | def forward(self, x):
267 | for c in self.convs:
268 | xt = F.leaky_relu(x, LRELU_SLOPE)
269 | xt = c(xt)
270 | x = xt + x
271 | return x
272 |
273 | def remove_weight_norm(self):
274 | for l in self.convs:
275 | remove_weight_norm(l)
276 |
277 |
278 |
279 | class Generator(torch.nn.Module):
280 | def __init__(self, h):
281 | super(Generator, self).__init__()
282 | self.h = h
283 | self.num_kernels = len(h.resblock_kernel_sizes)
284 | self.num_upsamples = len(h.upsample_rates)
285 | self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512
286 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2
287 |
288 | self.ups = nn.ModuleList()
289 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
290 | if (k-u) % 2 == 0:
291 | self.ups.append(weight_norm(
292 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
293 | k, u, padding=(k-u)//2)))
294 | else:
295 | self.ups.append(weight_norm(
296 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
297 | k, u, padding=(k-u)//2+1, output_padding=1)))
298 |
299 | # self.ups.append(weight_norm(
300 | # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
301 | # k, u, padding=(k-u)//2)))
302 |
303 |
304 | self.resblocks = nn.ModuleList()
305 | for i in range(len(self.ups)):
306 | ch = h.upsample_initial_channel//(2**(i+1))
307 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
308 | self.resblocks.append(resblock(h, ch, k, d))
309 |
310 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
311 | self.ups.apply(init_weights)
312 | self.conv_post.apply(init_weights)
313 |
314 | def forward(self, x):
315 | x = self.conv_pre(x)
316 | for i in range(self.num_upsamples):
317 | x = F.leaky_relu(x, LRELU_SLOPE)
318 | x = self.ups[i](x)
319 | xs = None
320 | for j in range(self.num_kernels):
321 | if xs is None:
322 | xs = self.resblocks[i*self.num_kernels+j](x)
323 | else:
324 | xs += self.resblocks[i*self.num_kernels+j](x)
325 | x = xs / self.num_kernels
326 | x = F.leaky_relu(x)
327 | x = self.conv_post(x)
328 | x = torch.tanh(x)
329 |
330 | return x
331 |
332 | def remove_weight_norm(self):
333 | for l in self.ups:
334 | remove_weight_norm(l)
335 | for l in self.resblocks:
336 | l.remove_weight_norm()
337 | remove_weight_norm(self.conv_pre)
338 | remove_weight_norm(self.conv_post)
339 |
340 | @classmethod
341 | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None):
342 | if not os.path.isdir(pretrained_model_name_or_path):
343 | pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
344 |
345 | if subfolder is not None:
346 | pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder)
347 | config_path = os.path.join(pretrained_model_name_or_path, "config.json")
348 | ckpt_path = os.path.join(pretrained_model_name_or_path, "vocoder.pt")
349 |
350 | config = get_config(config_path)
351 | vocoder = cls(config)
352 |
353 | state_dict_g = torch.load(ckpt_path)
354 | vocoder.load_state_dict(state_dict_g["generator"])
355 | vocoder.eval()
356 | vocoder.remove_weight_norm()
357 | return vocoder
358 |
359 |
360 | @torch.no_grad()
361 | def inference(self, mels, lengths=None):
362 | self.eval()
363 | with torch.no_grad():
364 | wavs = self(mels).squeeze(1)
365 |
366 | # wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16")
367 | wavs = (wavs.cpu().numpy()).astype("float32") # I change the code from int16 to float32
368 |
369 | if lengths is not None:
370 | wavs = wavs[:, :lengths]
371 |
372 | return wavs
--------------------------------------------------------------------------------