├── .github
└── workflows
│ ├── black.yml
│ ├── ci.yml
│ ├── mypy.yml
│ ├── pip.yml
│ └── ruff.yml
├── .gitignore
├── CITATION
├── LICENSE
├── README.md
├── VERSION
├── cog.yaml
├── integrations
├── README.md
├── __init__.py
├── baseten.py
└── cog_riffusion.py
├── pyproject.toml
├── requirements.txt
├── requirements_all.txt
├── requirements_dev.txt
├── riffusion
├── __init__.py
├── audio_splitter.py
├── cli.py
├── datatypes.py
├── external
│ ├── README.md
│ ├── __init__.py
│ └── prompt_weighting.py
├── py.typed
├── riffusion_pipeline.py
├── server.py
├── spectrogram_converter.py
├── spectrogram_image_converter.py
├── spectrogram_params.py
├── streamlit
│ ├── README.md
│ ├── __init__.py
│ ├── playground.py
│ ├── tasks
│ │ ├── __init__.py
│ │ ├── audio_to_audio.py
│ │ ├── home.py
│ │ ├── image_to_audio.py
│ │ ├── interpolation.py
│ │ ├── sample_clips.py
│ │ ├── split_audio.py
│ │ ├── text_to_audio.py
│ │ └── text_to_audio_batch.py
│ └── util.py
└── util
│ ├── __init__.py
│ ├── audio_util.py
│ ├── base64_util.py
│ ├── fft_util.py
│ ├── image_util.py
│ └── torch_util.py
├── seed_images
├── agile.png
├── marim.png
├── mask_beat_lines_80.png
├── mask_gradient_dark.png
├── mask_gradient_top_70.png
├── mask_gradient_top_fifth_75.png
├── mask_top_third_75.png
├── mask_top_third_95.png
├── motorway.png
├── og_beat.png
└── vibes.png
├── setup.py
└── test
├── __init__.py
├── audio_to_image_test.py
├── image_to_audio_test.py
├── image_util_test.py
├── print_exif_test.py
├── sample_clips_test.py
├── spectrogram_converter_test.py
├── spectrogram_image_converter_test.py
├── test_case.py
└── test_data
├── README.md
└── tired_traveler
├── clips
├── clip_0_start_15795_ms_duration_5678_ms.wav
├── clip_1_start_860_ms_duration_5678_ms.wav
└── clip_2_start_103694_ms_duration_5678_ms.wav
├── images
├── clip_2_start_103694_ms_duration_5678_ms.png
└── clip_2_start_103694_ms_duration_5678_ms_stereo.png
└── tired_traveler.mp3
/.github/workflows/black.yml:
--------------------------------------------------------------------------------
1 | name: black format
2 |
3 | on:
4 | push:
5 | branches:
6 | - 'main'
7 | pull_request:
8 | types: [opened, synchronize, reopened, ready_for_review]
9 |
10 | jobs:
11 | run:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v2
15 | - uses: psf/black@stable
16 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: python test
2 |
3 | on:
4 | push:
5 | branches:
6 | - 'main'
7 | pull_request:
8 | types: [opened, synchronize, reopened, ready_for_review]
9 |
10 | jobs:
11 | run:
12 | runs-on: ubuntu-latest
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.9", "3.10"]
17 |
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@v3
21 |
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 |
27 | - name: Install system packages
28 | run: |
29 | sudo apt-get update
30 | sudo apt-get install -y ffmpeg libsndfile1
31 |
32 | - name: Upgrade pip
33 | run: |
34 | python -m pip install --upgrade pip
35 |
36 | - name: Install pip packages
37 | run: |
38 | pip install -r requirements_all.txt
39 |
40 | - name: Test with unittest
41 | run: |
42 | RIFFUSION_TEST_DEVICE=cpu python -m unittest test/*_test.py
43 |
--------------------------------------------------------------------------------
/.github/workflows/mypy.yml:
--------------------------------------------------------------------------------
1 | name: mypy type check
2 |
3 | on:
4 | push:
5 | branches:
6 | - 'main'
7 | pull_request:
8 | types: [opened, synchronize, reopened, ready_for_review]
9 |
10 | jobs:
11 | run:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v3
15 | with:
16 | submodules: recursive
17 |
18 | - uses: jpetrucciani/mypy-check@master
19 | with:
20 | mypy_flags: '--config-file pyproject.toml'
21 | requirements_file: 'requirements_all.txt'
22 |
--------------------------------------------------------------------------------
/.github/workflows/pip.yml:
--------------------------------------------------------------------------------
1 | name: pip install
2 |
3 | on:
4 | push:
5 | branches:
6 | - 'main'
7 | pull_request:
8 | types: [opened, synchronize, reopened, ready_for_review]
9 |
10 | jobs:
11 | run:
12 | runs-on: ubuntu-latest
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.9", "3.10"]
17 |
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@v3
21 |
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 |
27 | - name: Editable Pip Install
28 | run: python -m pip install -e .
29 |
--------------------------------------------------------------------------------
/.github/workflows/ruff.yml:
--------------------------------------------------------------------------------
1 | name: ruff lint
2 |
3 | on:
4 | push:
5 | branches:
6 | - 'main'
7 | pull_request:
8 | types: [opened, synchronize, reopened, ready_for_review]
9 |
10 | jobs:
11 | run:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v2
15 | - uses: jpetrucciani/ruff-check@main
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # VSCode
10 | .vscode
11 |
12 | # Cog
13 | .cog/
14 |
15 | # Random stuff I don't care about
16 | .graveyard/
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # OSX cruft
40 | .DS_Store
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
--------------------------------------------------------------------------------
/CITATION:
--------------------------------------------------------------------------------
1 | @article{Forsgren_Martiros_2022,
2 | author = {Forsgren, Seth* and Martiros, Hayk*},
3 | title = {{Riffusion - Stable diffusion for real-time music generation}},
4 | url = {https://riffusion.com/about},
5 | year = {2022}
6 | }
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2022 Hayk Martiros and Seth Forsgren
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4 | associated documentation files (the "Software"), to deal in the Software without restriction,
5 | including without limitation the rights to use, copy, modify, merge, publish, distribute,
6 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
7 | furnished to do so, subject to the following conditions:
8 |
9 | The above copyright notice and this permission notice shall be included in all copies or substantial
10 | portions of the Software.
11 |
12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
13 | NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
14 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
15 | OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # :guitar: Riffusion (hobby)
2 |
3 | :no_entry: This project is no longer actively maintained.
4 |
5 |
6 |
7 |
8 |
9 | Riffusion is a library for real-time music and audio generation with stable diffusion.
10 |
11 | Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
12 |
13 | This is the core repository for riffusion image and audio processing code.
14 |
15 | * Diffusion pipeline that performs prompt interpolation combined with image conditioning
16 | * Conversions between spectrogram images and audio clips
17 | * Command-line interface for common tasks
18 | * Interactive app using streamlit
19 | * Flask server to provide model inference via API
20 | * Various third party integrations
21 |
22 | Related repositories:
23 | * Web app: https://github.com/riffusion/riffusion-app
24 | * Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
25 |
26 | ## Citation
27 |
28 | If you build on this work, please cite it as follows:
29 |
30 | ```
31 | @article{Forsgren_Martiros_2022,
32 | author = {Forsgren, Seth* and Martiros, Hayk*},
33 | title = {{Riffusion - Stable diffusion for real-time music generation}},
34 | url = {https://riffusion.com/about},
35 | year = {2022}
36 | }
37 | ```
38 |
39 | ## Install
40 |
41 | Tested in CI with Python 3.9 and 3.10.
42 |
43 | It's highly recommended to set up a virtual Python environment with `conda` or `virtualenv`:
44 | ```
45 | conda create --name riffusion python=3.9
46 | conda activate riffusion
47 | ```
48 |
49 | Install Python dependencies:
50 | ```
51 | python -m pip install -r requirements.txt
52 | ```
53 |
54 | In order to use audio formats other than WAV, [ffmpeg](https://ffmpeg.org/download.html) is required.
55 | ```
56 | sudo apt-get install ffmpeg # linux
57 | brew install ffmpeg # mac
58 | conda install -c conda-forge ffmpeg # conda
59 | ```
60 |
61 | If torchaudio has no backend, you may need to install `libsndfile`. See [this issue](https://github.com/riffusion/riffusion/issues/12).
62 |
63 | If you have an issue, try upgrading [diffusers](https://github.com/huggingface/diffusers). Tested with 0.9 - 0.11.
64 |
65 | Guides:
66 | * [Simple Install Guide for Windows](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
67 |
68 | ## Backends
69 |
70 | ### CPU
71 | `cpu` is supported but is quite slow.
72 |
73 | ### CUDA
74 | `cuda` is the recommended and most performant backend.
75 |
76 | To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
77 | [install guide](https://pytorch.org/get-started/locally/) or
78 | [stable wheels](https://download.pytorch.org/whl/torch_stable.html).
79 |
80 | To generate audio in real-time, you need a GPU that can run stable diffusion with approximately 50
81 | steps in under five seconds, such as a 3090 or A10G.
82 |
83 | Test availability with:
84 |
85 | ```python3
86 | import torch
87 | torch.cuda.is_available()
88 | ```
89 |
90 | ### MPS
91 | The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
92 | particularly for audio processing. You may need to set
93 | `PYTORCH_ENABLE_MPS_FALLBACK=1`.
94 |
95 | In addition, this backend is not deterministic.
96 |
97 | Test availability with:
98 |
99 | ```python3
100 | import torch
101 | torch.backends.mps.is_available()
102 | ```
103 |
104 | ## Command-line interface
105 |
106 | Riffusion comes with a command line interface for performing common tasks.
107 |
108 | See available commands:
109 | ```
110 | python -m riffusion.cli -h
111 | ```
112 |
113 | Get help for a specific command:
114 | ```
115 | python -m riffusion.cli image-to-audio -h
116 | ```
117 |
118 | Execute:
119 | ```
120 | python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
121 | ```
122 |
123 | ## Riffusion Playground
124 |
125 | Riffusion contains a [streamlit](https://streamlit.io/) app for interactive use and exploration.
126 |
127 | Run with:
128 | ```
129 | python -m riffusion.streamlit.playground
130 | ```
131 |
132 | And access at http://127.0.0.1:8501/
133 |
134 |
135 |
136 | ## Run the model server
137 |
138 | Riffusion can be run as a flask server that provides inference via API. This server enables the [web app](https://github.com/riffusion/riffusion-app) to run locally.
139 |
140 | Run with:
141 |
142 | ```
143 | python -m riffusion.server --host 127.0.0.1 --port 3013
144 | ```
145 |
146 | You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
147 |
148 | Use the `--device` argument to specify the torch device to use.
149 |
150 | The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
151 |
152 | Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
153 | ```
154 | {
155 | "alpha": 0.75,
156 | "num_inference_steps": 50,
157 | "seed_image_id": "og_beat",
158 |
159 | "start": {
160 | "prompt": "church bells on sunday",
161 | "seed": 42,
162 | "denoising": 0.75,
163 | "guidance": 7.0
164 | },
165 |
166 | "end": {
167 | "prompt": "jazz with piano",
168 | "seed": 123,
169 | "denoising": 0.75,
170 | "guidance": 7.0
171 | }
172 | }
173 | ```
174 |
175 | Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L54) for the API):
176 | ```
177 | {
178 | "image": "< base64 encoded JPEG image >",
179 | "audio": "< base64 encoded MP3 clip >"
180 | }
181 | ```
182 |
183 | ## Tests
184 | Tests live in the `test/` directory and are implemented with `unittest`.
185 |
186 | To run all tests:
187 | ```
188 | python -m unittest test/*_test.py
189 | ```
190 |
191 | To run a single test:
192 | ```
193 | python -m unittest test.audio_to_image_test
194 | ```
195 |
196 | To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
197 | ```
198 | RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
199 | ```
200 |
201 | To run a single test case within a test:
202 | ```
203 | python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
204 | ```
205 |
206 | To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with
207 | `cpu`, `cuda`, and `mps` backends.
208 |
209 | ## Development Guide
210 | Install additional packages for dev with `python -m pip install -r requirements_dev.txt`.
211 |
212 | * Linter: `ruff`
213 | * Formatter: `black`
214 | * Type checker: `mypy`
215 |
216 | These are configured in `pyproject.toml`.
217 |
218 | The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
219 |
220 | CI is run through GitHub Actions from `.github/workflows/ci.yml`.
221 |
222 | Contributions are welcome through pull requests.
223 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.3.1
2 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | # set to true if your model requires a GPU
6 | gpu: true
7 |
8 | # a list of ubuntu apt packages to install
9 | system_packages:
10 | - "ffmpeg"
11 | - "libsndfile1"
12 |
13 | # python version in the form '3.8' or '3.8.12'
14 | python_version: "3.9"
15 |
16 | # a list of packages in the format ==
17 | python_packages:
18 | - "accelerate==0.15.0"
19 | - "argh==0.26.2"
20 | - "dacite==1.6.0"
21 | - "diffusers==0.10.2"
22 | - "flask_cors==3.0.10"
23 | - "flask==1.1.2"
24 | - "numpy==1.19.4"
25 | - "pillow==9.1.0"
26 | - "pydub==0.25.1"
27 | - "scipy==1.6.3"
28 | - "torch==1.13.0"
29 | - "torchaudio==0.13.0"
30 | - "transformers==4.25.1"
31 |
32 | # commands run after the environment is setup
33 | # run:
34 | # - "echo env is ready!"
35 | # - "echo another command if needed"
36 |
37 | # predict.py defines how predictions are run on your model
38 | predict: "integrations/cog_riffusion.py:RiffusionPredictor"
39 |
--------------------------------------------------------------------------------
/integrations/README.md:
--------------------------------------------------------------------------------
1 | # Integrations
2 |
3 | This package contains integrations of Riffusion into third party apps and deployments.
4 |
5 | ## Baseten
6 |
7 | [Baseten](https://baseten.com) is a platform for building and deploying machine learning models.
8 |
9 | ## Replicate
10 |
11 | To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
12 | download the model weights:
13 |
14 | cog run python -m integrations.cog_riffusion --download_weights
15 |
16 | Then you can run predictions:
17 |
18 | cog predict -i prompt_a="funky synth solo"
19 |
20 | You can also view the model on Replicate [here](https://replicate.com/riffusion/riffusion). Owners
21 | can push an updated version of the model like so:
22 |
23 | # download weights locally if you haven't already
24 | cog run python -m integrations.cog_riffusion --download_weights
25 |
26 | cog login
27 | cog push r8.im/riffusion/riffusion
28 |
--------------------------------------------------------------------------------
/integrations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/integrations/__init__.py
--------------------------------------------------------------------------------
/integrations/baseten.py:
--------------------------------------------------------------------------------
1 | """
2 | This file can be used to build a Truss for deployment with Baseten.
3 | If used, it should be renamed to model.py and placed alongside the other
4 | files from /riffusion in the standard /model directory of the Truss.
5 |
6 | For more on the Truss file format, see https://truss.baseten.co/
7 | """
8 |
9 | import typing as T
10 |
11 | import dacite
12 | import torch
13 | from huggingface_hub import snapshot_download
14 |
15 | from riffusion.datatypes import InferenceInput
16 | from riffusion.riffusion_pipeline import RiffusionPipeline
17 | from riffusion.server import compute_request
18 |
19 |
20 | class Model:
21 | """
22 | Baseten Truss model class for riffusion.
23 |
24 | See: https://truss.baseten.co/reference/structure#model.py
25 | """
26 |
27 | def __init__(self, **kwargs) -> None:
28 | self._data_dir = kwargs["data_dir"]
29 | self._config = kwargs["config"]
30 | self._pipeline = None
31 | self._vae = None
32 |
33 | self.checkpoint_name = "riffusion/riffusion-model-v1"
34 |
35 | # Download entire seed image folder from huggingface hub
36 | self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
37 |
38 | def load(self):
39 | """
40 | Load the model. Guaranteed to be called before `predict`.
41 | """
42 | self._pipeline = RiffusionPipeline.load_checkpoint(
43 | checkpoint=self.checkpoint_name,
44 | use_traced_unet=True,
45 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
46 | )
47 |
48 | def preprocess(self, request: T.Dict) -> T.Dict:
49 | """
50 | Incorporate pre-processing required by the model if desired here.
51 |
52 | These might be feature transformations that are tightly coupled to the model.
53 | """
54 | return request
55 |
56 | def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
57 | """
58 | This is the main function that is called.
59 | """
60 | assert self._pipeline is not None, "Model pipeline not loaded"
61 |
62 | try:
63 | inputs = dacite.from_dict(InferenceInput, request)
64 | except dacite.exceptions.WrongTypeError as exception:
65 | return str(exception), 400
66 | except dacite.exceptions.MissingValueError as exception:
67 | return str(exception), 400
68 |
69 | # NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
70 | with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
71 | response = compute_request(
72 | inputs=inputs,
73 | pipeline=self._pipeline,
74 | seed_images_dir=self._seed_images_dir,
75 | )
76 |
77 | return response
78 |
79 | def postprocess(self, request: T.Dict) -> T.Dict:
80 | """
81 | Incorporate post-processing required by the model if desired here.
82 | """
83 | return request
84 |
--------------------------------------------------------------------------------
/integrations/cog_riffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | Prediction interface for Cog ⚙️
3 | https://github.com/replicate/cog/blob/main/docs/python.md
4 | """
5 |
6 | import argparse
7 | import os
8 | import shutil
9 | import typing as T
10 |
11 | import numpy as np
12 | import PIL
13 | import torch
14 | from cog import BaseModel, BasePredictor, Input, Path
15 |
16 | from riffusion.datatypes import InferenceInput, PromptInput
17 | from riffusion.riffusion_pipeline import RiffusionPipeline
18 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter
19 | from riffusion.spectrogram_params import SpectrogramParams
20 |
21 | MODEL_ID = "riffusion/riffusion-model-v1"
22 | MODEL_CACHE = "riffusion-cache"
23 |
24 | # Where built-in seed images are stored
25 | SEED_IMAGES_DIR = Path("./seed_images")
26 | SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val]
27 | SEED_IMAGES.sort()
28 |
29 |
30 | class Output(BaseModel):
31 | """
32 | Output class for riffusion predictions
33 | """
34 |
35 | audio: Path
36 | spectrogram: Path
37 | error: T.Optional[str] = None
38 |
39 |
40 | class RiffusionPredictor(BasePredictor):
41 | """
42 | Implementation of cog predictor object s.t. we can run riffusion predictions w/cog.
43 |
44 | See README & https://github.com/replicate/cog for details
45 | """
46 |
47 | def setup(self, local_files_only=True):
48 | """
49 | Loads the model onto GPU from local cache.
50 | """
51 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
52 |
53 | self.model = RiffusionPipeline.load_checkpoint(
54 | checkpoint=MODEL_ID,
55 | use_traced_unet=True,
56 | device=self.device,
57 | local_files_only=local_files_only,
58 | cache_dir=MODEL_CACHE,
59 | )
60 |
61 | def predict(
62 | self,
63 | prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"),
64 | denoising: float = Input(
65 | description="How much to transform input spectrogram",
66 | default=0.75,
67 | ge=0,
68 | le=1,
69 | ),
70 | prompt_b: str = Input(
71 | description="The second prompt to interpolate with the first,"
72 | "leave blank if no interpolation",
73 | default=None,
74 | ),
75 | alpha: float = Input(
76 | description="Interpolation alpha if using two prompts."
77 | "A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully",
78 | default=0.5,
79 | ge=0,
80 | le=1,
81 | ),
82 | num_inference_steps: int = Input(
83 | description="Number of steps to run the diffusion model", default=50, ge=1
84 | ),
85 | seed_image_id: str = Input(
86 | description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES
87 | ),
88 | ) -> Output:
89 | """
90 | Runs riffusion inference.
91 | """
92 | # Load the seed image by ID
93 | init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png")
94 | if not init_image_path.is_file():
95 | return Output(error=f"Invalid seed image: {seed_image_id}")
96 | init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
97 |
98 | # fake max ints
99 | seed_a = np.random.randint(0, 2147483647)
100 | seed_b = np.random.randint(0, 2147483647)
101 |
102 | start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising)
103 | if not prompt_b: # no transition
104 | prompt_b = prompt_a
105 | alpha = 0
106 | end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising)
107 | riffusion_input = InferenceInput(
108 | start=start,
109 | end=end,
110 | alpha=alpha,
111 | num_inference_steps=num_inference_steps,
112 | seed_image_id=seed_image_id,
113 | )
114 |
115 | # Execute the model to get the spectrogram image
116 | image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None)
117 |
118 | # Reconstruct audio from the image
119 | params = SpectrogramParams()
120 | converter = SpectrogramImageConverter(params=params, device=self.device)
121 | segment = converter.audio_from_spectrogram_image(image)
122 |
123 | if not os.path.exists("out/"):
124 | os.mkdir("out")
125 |
126 | out_img_path = "out/spectrogram.jpg"
127 | image.save("out/spectrogram.jpg", exif=image.getexif())
128 |
129 | out_wav_path = "out/gen_sound.wav"
130 | segment.export(out_wav_path, format="wav")
131 |
132 | return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path))
133 |
134 |
135 | # TODO(hayk): Can we get rid of the below functions and incorporate into
136 | # RiffusionPipeline.load_checkpoint?
137 |
138 |
139 | def download_weights():
140 | """
141 | Clears local cache & downloads riffusion weights
142 | """
143 | if os.path.exists(MODEL_CACHE):
144 | shutil.rmtree(MODEL_CACHE)
145 | os.makedirs(MODEL_CACHE)
146 |
147 | pred = RiffusionPredictor()
148 | pred.setup(local_files_only=False)
149 |
150 |
151 | if __name__ == "__main__":
152 | parser = argparse.ArgumentParser()
153 | parser.add_argument(
154 | "--download_weights", action="store_true", help="Download and cache weights"
155 | )
156 | args = parser.parse_args()
157 | if args.download_weights:
158 | download_weights()
159 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 100
3 |
4 | [tool.ruff]
5 | line-length = 100
6 |
7 | # Which rules to run
8 | select = [
9 | # Pyflakes
10 | "F",
11 | # Pycodestyle
12 | "E",
13 | "W",
14 | # isort
15 | "I001"
16 | ]
17 |
18 | ignore = []
19 |
20 | # Exclude a variety of commonly ignored directories.
21 | exclude = [
22 | ".bzr",
23 | ".direnv",
24 | ".eggs",
25 | ".git",
26 | ".hg",
27 | ".mypy_cache",
28 | ".nox",
29 | ".pants.d",
30 | ".ruff_cache",
31 | ".svn",
32 | ".tox",
33 | ".venv",
34 | "__pypackages__",
35 | "_build",
36 | "buck-out",
37 | "build",
38 | "dist",
39 | "node_modules",
40 | "venv",
41 | ]
42 | per-file-ignores = {}
43 |
44 | # Allow unused variables when underscore-prefixed.
45 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
46 |
47 | # Assume Python 3.10.
48 | target-version = "py310"
49 |
50 | [tool.ruff.mccabe]
51 | # Unlike Flake8, default to a complexity level of 10.
52 | max-complexity = 10
53 |
54 | [tool.mypy]
55 | python_version = "3.10"
56 |
57 | [[tool.mypy.overrides]]
58 | module = "argh.*"
59 | ignore_missing_imports = true
60 |
61 | [[tool.mypy.overrides]]
62 | module = "cog.*"
63 | ignore_missing_imports = true
64 |
65 | [[tool.mypy.overrides]]
66 | module = "diffusers.*"
67 | ignore_missing_imports = true
68 |
69 | [[tool.mypy.overrides]]
70 | module = "huggingface_hub.*"
71 | ignore_missing_imports = true
72 |
73 | [[tool.mypy.overrides]]
74 | module = "numpy.*"
75 | ignore_missing_imports = true
76 |
77 | [[tool.mypy.overrides]]
78 | module = "plotly.*"
79 | ignore_missing_imports = true
80 |
81 | [[tool.mypy.overrides]]
82 | module = "pydub.*"
83 | ignore_missing_imports = true
84 |
85 | [[tool.mypy.overrides]]
86 | module = "scipy.fft.*"
87 | ignore_missing_imports = true
88 |
89 | [[tool.mypy.overrides]]
90 | module = "scipy.io.*"
91 | ignore_missing_imports = true
92 |
93 | [[tool.mypy.overrides]]
94 | module = "torchaudio.*"
95 | ignore_missing_imports = true
96 |
97 | [[tool.mypy.overrides]]
98 | module = "transformers.*"
99 | ignore_missing_imports = true
100 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | argh
3 | dacite
4 | demucs
5 | diffusers==0.9.0
6 | flask
7 | flask_cors
8 | numpy
9 | pillow>=9.1.0
10 | plotly
11 | pydub
12 | pysoundfile
13 | scipy
14 | soundfile
15 | sox
16 | streamlit>=1.18.0
17 | torch
18 | torchaudio
19 | torchvision
20 | transformers
21 |
--------------------------------------------------------------------------------
/requirements_all.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | -r requirements_dev.txt
3 |
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | black
2 | ipdb
3 | mypy
4 | ruff
5 | types-Flask-Cors
6 | types-Pillow
7 | types-requests
8 | types-setuptools
9 | types-tqdm
10 |
--------------------------------------------------------------------------------
/riffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/__init__.py
--------------------------------------------------------------------------------
/riffusion/audio_splitter.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import subprocess
3 | import tempfile
4 | import typing as T
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import pydub
9 | import torch
10 | import torchaudio
11 | from torchaudio.transforms import Fade
12 |
13 | from riffusion.util import audio_util
14 |
15 |
16 | def split_audio(
17 | segment: pydub.AudioSegment,
18 | model_name: str = "htdemucs_6s",
19 | extension: str = "wav",
20 | jobs: int = 4,
21 | device: str = "cuda",
22 | ) -> T.Dict[str, pydub.AudioSegment]:
23 | """
24 | Split audio into stems using demucs.
25 | """
26 | tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
27 |
28 | # Save the audio to a temporary file
29 | audio_path = tmp_dir / "audio.mp3"
30 | segment.export(audio_path, format="mp3")
31 |
32 | # Assemble command
33 | command = [
34 | "demucs",
35 | str(audio_path),
36 | "--name",
37 | model_name,
38 | "--out",
39 | str(tmp_dir),
40 | "--jobs",
41 | str(jobs),
42 | "--device",
43 | device if device != "mps" else "cpu",
44 | ]
45 | print(" ".join(command))
46 |
47 | if extension == "mp3":
48 | command.append("--mp3")
49 |
50 | # Run demucs
51 | subprocess.run(
52 | command,
53 | check=True,
54 | )
55 |
56 | # Load the stems
57 | stems = {}
58 | for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
59 | stem = pydub.AudioSegment.from_file(stem_path)
60 | stems[stem_path.stem] = stem
61 |
62 | # Delete tmp dir
63 | shutil.rmtree(tmp_dir)
64 |
65 | return stems
66 |
67 |
68 | class AudioSplitter:
69 | """
70 | Split audio into instrument stems like {drums, bass, vocals, etc.}
71 |
72 | NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
73 | model in the demucs repo. See the function above. Probably just delete this.
74 |
75 | See:
76 | https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
77 | """
78 |
79 | def __init__(
80 | self,
81 | segment_length_s: float = 10.0,
82 | overlap_s: float = 0.1,
83 | device: str = "cuda",
84 | ):
85 | self.segment_length_s = segment_length_s
86 | self.overlap_s = overlap_s
87 | self.device = device
88 |
89 | self.model = self.load_model().to(device)
90 |
91 | @staticmethod
92 | def load_model(model_path: str = "models/hdemucs_high_trained.pt") -> torchaudio.models.HDemucs:
93 | """
94 | Load the trained HDEMUCS pytorch model.
95 | """
96 | # NOTE(hayk): The sources are baked into the pretrained model and can't be changed
97 | model = torchaudio.models.hdemucs_high(sources=["drums", "bass", "other", "vocals"])
98 |
99 | path = torchaudio.utils.download_asset(model_path)
100 | state_dict = torch.load(path)
101 | model.load_state_dict(state_dict)
102 | model.eval()
103 |
104 | return model
105 |
106 | def split(self, audio: pydub.AudioSegment) -> T.Dict[str, pydub.AudioSegment]:
107 | """
108 | Split the given audio segment into instrument stems.
109 | """
110 | if audio.channels == 1:
111 | audio_stereo = audio.set_channels(2)
112 | elif audio.channels == 2:
113 | audio_stereo = audio
114 | else:
115 | raise ValueError(f"Audio must be stereo, but got {audio.channels} channels")
116 |
117 | # Get as (samples, channels) float numpy array
118 | waveform_np = np.array(audio_stereo.get_array_of_samples())
119 | waveform_np = waveform_np.reshape(-1, audio_stereo.channels)
120 | waveform_np_float = waveform_np.astype(np.float32)
121 |
122 | # To torch and channels-first
123 | waveform = torch.from_numpy(waveform_np_float).to(self.device)
124 | waveform = waveform.transpose(1, 0)
125 |
126 | # Normalize
127 | ref = waveform.mean(0)
128 | waveform = (waveform - ref.mean()) / ref.std()
129 |
130 | # Split
131 | sources = self.separate_sources(
132 | waveform[None],
133 | sample_rate=audio.frame_rate,
134 | )[0]
135 |
136 | # De-normalize
137 | sources = sources * ref.std() + ref.mean()
138 |
139 | # To numpy
140 | sources_np = sources.cpu().numpy().astype(waveform_np.dtype)
141 |
142 | # Convert to pydub
143 | stem_segments = [
144 | audio_util.audio_from_waveform(waveform, audio.frame_rate) for waveform in sources_np
145 | ]
146 |
147 | # Convert back to mono if necessary
148 | if audio.channels == 1:
149 | stem_segments = [stem.set_channels(1) for stem in stem_segments]
150 |
151 | return dict(zip(self.model.sources, stem_segments))
152 |
153 | def separate_sources(
154 | self,
155 | waveform: torch.Tensor,
156 | sample_rate: int = 44100,
157 | ):
158 | """
159 | Apply model to a given waveform in chunks. Use fade and overlap to smooth the edges.
160 | """
161 | batch, channels, length = waveform.shape
162 |
163 | chunk_len = int(sample_rate * self.segment_length_s * (1 + self.overlap_s))
164 | start = 0
165 | end = chunk_len
166 | overlap_frames = self.overlap_s * sample_rate
167 | fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
168 |
169 | final = torch.zeros(batch, len(self.model.sources), channels, length, device=self.device)
170 |
171 | # TODO(hayk): Improve this code, which came from the torchaudio docs
172 | while start < length - overlap_frames:
173 | chunk = waveform[:, :, start:end]
174 | with torch.no_grad():
175 | out = self.model.forward(chunk)
176 | out = fade(out)
177 | final[:, :, :, start:end] += out
178 | if start == 0:
179 | fade.fade_in_len = int(overlap_frames)
180 | start += int(chunk_len - overlap_frames)
181 | else:
182 | start += chunk_len
183 | end += chunk_len
184 | if end >= length:
185 | fade.fade_out_len = 0
186 |
187 | return final
188 |
--------------------------------------------------------------------------------
/riffusion/cli.py:
--------------------------------------------------------------------------------
1 | """
2 | Command line tools for riffusion.
3 | """
4 |
5 | import random
6 | import typing as T
7 | from multiprocessing.pool import ThreadPool
8 | from pathlib import Path
9 |
10 | import argh
11 | import numpy as np
12 | import pydub
13 | import tqdm
14 | from PIL import Image
15 |
16 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter
17 | from riffusion.spectrogram_params import SpectrogramParams
18 | from riffusion.util import image_util
19 |
20 |
21 | @argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
22 | @argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
23 | def audio_to_image(
24 | *,
25 | audio: str,
26 | image: str,
27 | step_size_ms: int = 10,
28 | num_frequencies: int = 512,
29 | min_frequency: int = 0,
30 | max_frequency: int = 10000,
31 | window_duration_ms: int = 100,
32 | padded_duration_ms: int = 400,
33 | power_for_image: float = 0.25,
34 | stereo: bool = False,
35 | device: str = "cuda",
36 | ):
37 | """
38 | Compute a spectrogram image from a waveform.
39 | """
40 | segment = pydub.AudioSegment.from_file(audio)
41 |
42 | params = SpectrogramParams(
43 | sample_rate=segment.frame_rate,
44 | stereo=stereo,
45 | window_duration_ms=window_duration_ms,
46 | padded_duration_ms=padded_duration_ms,
47 | step_size_ms=step_size_ms,
48 | min_frequency=min_frequency,
49 | max_frequency=max_frequency,
50 | num_frequencies=num_frequencies,
51 | power_for_image=power_for_image,
52 | )
53 |
54 | converter = SpectrogramImageConverter(params=params, device=device)
55 |
56 | pil_image = converter.spectrogram_image_from_audio(segment)
57 |
58 | pil_image.save(image, exif=pil_image.getexif(), format="PNG")
59 | print(f"Wrote {image}")
60 |
61 |
62 | def print_exif(*, image: str) -> None:
63 | """
64 | Print the params of a spectrogram image as saved in the exif data.
65 | """
66 | pil_image = Image.open(image)
67 | exif_data = image_util.exif_from_image(pil_image)
68 |
69 | for name, value in exif_data.items():
70 | print(f"{name:<20} = {value:>15}")
71 |
72 |
73 | def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
74 | """
75 | Reconstruct an audio clip from a spectrogram image.
76 | """
77 | pil_image = Image.open(image)
78 |
79 | # Get parameters from image exif
80 | img_exif = pil_image.getexif()
81 | assert img_exif is not None
82 |
83 | try:
84 | params = SpectrogramParams.from_exif(exif=img_exif)
85 | except (KeyError, AttributeError):
86 | print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
87 | params = SpectrogramParams()
88 |
89 | converter = SpectrogramImageConverter(params=params, device=device)
90 | segment = converter.audio_from_spectrogram_image(pil_image)
91 |
92 | extension = Path(audio).suffix[1:]
93 | segment.export(audio, format=extension)
94 |
95 | print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
96 |
97 |
98 | def sample_clips(
99 | *,
100 | audio: str,
101 | output_dir: str,
102 | num_clips: int = 1,
103 | duration_ms: int = 5120,
104 | mono: bool = False,
105 | extension: str = "wav",
106 | seed: int = -1,
107 | ):
108 | """
109 | Slice an audio file into clips of the given duration.
110 | """
111 | if seed >= 0:
112 | np.random.seed(seed)
113 |
114 | segment = pydub.AudioSegment.from_file(audio)
115 |
116 | if mono:
117 | segment = segment.set_channels(1)
118 |
119 | output_dir_path = Path(output_dir)
120 | if not output_dir_path.exists():
121 | output_dir_path.mkdir(parents=True)
122 |
123 | segment_duration_ms = int(segment.duration_seconds * 1000)
124 | for i in range(num_clips):
125 | clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
126 | clip = segment[clip_start_ms : clip_start_ms + duration_ms]
127 |
128 | clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
129 | clip_path = output_dir_path / clip_name
130 | clip.export(clip_path, format=extension)
131 | print(f"Wrote {clip_path}")
132 |
133 |
134 | def audio_to_images_batch(
135 | *,
136 | audio_dir: str,
137 | output_dir: str,
138 | image_extension: str = "jpg",
139 | step_size_ms: int = 10,
140 | num_frequencies: int = 512,
141 | min_frequency: int = 0,
142 | max_frequency: int = 10000,
143 | power_for_image: float = 0.25,
144 | mono: bool = False,
145 | sample_rate: int = 44100,
146 | device: str = "cuda",
147 | num_threads: T.Optional[int] = None,
148 | limit: int = -1,
149 | ):
150 | """
151 | Process audio clips into spectrograms in batch, multi-threaded.
152 | """
153 | audio_paths = list(Path(audio_dir).glob("*"))
154 | audio_paths.sort()
155 |
156 | if limit > 0:
157 | audio_paths = audio_paths[:limit]
158 |
159 | output_path = Path(output_dir)
160 | output_path.mkdir(parents=True, exist_ok=True)
161 |
162 | params = SpectrogramParams(
163 | step_size_ms=step_size_ms,
164 | num_frequencies=num_frequencies,
165 | min_frequency=min_frequency,
166 | max_frequency=max_frequency,
167 | power_for_image=power_for_image,
168 | stereo=not mono,
169 | sample_rate=sample_rate,
170 | )
171 |
172 | converter = SpectrogramImageConverter(params=params, device=device)
173 |
174 | def process_one(audio_path: Path) -> None:
175 | # Load
176 | try:
177 | segment = pydub.AudioSegment.from_file(str(audio_path))
178 | except Exception:
179 | return
180 |
181 | # TODO(hayk): Sanity checks on clip
182 |
183 | if mono and segment.channels != 1:
184 | segment = segment.set_channels(1)
185 | elif not mono and segment.channels != 2:
186 | segment = segment.set_channels(2)
187 |
188 | # Frame rate
189 | if segment.frame_rate != params.sample_rate:
190 | segment = segment.set_frame_rate(params.sample_rate)
191 |
192 | # Convert
193 | image = converter.spectrogram_image_from_audio(segment)
194 |
195 | # Save
196 | image_path = output_path / f"{audio_path.stem}.{image_extension}"
197 | image_format = {"jpg": "JPEG", "jpeg": "JPEG", "png": "PNG"}[image_extension]
198 | image.save(image_path, exif=image.getexif(), format=image_format)
199 |
200 | # Create thread pool
201 | pool = ThreadPool(processes=num_threads)
202 | with tqdm.tqdm(total=len(audio_paths)) as pbar:
203 | for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)):
204 | pbar.update()
205 |
206 |
207 | def sample_clips_batch(
208 | *,
209 | audio_dir: str,
210 | output_dir: str,
211 | num_clips_per_file: int = 1,
212 | duration_ms: int = 5120,
213 | mono: bool = False,
214 | extension: str = "mp3",
215 | num_threads: T.Optional[int] = None,
216 | glob: str = "*",
217 | limit: int = -1,
218 | seed: int = -1,
219 | ):
220 | """
221 | Sample short clips from a directory of audio files, multi-threaded.
222 | """
223 | audio_paths = list(Path(audio_dir).glob(glob))
224 | audio_paths.sort()
225 |
226 | # Exclude json
227 | audio_paths = [p for p in audio_paths if p.suffix != ".json"]
228 |
229 | if limit > 0:
230 | audio_paths = audio_paths[:limit]
231 |
232 | output_path = Path(output_dir)
233 | output_path.mkdir(parents=True, exist_ok=True)
234 |
235 | if seed >= 0:
236 | random.seed(seed)
237 |
238 | def process_one(audio_path: Path) -> None:
239 | try:
240 | segment = pydub.AudioSegment.from_file(str(audio_path))
241 | except Exception:
242 | return
243 |
244 | if mono:
245 | segment = segment.set_channels(1)
246 |
247 | segment_duration_ms = int(segment.duration_seconds * 1000)
248 | for i in range(num_clips_per_file):
249 | try:
250 | clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
251 | except ValueError:
252 | continue
253 |
254 | clip = segment[clip_start_ms : clip_start_ms + duration_ms]
255 |
256 | clip_name = (
257 | f"{audio_path.stem}_{i}_"
258 | f"start_{clip_start_ms}_ms_dur_{duration_ms}_ms.{extension}"
259 | )
260 | clip.export(output_path / clip_name, format=extension)
261 |
262 | pool = ThreadPool(processes=num_threads)
263 | with tqdm.tqdm(total=len(audio_paths)) as pbar:
264 | for result in pool.imap_unordered(process_one, audio_paths):
265 | pbar.update()
266 |
267 |
268 | if __name__ == "__main__":
269 | argh.dispatch_commands(
270 | [
271 | audio_to_image,
272 | image_to_audio,
273 | sample_clips,
274 | print_exif,
275 | audio_to_images_batch,
276 | sample_clips_batch,
277 | ]
278 | )
279 |
--------------------------------------------------------------------------------
/riffusion/datatypes.py:
--------------------------------------------------------------------------------
1 | """
2 | Data model for the riffusion API.
3 | """
4 | from __future__ import annotations
5 |
6 | import typing as T
7 | from dataclasses import dataclass
8 |
9 |
10 | @dataclass(frozen=True)
11 | class PromptInput:
12 | """
13 | Parameters for one end of interpolation.
14 | """
15 |
16 | # Text prompt fed into a CLIP model
17 | prompt: str
18 |
19 | # Random seed for denoising
20 | seed: int
21 |
22 | # Negative prompt to avoid (optional)
23 | negative_prompt: T.Optional[str] = None
24 |
25 | # Denoising strength
26 | denoising: float = 0.75
27 |
28 | # Classifier-free guidance strength
29 | guidance: float = 7.0
30 |
31 |
32 | @dataclass(frozen=True)
33 | class InferenceInput:
34 | """
35 | Parameters for a single run of the riffusion model, interpolating between
36 | a start and end set of PromptInputs. This is the API required for a request
37 | to the model server.
38 | """
39 |
40 | # Start point of interpolation
41 | start: PromptInput
42 |
43 | # End point of interpolation
44 | end: PromptInput
45 |
46 | # Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
47 | # uses end fully.
48 | alpha: float
49 |
50 | # Number of inner loops of the diffusion model
51 | num_inference_steps: int = 50
52 |
53 | # Which seed image to use
54 | seed_image_id: str = "og_beat"
55 |
56 | # ID of mask image to use
57 | mask_image_id: T.Optional[str] = None
58 |
59 |
60 | @dataclass(frozen=True)
61 | class InferenceOutput:
62 | """
63 | Response from the model inference server.
64 | """
65 |
66 | # base64 encoded spectrogram image as a JPEG
67 | image: str
68 |
69 | # base64 encoded audio clip as an MP3
70 | audio: str
71 |
72 | # The duration of the audio clip
73 | duration_s: float
74 |
--------------------------------------------------------------------------------
/riffusion/external/README.md:
--------------------------------------------------------------------------------
1 | # external
2 |
3 | This package contains scripts and tools from external sources.
4 |
--------------------------------------------------------------------------------
/riffusion/external/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/external/__init__.py
--------------------------------------------------------------------------------
/riffusion/external/prompt_weighting.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is taken from the diffusers community pipeline:
3 |
4 | https://github.com/huggingface/diffusers/blob/f242eba4fdc5b76dc40d3a9c01ba49b2c74b9796/examples/community/lpw_stable_diffusion.py
5 |
6 | License: Apache 2.0
7 | """
8 | # ruff: noqa
9 | # mypy: ignore-errors
10 |
11 | import logging
12 | import re
13 | import typing as T
14 |
15 | import torch
16 |
17 | from diffusers import StableDiffusionPipeline
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | re_attention = re.compile(
24 | r"""
25 | \\\(|
26 | \\\)|
27 | \\\[|
28 | \\]|
29 | \\\\|
30 | \\|
31 | \(|
32 | \[|
33 | :([+-]?[.\d]+)\)|
34 | \)|
35 | ]|
36 | [^\\()\[\]:]+|
37 | :
38 | """,
39 | re.X,
40 | )
41 |
42 |
43 | def parse_prompt_attention(text):
44 | """
45 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
46 | Accepted tokens are:
47 | (abc) - increases attention to abc by a multiplier of 1.1
48 | (abc:3.12) - increases attention to abc by a multiplier of 3.12
49 | [abc] - decreases attention to abc by a multiplier of 1.1
50 | \( - literal character '('
51 | \[ - literal character '['
52 | \) - literal character ')'
53 | \] - literal character ']'
54 | \\ - literal character '\'
55 | anything else - just text
56 | >>> parse_prompt_attention('normal text')
57 | [['normal text', 1.0]]
58 | >>> parse_prompt_attention('an (important) word')
59 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
60 | >>> parse_prompt_attention('(unbalanced')
61 | [['unbalanced', 1.1]]
62 | >>> parse_prompt_attention('\(literal\]')
63 | [['(literal]', 1.0]]
64 | >>> parse_prompt_attention('(unnecessary)(parens)')
65 | [['unnecessaryparens', 1.1]]
66 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
67 | [['a ', 1.0],
68 | ['house', 1.5730000000000004],
69 | [' ', 1.1],
70 | ['on', 1.0],
71 | [' a ', 1.1],
72 | ['hill', 0.55],
73 | [', sun, ', 1.1],
74 | ['sky', 1.4641000000000006],
75 | ['.', 1.1]]
76 | """
77 |
78 | res = []
79 | round_brackets = []
80 | square_brackets = []
81 |
82 | round_bracket_multiplier = 1.1
83 | square_bracket_multiplier = 1 / 1.1
84 |
85 | def multiply_range(start_position, multiplier):
86 | for p in range(start_position, len(res)):
87 | res[p][1] *= multiplier
88 |
89 | for m in re_attention.finditer(text):
90 | text = m.group(0)
91 | weight = m.group(1)
92 |
93 | if text.startswith("\\"):
94 | res.append([text[1:], 1.0])
95 | elif text == "(":
96 | round_brackets.append(len(res))
97 | elif text == "[":
98 | square_brackets.append(len(res))
99 | elif weight is not None and len(round_brackets) > 0:
100 | multiply_range(round_brackets.pop(), float(weight))
101 | elif text == ")" and len(round_brackets) > 0:
102 | multiply_range(round_brackets.pop(), round_bracket_multiplier)
103 | elif text == "]" and len(square_brackets) > 0:
104 | multiply_range(square_brackets.pop(), square_bracket_multiplier)
105 | else:
106 | res.append([text, 1.0])
107 |
108 | for pos in round_brackets:
109 | multiply_range(pos, round_bracket_multiplier)
110 |
111 | for pos in square_brackets:
112 | multiply_range(pos, square_bracket_multiplier)
113 |
114 | if len(res) == 0:
115 | res = [["", 1.0]]
116 |
117 | # merge runs of identical weights
118 | i = 0
119 | while i + 1 < len(res):
120 | if res[i][1] == res[i + 1][1]:
121 | res[i][0] += res[i + 1][0]
122 | res.pop(i + 1)
123 | else:
124 | i += 1
125 |
126 | return res
127 |
128 |
129 | def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
130 | r"""
131 | Tokenize a list of prompts and return its tokens with weights of each token.
132 | No padding, starting or ending token is included.
133 | """
134 | tokens = []
135 | weights = []
136 | truncated = False
137 | for text in prompt:
138 | texts_and_weights = parse_prompt_attention(text)
139 | text_token = []
140 | text_weight = []
141 | for word, weight in texts_and_weights:
142 | # tokenize and discard the starting and the ending token
143 | token = pipe.tokenizer(word).input_ids[1:-1]
144 | text_token += token
145 | # copy the weight by length of token
146 | text_weight += [weight] * len(token)
147 | # stop if the text is too long (longer than truncation limit)
148 | if len(text_token) > max_length:
149 | truncated = True
150 | break
151 | # truncate
152 | if len(text_token) > max_length:
153 | truncated = True
154 | text_token = text_token[:max_length]
155 | text_weight = text_weight[:max_length]
156 | tokens.append(text_token)
157 | weights.append(text_weight)
158 | if truncated:
159 | logger.warning(
160 | "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
161 | )
162 | return tokens, weights
163 |
164 |
165 | def pad_tokens_and_weights(
166 | tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77
167 | ):
168 | r"""
169 | Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
170 | """
171 | max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
172 | weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
173 | for i in range(len(tokens)):
174 | tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
175 | if no_boseos_middle:
176 | weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
177 | else:
178 | w = []
179 | if len(weights[i]) == 0:
180 | w = [1.0] * weights_length
181 | else:
182 | for j in range(max_embeddings_multiples):
183 | w.append(1.0) # weight for starting token in this chunk
184 | w += weights[i][
185 | j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))
186 | ]
187 | w.append(1.0) # weight for ending token in this chunk
188 | w += [1.0] * (weights_length - len(w))
189 | weights[i] = w[:]
190 |
191 | return tokens, weights
192 |
193 |
194 | def get_unweighted_text_embeddings(
195 | pipe: StableDiffusionPipeline,
196 | text_input: torch.Tensor,
197 | chunk_length: int,
198 | no_boseos_middle: T.Optional[bool] = True,
199 | ) -> torch.FloatTensor:
200 | """
201 | When the length of tokens is a multiple of the capacity of the text encoder,
202 | it should be split into chunks and sent to the text encoder individually.
203 | """
204 | max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
205 | if max_embeddings_multiples > 1:
206 | text_embeddings = []
207 | for i in range(max_embeddings_multiples):
208 | # extract the i-th chunk
209 | text_input_chunk = text_input[
210 | :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
211 | ].clone()
212 |
213 | # cover the head and the tail by the starting and the ending tokens
214 | text_input_chunk[:, 0] = text_input[0, 0]
215 | text_input_chunk[:, -1] = text_input[0, -1]
216 | text_embedding = pipe.text_encoder(text_input_chunk)[0]
217 |
218 | if no_boseos_middle:
219 | if i == 0:
220 | # discard the ending token
221 | text_embedding = text_embedding[:, :-1]
222 | elif i == max_embeddings_multiples - 1:
223 | # discard the starting token
224 | text_embedding = text_embedding[:, 1:]
225 | else:
226 | # discard both starting and ending tokens
227 | text_embedding = text_embedding[:, 1:-1]
228 |
229 | text_embeddings.append(text_embedding)
230 | text_embeddings = torch.concat(text_embeddings, axis=1)
231 | else:
232 | text_embeddings = pipe.text_encoder(text_input)[0]
233 | return text_embeddings
234 |
235 |
236 | def get_weighted_text_embeddings(
237 | pipe: StableDiffusionPipeline,
238 | prompt: T.Union[str, T.List[str]],
239 | uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
240 | max_embeddings_multiples: T.Optional[int] = 3,
241 | no_boseos_middle: T.Optional[bool] = False,
242 | skip_parsing: T.Optional[bool] = False,
243 | skip_weighting: T.Optional[bool] = False,
244 | **kwargs,
245 | ) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
246 | r"""
247 | Prompts can be assigned with local weights using brackets. For example,
248 | prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
249 | and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
250 | Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
251 | Args:
252 | pipe (`StableDiffusionPipeline`):
253 | Pipe to provide access to the tokenizer and the text encoder.
254 | prompt (`str` or `T.List[str]`):
255 | The prompt or prompts to guide the image generation.
256 | uncond_prompt (`str` or `T.List[str]`):
257 | The unconditional prompt or prompts for guide the image generation. If unconditional prompt
258 | is provided, the embeddings of prompt and uncond_prompt are concatenated.
259 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
260 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
261 | no_boseos_middle (`bool`, *optional*, defaults to `False`):
262 | If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
263 | ending token in each of the chunk in the middle.
264 | skip_parsing (`bool`, *optional*, defaults to `False`):
265 | Skip the parsing of brackets.
266 | skip_weighting (`bool`, *optional*, defaults to `False`):
267 | Skip the weighting. When the parsing is skipped, it is forced True.
268 | """
269 | max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
270 | if isinstance(prompt, str):
271 | prompt = [prompt]
272 |
273 | if not skip_parsing:
274 | prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
275 |
276 | if uncond_prompt is not None:
277 | if isinstance(uncond_prompt, str):
278 | uncond_prompt = [uncond_prompt]
279 | uncond_tokens, uncond_weights = get_prompts_with_weights(
280 | pipe, uncond_prompt, max_length - 2
281 | )
282 | else:
283 | prompt_tokens = [
284 | token[1:-1]
285 | for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
286 | ]
287 | prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
288 | if uncond_prompt is not None:
289 | if isinstance(uncond_prompt, str):
290 | uncond_prompt = [uncond_prompt]
291 | uncond_tokens = [
292 | token[1:-1]
293 | for token in pipe.tokenizer(
294 | uncond_prompt, max_length=max_length, truncation=True
295 | ).input_ids
296 | ]
297 | uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
298 |
299 | # round up the longest length of tokens to a multiple of (model_max_length - 2)
300 | max_length = max([len(token) for token in prompt_tokens])
301 | if uncond_prompt is not None:
302 | max_length = max(max_length, max([len(token) for token in uncond_tokens]))
303 |
304 | max_embeddings_multiples = min(
305 | max_embeddings_multiples,
306 | (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
307 | )
308 | max_embeddings_multiples = max(1, max_embeddings_multiples)
309 | max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
310 |
311 | # pad the length of tokens and weights
312 | bos = pipe.tokenizer.bos_token_id
313 | eos = pipe.tokenizer.eos_token_id
314 | prompt_tokens, prompt_weights = pad_tokens_and_weights(
315 | prompt_tokens,
316 | prompt_weights,
317 | max_length,
318 | bos,
319 | eos,
320 | no_boseos_middle=no_boseos_middle,
321 | chunk_length=pipe.tokenizer.model_max_length,
322 | )
323 | prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
324 | if uncond_prompt is not None:
325 | uncond_tokens, uncond_weights = pad_tokens_and_weights(
326 | uncond_tokens,
327 | uncond_weights,
328 | max_length,
329 | bos,
330 | eos,
331 | no_boseos_middle=no_boseos_middle,
332 | chunk_length=pipe.tokenizer.model_max_length,
333 | )
334 | uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
335 |
336 | # get the embeddings
337 | text_embeddings = get_unweighted_text_embeddings(
338 | pipe,
339 | prompt_tokens,
340 | pipe.tokenizer.model_max_length,
341 | no_boseos_middle=no_boseos_middle,
342 | )
343 | prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
344 | if uncond_prompt is not None:
345 | uncond_embeddings = get_unweighted_text_embeddings(
346 | pipe,
347 | uncond_tokens,
348 | pipe.tokenizer.model_max_length,
349 | no_boseos_middle=no_boseos_middle,
350 | )
351 | uncond_weights = torch.tensor(
352 | uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device
353 | )
354 |
355 | # assign weights to the prompts and normalize in the sense of mean
356 | # TODO: should we normalize by chunk or in a whole (current implementation)?
357 | if (not skip_parsing) and (not skip_weighting):
358 | previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
359 | text_embeddings *= prompt_weights.unsqueeze(-1)
360 | current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
361 | text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
362 | if uncond_prompt is not None:
363 | previous_mean = (
364 | uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
365 | )
366 | uncond_embeddings *= uncond_weights.unsqueeze(-1)
367 | current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
368 | uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
369 |
370 | if uncond_prompt is not None:
371 | return text_embeddings, uncond_embeddings
372 | return text_embeddings, None
373 |
--------------------------------------------------------------------------------
/riffusion/py.typed:
--------------------------------------------------------------------------------
1 | # https://peps.python.org/pep-0561/
2 |
--------------------------------------------------------------------------------
/riffusion/riffusion_pipeline.py:
--------------------------------------------------------------------------------
1 | """
2 | Riffusion inference pipeline.
3 | """
4 | from __future__ import annotations
5 |
6 | import dataclasses
7 | import functools
8 | import inspect
9 | import typing as T
10 |
11 | import numpy as np
12 | import torch
13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
14 | from diffusers.pipeline_utils import DiffusionPipeline
15 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
16 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
17 | from diffusers.utils import logging
18 | from huggingface_hub import hf_hub_download
19 | from PIL import Image
20 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21 |
22 | from riffusion.datatypes import InferenceInput
23 | from riffusion.external.prompt_weighting import get_weighted_text_embeddings
24 | from riffusion.util import torch_util
25 |
26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27 |
28 |
29 | class RiffusionPipeline(DiffusionPipeline):
30 | """
31 | Diffusers pipeline for doing a controlled img2img interpolation for audio generation.
32 |
33 | # TODO(hayk): Document more
34 |
35 | Part of this code was adapted from the non-img2img interpolation pipeline at:
36 |
37 | https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py
38 |
39 | Check the documentation for DiffusionPipeline for full information.
40 | """
41 |
42 | def __init__(
43 | self,
44 | vae: AutoencoderKL,
45 | text_encoder: CLIPTextModel,
46 | tokenizer: CLIPTokenizer,
47 | unet: UNet2DConditionModel,
48 | scheduler: T.Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
49 | safety_checker: StableDiffusionSafetyChecker,
50 | feature_extractor: CLIPFeatureExtractor,
51 | ):
52 | super().__init__()
53 | self.register_modules(
54 | vae=vae,
55 | text_encoder=text_encoder,
56 | tokenizer=tokenizer,
57 | unet=unet,
58 | scheduler=scheduler,
59 | safety_checker=safety_checker,
60 | feature_extractor=feature_extractor,
61 | )
62 |
63 | @classmethod
64 | def load_checkpoint(
65 | cls,
66 | checkpoint: str,
67 | use_traced_unet: bool = True,
68 | channels_last: bool = False,
69 | dtype: torch.dtype = torch.float16,
70 | device: str = "cuda",
71 | local_files_only: bool = False,
72 | low_cpu_mem_usage: bool = False,
73 | cache_dir: T.Optional[str] = None,
74 | ) -> RiffusionPipeline:
75 | """
76 | Load the riffusion model pipeline.
77 |
78 | Args:
79 | checkpoint: Model checkpoint on disk in diffusers format
80 | use_traced_unet: Whether to use the traced unet for speedups
81 | device: Device to load the model on
82 | channels_last: Whether to use channels_last memory format
83 | local_files_only: Don't download, only use local files
84 | low_cpu_mem_usage: Attempt to use less memory on CPU
85 | """
86 | device = torch_util.check_device(device)
87 |
88 | if device == "cpu" or device.lower().startswith("mps"):
89 | print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
90 | dtype = torch.float32
91 |
92 | pipeline = RiffusionPipeline.from_pretrained(
93 | checkpoint,
94 | revision="main",
95 | torch_dtype=dtype,
96 | # Disable the NSFW filter, causes incorrect false positives
97 | # TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
98 | safety_checker=lambda images, **kwargs: (images, False),
99 | low_cpu_mem_usage=low_cpu_mem_usage,
100 | local_files_only=local_files_only,
101 | cache_dir=cache_dir,
102 | ).to(device)
103 |
104 | if channels_last:
105 | pipeline.unet.to(memory_format=torch.channels_last)
106 |
107 | # Optionally load a traced unet
108 | if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
109 | traced_unet = cls.load_traced_unet(
110 | checkpoint=checkpoint,
111 | subfolder="unet_traced",
112 | filename="unet_traced.pt",
113 | in_channels=pipeline.unet.in_channels,
114 | dtype=dtype,
115 | device=device,
116 | local_files_only=local_files_only,
117 | cache_dir=cache_dir,
118 | )
119 |
120 | if traced_unet is not None:
121 | pipeline.unet = traced_unet
122 |
123 | model = pipeline.to(device)
124 |
125 | return model
126 |
127 | @staticmethod
128 | def load_traced_unet(
129 | checkpoint: str,
130 | subfolder: str,
131 | filename: str,
132 | in_channels: int,
133 | dtype: torch.dtype,
134 | device: str = "cuda",
135 | local_files_only=False,
136 | cache_dir: T.Optional[str] = None,
137 | ) -> T.Optional[torch.nn.Module]:
138 | """
139 | Load a traced unet from the huggingface hub. This can improve performance.
140 | """
141 | if device == "cpu" or device.lower().startswith("mps"):
142 | print("WARNING: Traced UNet only available for CUDA, skipping")
143 | return None
144 |
145 | # Download and load the traced unet
146 | unet_file = hf_hub_download(
147 | checkpoint,
148 | subfolder=subfolder,
149 | filename=filename,
150 | local_files_only=local_files_only,
151 | cache_dir=cache_dir,
152 | )
153 | unet_traced = torch.jit.load(unet_file)
154 |
155 | # Wrap it in a torch module
156 | class TracedUNet(torch.nn.Module):
157 | @dataclasses.dataclass
158 | class UNet2DConditionOutput:
159 | sample: torch.FloatTensor
160 |
161 | def __init__(self):
162 | super().__init__()
163 | self.in_channels = device
164 | self.device = device
165 | self.dtype = dtype
166 |
167 | def forward(self, latent_model_input, t, encoder_hidden_states):
168 | sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
169 | return self.UNet2DConditionOutput(sample=sample)
170 |
171 | return TracedUNet()
172 |
173 | @property
174 | def device(self) -> str:
175 | return str(self.vae.device)
176 |
177 | @functools.lru_cache()
178 | def embed_text(self, text) -> torch.FloatTensor:
179 | """
180 | Takes in text and turns it into text embeddings.
181 | """
182 | text_input = self.tokenizer(
183 | text,
184 | padding="max_length",
185 | max_length=self.tokenizer.model_max_length,
186 | truncation=True,
187 | return_tensors="pt",
188 | )
189 | with torch.no_grad():
190 | embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
191 | return embed
192 |
193 | @functools.lru_cache()
194 | def embed_text_weighted(self, text) -> torch.FloatTensor:
195 | """
196 | Get text embedding with weights.
197 | """
198 | return get_weighted_text_embeddings(
199 | pipe=self,
200 | prompt=text,
201 | uncond_prompt=None,
202 | max_embeddings_multiples=3,
203 | no_boseos_middle=False,
204 | skip_parsing=False,
205 | skip_weighting=False,
206 | )[0]
207 |
208 | @torch.no_grad()
209 | def riffuse(
210 | self,
211 | inputs: InferenceInput,
212 | init_image: Image.Image,
213 | mask_image: T.Optional[Image.Image] = None,
214 | use_reweighting: bool = True,
215 | ) -> Image.Image:
216 | """
217 | Runs inference using interpolation with both img2img and text conditioning.
218 |
219 | Args:
220 | inputs: Parameter dataclass
221 | init_image: Image used for conditioning
222 | mask_image: White pixels in the mask will be replaced by noise and therefore repainted,
223 | while black pixels will be preserved. It will be converted to a single
224 | channel (luminance) before use.
225 | use_reweighting: Use prompt reweighting
226 | """
227 | alpha = inputs.alpha
228 | start = inputs.start
229 | end = inputs.end
230 |
231 | guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
232 |
233 | # TODO(hayk): Always generate the seed on CPU?
234 | if self.device.lower().startswith("mps"):
235 | generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
236 | generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
237 | else:
238 | generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
239 | generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
240 |
241 | # Text encodings
242 | if use_reweighting:
243 | embed_start = self.embed_text_weighted(start.prompt)
244 | embed_end = self.embed_text_weighted(end.prompt)
245 | else:
246 | embed_start = self.embed_text(start.prompt)
247 | embed_end = self.embed_text(end.prompt)
248 |
249 | text_embedding = embed_start + alpha * (embed_end - embed_start)
250 |
251 | # Image latents
252 | init_image_torch = preprocess_image(init_image).to(
253 | device=self.device, dtype=embed_start.dtype
254 | )
255 | init_latent_dist = self.vae.encode(init_image_torch).latent_dist
256 | # TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
257 | # result is so close no matter the seed that it doesn't really add variety.
258 | if self.device.lower().startswith("mps"):
259 | generator = torch.Generator(device="cpu").manual_seed(start.seed)
260 | else:
261 | generator = torch.Generator(device=self.device).manual_seed(start.seed)
262 |
263 | init_latents = init_latent_dist.sample(generator=generator)
264 | init_latents = 0.18215 * init_latents
265 |
266 | # Prepare mask latent
267 | mask: T.Optional[torch.Tensor] = None
268 | if mask_image:
269 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
270 | mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
271 | device=self.device, dtype=embed_start.dtype
272 | )
273 |
274 | outputs = self.interpolate_img2img(
275 | text_embeddings=text_embedding,
276 | init_latents=init_latents,
277 | mask=mask,
278 | generator_a=generator_start,
279 | generator_b=generator_end,
280 | interpolate_alpha=alpha,
281 | strength_a=start.denoising,
282 | strength_b=end.denoising,
283 | num_inference_steps=inputs.num_inference_steps,
284 | guidance_scale=guidance_scale,
285 | )
286 |
287 | return outputs["images"][0]
288 |
289 | @torch.no_grad()
290 | def interpolate_img2img(
291 | self,
292 | text_embeddings: torch.Tensor,
293 | init_latents: torch.Tensor,
294 | generator_a: torch.Generator,
295 | generator_b: torch.Generator,
296 | interpolate_alpha: float,
297 | mask: T.Optional[torch.Tensor] = None,
298 | strength_a: float = 0.8,
299 | strength_b: float = 0.8,
300 | num_inference_steps: int = 50,
301 | guidance_scale: float = 7.5,
302 | negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
303 | num_images_per_prompt: int = 1,
304 | eta: T.Optional[float] = 0.0,
305 | output_type: T.Optional[str] = "pil",
306 | **kwargs,
307 | ):
308 | """
309 | TODO
310 | """
311 | batch_size = text_embeddings.shape[0]
312 |
313 | # set timesteps
314 | self.scheduler.set_timesteps(num_inference_steps)
315 |
316 | # duplicate text embeddings for each generation per prompt, using mps friendly method
317 | bs_embed, seq_len, _ = text_embeddings.shape
318 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
319 | text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
320 |
321 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
322 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
323 | # corresponds to doing no classifier free guidance.
324 | do_classifier_free_guidance = guidance_scale > 1.0
325 | # get unconditional embeddings for classifier free guidance
326 | if do_classifier_free_guidance:
327 | if negative_prompt is None:
328 | uncond_tokens = [""]
329 | elif isinstance(negative_prompt, str):
330 | uncond_tokens = [negative_prompt]
331 | elif batch_size != len(negative_prompt):
332 | raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
333 | else:
334 | uncond_tokens = negative_prompt
335 |
336 | # max_length = text_input_ids.shape[-1]
337 | uncond_input = self.tokenizer(
338 | uncond_tokens,
339 | padding="max_length",
340 | max_length=self.tokenizer.model_max_length,
341 | truncation=True,
342 | return_tensors="pt",
343 | )
344 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
345 |
346 | # duplicate unconditional embeddings for each generation per prompt
347 | uncond_embeddings = uncond_embeddings.repeat_interleave(
348 | batch_size * num_images_per_prompt, dim=0
349 | )
350 |
351 | # For classifier free guidance, we need to do two forward passes.
352 | # Here we concatenate the unconditional and text embeddings into a single batch
353 | # to avoid doing two forward passes
354 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
355 |
356 | latents_dtype = text_embeddings.dtype
357 |
358 | strength = (1 - interpolate_alpha) * strength_a + interpolate_alpha * strength_b
359 |
360 | # get the original timestep using init_timestep
361 | offset = self.scheduler.config.get("steps_offset", 0)
362 | init_timestep = int(num_inference_steps * strength) + offset
363 | init_timestep = min(init_timestep, num_inference_steps)
364 |
365 | timesteps = self.scheduler.timesteps[-init_timestep]
366 | timesteps = torch.tensor(
367 | [timesteps] * batch_size * num_images_per_prompt, device=self.device
368 | )
369 |
370 | # add noise to latents using the timesteps
371 | noise_a = torch.randn(
372 | init_latents.shape, generator=generator_a, device=self.device, dtype=latents_dtype
373 | )
374 | noise_b = torch.randn(
375 | init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
376 | )
377 | noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
378 | init_latents_orig = init_latents
379 | init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
380 |
381 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same args
382 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
383 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
384 | # and should be between [0, 1]
385 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
386 | extra_step_kwargs = {}
387 | if accepts_eta:
388 | extra_step_kwargs["eta"] = eta
389 |
390 | latents = init_latents.clone()
391 |
392 | t_start = max(num_inference_steps - init_timestep + offset, 0)
393 |
394 | # Some schedulers like PNDM have timesteps as arrays
395 | # It's more optimized to move all timesteps to correct device beforehand
396 | timesteps = self.scheduler.timesteps[t_start:].to(self.device)
397 |
398 | for i, t in enumerate(self.progress_bar(timesteps)):
399 | # expand the latents if we are doing classifier free guidance
400 | latent_model_input = (
401 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents
402 | )
403 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
404 |
405 | # predict the noise residual
406 | noise_pred = self.unet(
407 | latent_model_input, t, encoder_hidden_states=text_embeddings
408 | ).sample
409 |
410 | # perform guidance
411 | if do_classifier_free_guidance:
412 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
413 | noise_pred = noise_pred_uncond + guidance_scale * (
414 | noise_pred_text - noise_pred_uncond
415 | )
416 |
417 | # compute the previous noisy sample x_t -> x_t-1
418 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
419 |
420 | if mask is not None:
421 | init_latents_proper = self.scheduler.add_noise(
422 | init_latents_orig, noise, torch.tensor([t])
423 | )
424 | # import ipdb; ipdb.set_trace()
425 | latents = (init_latents_proper * mask) + (latents * (1 - mask))
426 |
427 | latents = 1.0 / 0.18215 * latents
428 | image = self.vae.decode(latents).sample
429 |
430 | image = (image / 2 + 0.5).clamp(0, 1)
431 | image = image.cpu().permute(0, 2, 3, 1).numpy()
432 |
433 | if output_type == "pil":
434 | image = self.numpy_to_pil(image)
435 |
436 | return dict(images=image, latents=latents, nsfw_content_detected=False)
437 |
438 |
439 | def preprocess_image(image: Image.Image) -> torch.Tensor:
440 | """
441 | Preprocess an image for the model.
442 | """
443 | w, h = image.size
444 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
445 | image = image.resize((w, h), resample=Image.LANCZOS)
446 |
447 | image_np = np.array(image).astype(np.float32) / 255.0
448 | image_np = image_np[None].transpose(0, 3, 1, 2)
449 |
450 | image_torch = torch.from_numpy(image_np)
451 |
452 | return 2.0 * image_torch - 1.0
453 |
454 |
455 | def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
456 | """
457 | Preprocess a mask for the model.
458 | """
459 | # Convert to grayscale
460 | mask = mask.convert("L")
461 |
462 | # Resize to integer multiple of 32
463 | w, h = mask.size
464 | w, h = map(lambda x: x - x % 32, (w, h))
465 | mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
466 |
467 | # Convert to numpy array and rescale
468 | mask_np = np.array(mask).astype(np.float32) / 255.0
469 |
470 | # Tile and transpose
471 | mask_np = np.tile(mask_np, (4, 1, 1))
472 | mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
473 |
474 | # Invert to repaint white and keep black
475 | mask_np = 1 - mask_np # repaint white, keep black
476 |
477 | return torch.from_numpy(mask_np)
478 |
--------------------------------------------------------------------------------
/riffusion/server.py:
--------------------------------------------------------------------------------
1 | """
2 | Flask server that serves the riffusion model as an API.
3 | """
4 |
5 | import dataclasses
6 | import io
7 | import json
8 | import logging
9 | import time
10 | import typing as T
11 | from pathlib import Path
12 |
13 | import dacite
14 | import flask
15 | import PIL
16 | from flask_cors import CORS
17 |
18 | from riffusion.datatypes import InferenceInput, InferenceOutput
19 | from riffusion.riffusion_pipeline import RiffusionPipeline
20 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter
21 | from riffusion.spectrogram_params import SpectrogramParams
22 | from riffusion.util import base64_util
23 |
24 | # Flask app with CORS
25 | app = flask.Flask(__name__)
26 | CORS(app)
27 |
28 | # Log at the INFO level to both stdout and disk
29 | logging.basicConfig(level=logging.INFO)
30 | logging.getLogger().addHandler(logging.FileHandler("server.log"))
31 |
32 | # Global variable for the model pipeline
33 | PIPELINE: T.Optional[RiffusionPipeline] = None
34 |
35 | # Where built-in seed images are stored
36 | SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
37 |
38 |
39 | def run_app(
40 | *,
41 | checkpoint: str = "riffusion/riffusion-model-v1",
42 | no_traced_unet: bool = False,
43 | device: str = "cuda",
44 | host: str = "127.0.0.1",
45 | port: int = 3013,
46 | debug: bool = False,
47 | ssl_certificate: T.Optional[str] = None,
48 | ssl_key: T.Optional[str] = None,
49 | ):
50 | """
51 | Run a flask API that serves the given riffusion model checkpoint.
52 | """
53 | # Initialize the model
54 | global PIPELINE
55 | PIPELINE = RiffusionPipeline.load_checkpoint(
56 | checkpoint=checkpoint,
57 | use_traced_unet=not no_traced_unet,
58 | device=device,
59 | )
60 |
61 | args = dict(
62 | debug=debug,
63 | threaded=False,
64 | host=host,
65 | port=port,
66 | )
67 |
68 | if ssl_certificate:
69 | assert ssl_key is not None
70 | args["ssl_context"] = (ssl_certificate, ssl_key)
71 |
72 | app.run(**args) # type: ignore
73 |
74 |
75 | @app.route("/run_inference/", methods=["POST"])
76 | def run_inference():
77 | """
78 | Execute the riffusion model as an API.
79 |
80 | Inputs:
81 | Serialized JSON of the InferenceInput dataclass
82 |
83 | Returns:
84 | Serialized JSON of the InferenceOutput dataclass
85 | """
86 | start_time = time.time()
87 |
88 | # Parse the payload as JSON
89 | json_data = json.loads(flask.request.data)
90 |
91 | # Log the request
92 | logging.info(json_data)
93 |
94 | # Parse an InferenceInput dataclass from the payload
95 | try:
96 | inputs = dacite.from_dict(InferenceInput, json_data)
97 | except dacite.exceptions.WrongTypeError as exception:
98 | logging.info(json_data)
99 | return str(exception), 400
100 | except dacite.exceptions.MissingValueError as exception:
101 | logging.info(json_data)
102 | return str(exception), 400
103 |
104 | response = compute_request(
105 | inputs=inputs,
106 | seed_images_dir=SEED_IMAGES_DIR,
107 | pipeline=PIPELINE,
108 | )
109 |
110 | # Log the total time
111 | logging.info(f"Request took {time.time() - start_time:.2f} s")
112 |
113 | return response
114 |
115 |
116 | def compute_request(
117 | inputs: InferenceInput,
118 | pipeline: RiffusionPipeline,
119 | seed_images_dir: str,
120 | ) -> T.Union[str, T.Tuple[str, int]]:
121 | """
122 | Does all the heavy lifting of the request.
123 |
124 | Args:
125 | inputs: The input dataclass
126 | pipeline: The riffusion model pipeline
127 | seed_images_dir: The directory where seed images are stored
128 | """
129 | # Load the seed image by ID
130 | init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
131 |
132 | if not init_image_path.is_file():
133 | return f"Invalid seed image: {inputs.seed_image_id}", 400
134 | init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
135 |
136 | # Load the mask image by ID
137 | mask_image: T.Optional[PIL.Image.Image] = None
138 | if inputs.mask_image_id:
139 | mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
140 | if not mask_image_path.is_file():
141 | return f"Invalid mask image: {inputs.mask_image_id}", 400
142 | mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
143 |
144 | # Execute the model to get the spectrogram image
145 | image = pipeline.riffuse(
146 | inputs,
147 | init_image=init_image,
148 | mask_image=mask_image,
149 | )
150 |
151 | # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
152 | params = SpectrogramParams(
153 | min_frequency=0,
154 | max_frequency=10000,
155 | )
156 |
157 | # Reconstruct audio from the image
158 | # TODO(hayk): It may help performance a bit to cache this object
159 | converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
160 |
161 | segment = converter.audio_from_spectrogram_image(
162 | image,
163 | apply_filters=True,
164 | )
165 |
166 | # Export audio to MP3 bytes
167 | mp3_bytes = io.BytesIO()
168 | segment.export(mp3_bytes, format="mp3")
169 | mp3_bytes.seek(0)
170 |
171 | # Export image to JPEG bytes
172 | image_bytes = io.BytesIO()
173 | image.save(image_bytes, exif=image.getexif(), format="JPEG")
174 | image_bytes.seek(0)
175 |
176 | # Assemble the output dataclass
177 | output = InferenceOutput(
178 | image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
179 | audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
180 | duration_s=segment.duration_seconds,
181 | )
182 |
183 | return json.dumps(dataclasses.asdict(output))
184 |
185 |
186 | if __name__ == "__main__":
187 | import argh
188 |
189 | argh.dispatch_command(run_app)
190 |
--------------------------------------------------------------------------------
/riffusion/spectrogram_converter.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import pydub
5 | import torch
6 | import torchaudio
7 |
8 | from riffusion.spectrogram_params import SpectrogramParams
9 | from riffusion.util import audio_util, torch_util
10 |
11 |
12 | class SpectrogramConverter:
13 | """
14 | Convert between audio segments and spectrogram tensors using torchaudio.
15 |
16 | In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
17 | that represent the amplitude of the frequency at that time bucket (in the frequency domain).
18 | Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
19 | used in some functions is "mel amplitudes".
20 |
21 | The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
22 | returns the amplitude, because the phase is chaotic and hard to learn. The function
23 | `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
24 | approximates the phase information using the Griffin-Lim algorithm.
25 |
26 | Each channel in the audio is treated independently, and the spectrogram has a batch dimension
27 | equal to the number of channels in the input audio segment.
28 |
29 | Both the Griffin Lim algorithm and the Mel scaling process are lossy.
30 |
31 | For more information, see https://pytorch.org/audio/stable/transforms.html
32 | """
33 |
34 | def __init__(self, params: SpectrogramParams, device: str = "cuda"):
35 | self.p = params
36 |
37 | self.device = torch_util.check_device(device)
38 |
39 | if device.lower().startswith("mps"):
40 | warnings.warn(
41 | "WARNING: MPS does not support audio operations, falling back to CPU for them",
42 | stacklevel=2,
43 | )
44 | self.device = "cpu"
45 |
46 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
47 | self.spectrogram_func = torchaudio.transforms.Spectrogram(
48 | n_fft=params.n_fft,
49 | hop_length=params.hop_length,
50 | win_length=params.win_length,
51 | pad=0,
52 | window_fn=torch.hann_window,
53 | power=None,
54 | normalized=False,
55 | wkwargs=None,
56 | center=True,
57 | pad_mode="reflect",
58 | onesided=True,
59 | ).to(self.device)
60 |
61 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
62 | self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
63 | n_fft=params.n_fft,
64 | n_iter=params.num_griffin_lim_iters,
65 | win_length=params.win_length,
66 | hop_length=params.hop_length,
67 | window_fn=torch.hann_window,
68 | power=1.0,
69 | wkwargs=None,
70 | momentum=0.99,
71 | length=None,
72 | rand_init=True,
73 | ).to(self.device)
74 |
75 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
76 | self.mel_scaler = torchaudio.transforms.MelScale(
77 | n_mels=params.num_frequencies,
78 | sample_rate=params.sample_rate,
79 | f_min=params.min_frequency,
80 | f_max=params.max_frequency,
81 | n_stft=params.n_fft // 2 + 1,
82 | norm=params.mel_scale_norm,
83 | mel_scale=params.mel_scale_type,
84 | ).to(self.device)
85 |
86 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
87 | self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
88 | n_stft=params.n_fft // 2 + 1,
89 | n_mels=params.num_frequencies,
90 | sample_rate=params.sample_rate,
91 | f_min=params.min_frequency,
92 | f_max=params.max_frequency,
93 | max_iter=params.max_mel_iters,
94 | tolerance_loss=1e-5,
95 | tolerance_change=1e-8,
96 | sgdargs=None,
97 | norm=params.mel_scale_norm,
98 | mel_scale=params.mel_scale_type,
99 | ).to(self.device)
100 |
101 | def spectrogram_from_audio(
102 | self,
103 | audio: pydub.AudioSegment,
104 | ) -> np.ndarray:
105 | """
106 | Compute a spectrogram from an audio segment.
107 |
108 | Args:
109 | audio: Audio segment which must match the sample rate of the params
110 |
111 | Returns:
112 | spectrogram: (channel, frequency, time)
113 | """
114 | assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
115 |
116 | # Get the samples as a numpy array in (batch, samples) shape
117 | waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
118 |
119 | # Convert to floats if necessary
120 | if waveform.dtype != np.float32:
121 | waveform = waveform.astype(np.float32)
122 |
123 | waveform_tensor = torch.from_numpy(waveform).to(self.device)
124 | amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
125 | return amplitudes_mel.cpu().numpy()
126 |
127 | def audio_from_spectrogram(
128 | self,
129 | spectrogram: np.ndarray,
130 | apply_filters: bool = True,
131 | ) -> pydub.AudioSegment:
132 | """
133 | Reconstruct an audio segment from a spectrogram.
134 |
135 | Args:
136 | spectrogram: (batch, frequency, time)
137 | apply_filters: Post-process with normalization and compression
138 |
139 | Returns:
140 | audio: Audio segment with channels equal to the batch dimension
141 | """
142 | # Move to device
143 | amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
144 |
145 | # Reconstruct the waveform
146 | waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
147 |
148 | # Convert to audio segment
149 | segment = audio_util.audio_from_waveform(
150 | samples=waveform.cpu().numpy(),
151 | sample_rate=self.p.sample_rate,
152 | # Normalize the waveform to the range [-1, 1]
153 | normalize=True,
154 | )
155 |
156 | # Optionally apply post-processing filters
157 | if apply_filters:
158 | segment = audio_util.apply_filters(
159 | segment,
160 | compression=False,
161 | )
162 |
163 | return segment
164 |
165 | def mel_amplitudes_from_waveform(
166 | self,
167 | waveform: torch.Tensor,
168 | ) -> torch.Tensor:
169 | """
170 | Torch-only function to compute Mel-scale amplitudes from a waveform.
171 |
172 | Args:
173 | waveform: (batch, samples)
174 |
175 | Returns:
176 | amplitudes_mel: (batch, frequency, time)
177 | """
178 | # Compute the complex-valued spectrogram
179 | spectrogram_complex = self.spectrogram_func(waveform)
180 |
181 | # Take the magnitude
182 | amplitudes = torch.abs(spectrogram_complex)
183 |
184 | # Convert to mel scale
185 | return self.mel_scaler(amplitudes)
186 |
187 | def waveform_from_mel_amplitudes(
188 | self,
189 | amplitudes_mel: torch.Tensor,
190 | ) -> torch.Tensor:
191 | """
192 | Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
193 |
194 | Args:
195 | amplitudes_mel: (batch, frequency, time)
196 |
197 | Returns:
198 | waveform: (batch, samples)
199 | """
200 | # Convert from mel scale to linear
201 | amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
202 |
203 | # Run the approximate algorithm to compute the phase and recover the waveform
204 | return self.inverse_spectrogram_func(amplitudes_linear)
205 |
--------------------------------------------------------------------------------
/riffusion/spectrogram_image_converter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pydub
3 | from PIL import Image
4 |
5 | from riffusion.spectrogram_converter import SpectrogramConverter
6 | from riffusion.spectrogram_params import SpectrogramParams
7 | from riffusion.util import image_util
8 |
9 |
10 | class SpectrogramImageConverter:
11 | """
12 | Convert between spectrogram images and audio segments.
13 |
14 | This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
15 | to images and back. The real audio processing lives in SpectrogramConverter.
16 | """
17 |
18 | def __init__(self, params: SpectrogramParams, device: str = "cuda"):
19 | self.p = params
20 | self.device = device
21 | self.converter = SpectrogramConverter(params=params, device=device)
22 |
23 | def spectrogram_image_from_audio(
24 | self,
25 | segment: pydub.AudioSegment,
26 | ) -> Image.Image:
27 | """
28 | Compute a spectrogram image from an audio segment.
29 |
30 | Args:
31 | segment: Audio segment to convert
32 |
33 | Returns:
34 | Spectrogram image (in pillow format)
35 | """
36 | assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
37 |
38 | if self.p.stereo:
39 | if segment.channels == 1:
40 | print("WARNING: Mono audio but stereo=True, cloning channel")
41 | segment = segment.set_channels(2)
42 | elif segment.channels > 2:
43 | print("WARNING: Multi channel audio, reducing to stereo")
44 | segment = segment.set_channels(2)
45 | else:
46 | if segment.channels > 1:
47 | print("WARNING: Stereo audio but stereo=False, setting to mono")
48 | segment = segment.set_channels(1)
49 |
50 | spectrogram = self.converter.spectrogram_from_audio(segment)
51 |
52 | image = image_util.image_from_spectrogram(
53 | spectrogram,
54 | power=self.p.power_for_image,
55 | )
56 |
57 | # Store conversion params in exif metadata of the image
58 | exif_data = self.p.to_exif()
59 | exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
60 | exif = image.getexif()
61 | exif.update(exif_data.items())
62 |
63 | return image
64 |
65 | def audio_from_spectrogram_image(
66 | self,
67 | image: Image.Image,
68 | apply_filters: bool = True,
69 | max_value: float = 30e6,
70 | ) -> pydub.AudioSegment:
71 | """
72 | Reconstruct an audio segment from a spectrogram image.
73 |
74 | Args:
75 | image: Spectrogram image (in pillow format)
76 | apply_filters: Apply post-processing to improve the reconstructed audio
77 | max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
78 | """
79 | spectrogram = image_util.spectrogram_from_image(
80 | image,
81 | max_value=max_value,
82 | power=self.p.power_for_image,
83 | stereo=self.p.stereo,
84 | )
85 |
86 | segment = self.converter.audio_from_spectrogram(
87 | spectrogram,
88 | apply_filters=apply_filters,
89 | )
90 |
91 | return segment
92 |
--------------------------------------------------------------------------------
/riffusion/spectrogram_params.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import typing as T
4 | from dataclasses import dataclass
5 | from enum import Enum
6 |
7 |
8 | @dataclass(frozen=True)
9 | class SpectrogramParams:
10 | """
11 | Parameters for the conversion from audio to spectrograms to images and back.
12 |
13 | Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
14 | within spectrogram images.
15 |
16 | To understand what these parameters do and to customize them, read `spectrogram_converter.py`
17 | and the linked torchaudio documentation.
18 | """
19 |
20 | # Whether the audio is stereo or mono
21 | stereo: bool = False
22 |
23 | # FFT parameters
24 | sample_rate: int = 44100
25 | step_size_ms: int = 10
26 | window_duration_ms: int = 100
27 | padded_duration_ms: int = 400
28 |
29 | # Mel scale parameters
30 | num_frequencies: int = 512
31 | # TODO(hayk): Set these to [20, 20000] for newer models
32 | min_frequency: int = 0
33 | max_frequency: int = 10000
34 | mel_scale_norm: T.Optional[str] = None
35 | mel_scale_type: str = "htk"
36 | max_mel_iters: int = 200
37 |
38 | # Griffin Lim parameters
39 | num_griffin_lim_iters: int = 32
40 |
41 | # Image parameterization
42 | power_for_image: float = 0.25
43 |
44 | class ExifTags(Enum):
45 | """
46 | Custom EXIF tags for the spectrogram image.
47 | """
48 |
49 | SAMPLE_RATE = 11000
50 | STEREO = 11005
51 | STEP_SIZE_MS = 11010
52 | WINDOW_DURATION_MS = 11020
53 | PADDED_DURATION_MS = 11030
54 |
55 | NUM_FREQUENCIES = 11040
56 | MIN_FREQUENCY = 11050
57 | MAX_FREQUENCY = 11060
58 |
59 | POWER_FOR_IMAGE = 11070
60 | MAX_VALUE = 11080
61 |
62 | @property
63 | def n_fft(self) -> int:
64 | """
65 | The number of samples in each STFT window, with padding.
66 | """
67 | return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
68 |
69 | @property
70 | def win_length(self) -> int:
71 | """
72 | The number of samples in each STFT window.
73 | """
74 | return int(self.window_duration_ms / 1000.0 * self.sample_rate)
75 |
76 | @property
77 | def hop_length(self) -> int:
78 | """
79 | The number of samples between each STFT window.
80 | """
81 | return int(self.step_size_ms / 1000.0 * self.sample_rate)
82 |
83 | def to_exif(self) -> T.Dict[int, T.Any]:
84 | """
85 | Return a dictionary of EXIF tags for the current values.
86 | """
87 | return {
88 | self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
89 | self.ExifTags.STEREO.value: self.stereo,
90 | self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
91 | self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
92 | self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
93 | self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
94 | self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
95 | self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
96 | self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
97 | }
98 |
99 | @classmethod
100 | def from_exif(cls, exif: T.Mapping[int, T.Any]) -> SpectrogramParams:
101 | """
102 | Create a SpectrogramParams object from the EXIF tags of the given image.
103 | """
104 | # TODO(hayk): Handle missing tags
105 | return cls(
106 | sample_rate=exif[cls.ExifTags.SAMPLE_RATE.value],
107 | stereo=bool(exif[cls.ExifTags.STEREO.value]),
108 | step_size_ms=exif[cls.ExifTags.STEP_SIZE_MS.value],
109 | window_duration_ms=exif[cls.ExifTags.WINDOW_DURATION_MS.value],
110 | padded_duration_ms=exif[cls.ExifTags.PADDED_DURATION_MS.value],
111 | num_frequencies=exif[cls.ExifTags.NUM_FREQUENCIES.value],
112 | min_frequency=exif[cls.ExifTags.MIN_FREQUENCY.value],
113 | max_frequency=exif[cls.ExifTags.MAX_FREQUENCY.value],
114 | power_for_image=exif[cls.ExifTags.POWER_FOR_IMAGE.value],
115 | )
116 |
--------------------------------------------------------------------------------
/riffusion/streamlit/README.md:
--------------------------------------------------------------------------------
1 | # streamlit
2 |
3 | This package is an interactive streamlit app for riffusion.
4 |
--------------------------------------------------------------------------------
/riffusion/streamlit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/streamlit/__init__.py
--------------------------------------------------------------------------------
/riffusion/streamlit/playground.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import streamlit as st
4 | import streamlit.web.cli as stcli
5 | from streamlit import runtime
6 |
7 | PAGES = {
8 | "🎛️ Home": "tasks.home",
9 | "🌊 Text to Audio": "tasks.text_to_audio",
10 | "✨ Audio to Audio": "tasks.audio_to_audio",
11 | "🎭 Interpolation": "tasks.interpolation",
12 | "✂️ Audio Splitter": "tasks.split_audio",
13 | "📜 Text to Audio Batch": "tasks.text_to_audio_batch",
14 | "📎 Sample Clips": "tasks.sample_clips",
15 | "⏈ Spectrogram to Audio": "tasks.image_to_audio",
16 | }
17 |
18 |
19 | def render() -> None:
20 | st.set_page_config(
21 | page_title="Riffusion Playground",
22 | page_icon="🎸",
23 | layout="wide",
24 | )
25 |
26 | page = st.sidebar.selectbox("Page", list(PAGES.keys()))
27 | assert page is not None
28 | module = __import__(PAGES[page], fromlist=["render"])
29 | module.render()
30 |
31 |
32 | if __name__ == "__main__":
33 | if runtime.exists():
34 | render()
35 | else:
36 | sys.argv = ["streamlit", "run"] + sys.argv
37 | sys.exit(stcli.main())
38 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/streamlit/tasks/__init__.py
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/audio_to_audio.py:
--------------------------------------------------------------------------------
1 | import io
2 | import typing as T
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pydub
7 | import streamlit as st
8 | from PIL import Image
9 |
10 | from riffusion.datatypes import InferenceInput, PromptInput
11 | from riffusion.spectrogram_params import SpectrogramParams
12 | from riffusion.streamlit import util as streamlit_util
13 | from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation
14 | from riffusion.util import audio_util
15 |
16 |
17 | def render() -> None:
18 | st.subheader("✨ Audio to Audio")
19 | st.write(
20 | """
21 | Modify existing audio from a text prompt or interpolate between two.
22 | """
23 | )
24 |
25 | with st.expander("Help", False):
26 | st.write(
27 | """
28 | This tool allows you to upload an audio file of arbitrary length and modify it with
29 | a text prompt. It does this by sweeping over the audio in overlapping clips, doing
30 | img2img style transfer with riffusion, then stitching the clips back together with
31 | cross fading to eliminate seams.
32 |
33 | Try a denoising strength of 0.4 for light modification and 0.55 for more heavy
34 | modification. The best specific denoising depends on how different the prompt is
35 | from the source audio. You can play with the seed to get infinite variations.
36 | Currently the same seed is used for all clips along the track.
37 |
38 | If the Interpolation check box is enabled, supports entering two sets of prompt,
39 | seed, and denoising value and smoothly blends between them along the selected
40 | duration of the audio. This is a great way to create a transition.
41 | """
42 | )
43 |
44 | device = streamlit_util.select_device(st.sidebar)
45 | extension = streamlit_util.select_audio_extension(st.sidebar)
46 | checkpoint = streamlit_util.select_checkpoint(st.sidebar)
47 |
48 | use_20k = st.sidebar.checkbox("Use 20kHz", value=False)
49 | use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
50 |
51 | with st.sidebar:
52 | num_inference_steps = T.cast(
53 | int,
54 | st.number_input(
55 | "Steps per sample", value=25, help="Number of denoising steps per model run"
56 | ),
57 | )
58 |
59 | guidance = st.number_input(
60 | "Guidance",
61 | value=7.0,
62 | help="How much the model listens to the text prompt",
63 | )
64 |
65 | scheduler = st.selectbox(
66 | "Scheduler",
67 | options=streamlit_util.SCHEDULER_OPTIONS,
68 | index=0,
69 | help="Which diffusion scheduler to use",
70 | )
71 | assert scheduler is not None
72 |
73 | audio_file = st.file_uploader(
74 | "Upload audio",
75 | type=streamlit_util.AUDIO_EXTENSIONS,
76 | label_visibility="collapsed",
77 | )
78 |
79 | if not audio_file:
80 | st.info("Upload audio to get started")
81 | return
82 |
83 | st.write("#### Original")
84 | st.audio(audio_file)
85 |
86 | segment = streamlit_util.load_audio_file(audio_file)
87 |
88 | # TODO(hayk): Fix
89 | if segment.frame_rate != 44100:
90 | st.warning("Audio must be 44100Hz. Converting")
91 | segment = segment.set_frame_rate(44100)
92 | st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
93 |
94 | clip_p = get_clip_params()
95 | start_time_s = clip_p["start_time_s"]
96 | clip_duration_s = clip_p["clip_duration_s"]
97 | overlap_duration_s = clip_p["overlap_duration_s"]
98 |
99 | duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s)
100 | increment_s = clip_duration_s - overlap_duration_s
101 | clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
102 |
103 | write_clip_details(
104 | clip_start_times=clip_start_times,
105 | clip_duration_s=clip_duration_s,
106 | overlap_duration_s=overlap_duration_s,
107 | )
108 |
109 | interpolate = st.checkbox(
110 | "Interpolate between two endpoints",
111 | value=False,
112 | help="Interpolate between two prompts, seeds, or denoising values along the"
113 | "duration of the segment",
114 | )
115 |
116 | counter = streamlit_util.StreamlitCounter()
117 |
118 | denoising_default = 0.55
119 | with st.form("audio to audio form"):
120 | if interpolate:
121 | left, right = st.columns(2)
122 |
123 | with left:
124 | st.write("##### Prompt A")
125 | prompt_input_a = PromptInput(
126 | guidance=guidance,
127 | **get_prompt_inputs(key="a", denoising_default=denoising_default),
128 | )
129 |
130 | with right:
131 | st.write("##### Prompt B")
132 | prompt_input_b = PromptInput(
133 | guidance=guidance,
134 | **get_prompt_inputs(key="b", denoising_default=denoising_default),
135 | )
136 | elif use_magic_mix:
137 | prompt = st.text_input("Prompt", key="prompt_a")
138 |
139 | row = st.columns(4)
140 |
141 | seed = T.cast(
142 | int,
143 | row[0].number_input(
144 | "Seed",
145 | value=42,
146 | key="seed_a",
147 | ),
148 | )
149 | prompt_input_a = PromptInput(
150 | prompt=prompt,
151 | seed=seed,
152 | guidance=guidance,
153 | )
154 | magic_mix_kmin = row[1].number_input("Kmin", value=0.3)
155 | magic_mix_kmax = row[2].number_input("Kmax", value=0.5)
156 | magic_mix_mix_factor = row[3].number_input("Mix Factor", value=0.5)
157 | else:
158 | prompt_input_a = PromptInput(
159 | guidance=guidance,
160 | **get_prompt_inputs(
161 | key="a",
162 | include_negative_prompt=True,
163 | cols=True,
164 | denoising_default=denoising_default,
165 | ),
166 | )
167 |
168 | st.form_submit_button("Riff", type="primary", on_click=counter.increment)
169 |
170 | show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
171 | show_difference = st.sidebar.checkbox("Show Difference", False)
172 |
173 | clip_segments = slice_audio_into_clips(
174 | segment=segment,
175 | clip_start_times=clip_start_times,
176 | clip_duration_s=clip_duration_s,
177 | )
178 |
179 | if not prompt_input_a.prompt:
180 | st.info("Enter a prompt")
181 | return
182 |
183 | if counter.value == 0:
184 | return
185 |
186 | st.write(f"## Counter: {counter.value}")
187 |
188 | if use_20k:
189 | params = SpectrogramParams(
190 | min_frequency=10,
191 | max_frequency=20000,
192 | sample_rate=44100,
193 | stereo=True,
194 | )
195 | else:
196 | params = SpectrogramParams(
197 | min_frequency=0,
198 | max_frequency=10000,
199 | stereo=False,
200 | )
201 |
202 | if interpolate:
203 | # TODO(hayk): Make not linspace
204 | alphas = list(np.linspace(0, 1, len(clip_segments)))
205 | alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
206 | st.write(f"**Alphas** : [{alphas_str}]")
207 |
208 | result_images: T.List[Image.Image] = []
209 | result_segments: T.List[pydub.AudioSegment] = []
210 | for i, clip_segment in enumerate(clip_segments):
211 | st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s")
212 |
213 | audio_bytes = io.BytesIO()
214 | clip_segment.export(audio_bytes, format="wav")
215 |
216 | init_image = streamlit_util.spectrogram_image_from_audio(
217 | clip_segment,
218 | params=params,
219 | device=device,
220 | )
221 |
222 | # TODO(hayk): Roll this into spectrogram_image_from_audio?
223 | init_image_resized = scale_image_to_32_stride(init_image)
224 |
225 | progress_callback = None
226 | if show_clip_details:
227 | left, right = st.columns(2)
228 |
229 | left.write("##### Source Clip")
230 | left.image(init_image, use_column_width=False)
231 | left.audio(audio_bytes)
232 |
233 | right.write("##### Riffed Clip")
234 | empty_bin = right.empty()
235 | with empty_bin.container():
236 | st.info("Riffing...")
237 | progress = st.progress(0.0)
238 | progress_callback = progress.progress
239 |
240 | if interpolate:
241 | assert use_magic_mix is False, "Cannot use magic mix and interpolate together"
242 | inputs = InferenceInput(
243 | alpha=float(alphas[i]),
244 | num_inference_steps=num_inference_steps,
245 | seed_image_id="og_beat",
246 | start=prompt_input_a,
247 | end=prompt_input_b,
248 | )
249 |
250 | image, audio_bytes = run_interpolation(
251 | inputs=inputs,
252 | init_image=init_image_resized,
253 | device=device,
254 | checkpoint=checkpoint,
255 | )
256 | elif use_magic_mix:
257 | assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
258 | image = streamlit_util.run_img2img_magic_mix(
259 | prompt=prompt_input_a.prompt,
260 | init_image=init_image_resized,
261 | num_inference_steps=num_inference_steps,
262 | guidance_scale=guidance,
263 | seed=prompt_input_a.seed,
264 | kmin=magic_mix_kmin,
265 | kmax=magic_mix_kmax,
266 | mix_factor=magic_mix_mix_factor,
267 | device=device,
268 | scheduler=scheduler,
269 | checkpoint=checkpoint,
270 | )
271 | else:
272 | image = streamlit_util.run_img2img(
273 | prompt=prompt_input_a.prompt,
274 | init_image=init_image_resized,
275 | denoising_strength=prompt_input_a.denoising,
276 | num_inference_steps=num_inference_steps,
277 | guidance_scale=guidance,
278 | negative_prompt=prompt_input_a.negative_prompt,
279 | seed=prompt_input_a.seed,
280 | progress_callback=progress_callback,
281 | device=device,
282 | scheduler=scheduler,
283 | checkpoint=checkpoint,
284 | )
285 |
286 | # Resize back to original size
287 | image = image.resize(init_image.size, Image.BICUBIC)
288 |
289 | result_images.append(image)
290 |
291 | if show_clip_details:
292 | empty_bin.empty()
293 | right.image(image, use_column_width=False)
294 |
295 | riffed_segment = streamlit_util.audio_segment_from_spectrogram_image(
296 | image=image,
297 | params=params,
298 | device=device,
299 | )
300 | result_segments.append(riffed_segment)
301 |
302 | audio_bytes = io.BytesIO()
303 | riffed_segment.export(audio_bytes, format="wav")
304 |
305 | if show_clip_details:
306 | right.audio(audio_bytes)
307 |
308 | if show_clip_details and show_difference:
309 | diff_np = np.maximum(
310 | 0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32)
311 | )
312 | diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
313 | diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
314 | image=diff_image,
315 | params=params,
316 | device=device,
317 | )
318 |
319 | audio_bytes = io.BytesIO()
320 | diff_segment.export(audio_bytes, format=extension)
321 | st.audio(audio_bytes)
322 |
323 | # Combine clips with a crossfade based on overlap
324 | combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
325 |
326 | st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
327 |
328 | input_name = Path(audio_file.name).stem
329 | output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}"
330 | streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension)
331 |
332 |
333 | def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
334 | """
335 | Render the parameters of slicing audio into clips.
336 | """
337 | p: T.Dict[str, T.Any] = {}
338 |
339 | cols = st.columns(4)
340 |
341 | p["start_time_s"] = cols[0].number_input(
342 | "Start Time [s]",
343 | min_value=0.0,
344 | value=0.0,
345 | )
346 | p["duration_s"] = cols[1].number_input(
347 | "Duration [s]",
348 | min_value=0.0,
349 | value=20.0,
350 | )
351 |
352 | if advanced:
353 | p["clip_duration_s"] = cols[2].number_input(
354 | "Clip Duration [s]",
355 | min_value=3.0,
356 | max_value=10.0,
357 | value=5.0,
358 | )
359 | else:
360 | p["clip_duration_s"] = 5.0
361 |
362 | if advanced:
363 | p["overlap_duration_s"] = cols[3].number_input(
364 | "Overlap Duration [s]",
365 | min_value=0.0,
366 | max_value=10.0,
367 | value=0.2,
368 | )
369 | else:
370 | p["overlap_duration_s"] = 0.2
371 |
372 | return p
373 |
374 |
375 | def write_clip_details(
376 | clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float
377 | ):
378 | """
379 | Write details of the clips to be sliced from an audio segment.
380 | """
381 | clip_details_text = (
382 | f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
383 | f"with overlap {overlap_duration_s}s"
384 | )
385 |
386 | with st.expander(clip_details_text):
387 | st.dataframe(
388 | {
389 | "Start Time [s]": clip_start_times,
390 | "End Time [s]": clip_start_times + clip_duration_s,
391 | "Duration [s]": clip_duration_s,
392 | }
393 | )
394 |
395 |
396 | def slice_audio_into_clips(
397 | segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float
398 | ) -> T.List[pydub.AudioSegment]:
399 | """
400 | Slice an audio segment into a list of clips of a given duration at the given start times.
401 | """
402 | clip_segments: T.List[pydub.AudioSegment] = []
403 | for i, clip_start_time_s in enumerate(clip_start_times):
404 | clip_start_time_ms = int(clip_start_time_s * 1000)
405 | clip_duration_ms = int(clip_duration_s * 1000)
406 | clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
407 |
408 | # TODO(hayk): I don't think this is working properly
409 | if i == len(clip_start_times) - 1:
410 | silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
411 | if silence_ms > 0:
412 | clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
413 |
414 | clip_segments.append(clip_segment)
415 |
416 | return clip_segments
417 |
418 |
419 | def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
420 | """
421 | Scale an image to a size that is a multiple of 32.
422 | """
423 | closest_width = int(np.ceil(image.width / 32) * 32)
424 | closest_height = int(np.ceil(image.height / 32) * 32)
425 | return image.resize((closest_width, closest_height), Image.BICUBIC)
426 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/home.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 |
4 | def render():
5 | st.title("✨🎸 Riffusion Playground 🎸✨")
6 |
7 | st.write("Select a task from the sidebar to get started!")
8 |
9 | left, right = st.columns(2)
10 |
11 | with left:
12 | st.subheader("🌊 Text to Audio")
13 | st.write("Generate audio clips from text prompts.")
14 |
15 | st.subheader("✨ Audio to Audio")
16 | st.write("Upload audio and modify with text prompt (interpolation supported).")
17 |
18 | st.subheader("🎭 Interpolation")
19 | st.write("Interpolate between prompts in the latent space.")
20 |
21 | st.subheader("✂️ Audio Splitter")
22 | st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
23 |
24 | with right:
25 | st.subheader("📜 Text to Audio Batch")
26 | st.write("Generate audio in batch from a JSON file of text prompts.")
27 |
28 | st.subheader("📎 Sample Clips")
29 | st.write("Export short clips from an audio file.")
30 |
31 | st.subheader("⏈ Spectrogram to Audio")
32 | st.write("Reconstruct audio from spectrogram images.")
33 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/image_to_audio.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from pathlib import Path
3 |
4 | import streamlit as st
5 | from PIL import Image
6 |
7 | from riffusion.spectrogram_params import SpectrogramParams
8 | from riffusion.streamlit import util as streamlit_util
9 | from riffusion.util.image_util import exif_from_image
10 |
11 |
12 | def render() -> None:
13 | st.subheader("⏈ Image to Audio")
14 | st.write(
15 | """
16 | Reconstruct audio from spectrogram images.
17 | """
18 | )
19 |
20 | with st.expander("Help", False):
21 | st.write(
22 | """
23 | This tool takes an existing spectrogram image and reconstructs it into an audio
24 | waveform. It also displays the EXIF metadata stored inside the image, which can
25 | contain the parameters used to create the spectrogram image. If no EXIF is contained,
26 | assumes default parameters.
27 | """
28 | )
29 |
30 | device = streamlit_util.select_device(st.sidebar)
31 | extension = streamlit_util.select_audio_extension(st.sidebar)
32 |
33 | use_20k = st.sidebar.checkbox("Use 20kHz", value=False)
34 |
35 | image_file = st.file_uploader(
36 | "Upload a file",
37 | type=streamlit_util.IMAGE_EXTENSIONS,
38 | label_visibility="collapsed",
39 | )
40 | if not image_file:
41 | st.info("Upload an image file to get started")
42 | return
43 |
44 | image = Image.open(image_file)
45 | st.image(image)
46 |
47 | with st.expander("Image metadata", expanded=False):
48 | exif = exif_from_image(image)
49 | st.json(exif)
50 |
51 | try:
52 | params = SpectrogramParams.from_exif(exif=image.getexif())
53 | except KeyError:
54 | st.info("Could not find spectrogram parameters in exif data. Using defaults.")
55 | if use_20k:
56 | params = SpectrogramParams(
57 | min_frequency=10,
58 | max_frequency=20000,
59 | stereo=True,
60 | )
61 | else:
62 | params = SpectrogramParams()
63 |
64 | with st.expander("Spectrogram Parameters", expanded=False):
65 | st.json(dataclasses.asdict(params))
66 |
67 | segment = streamlit_util.audio_segment_from_spectrogram_image(
68 | image=image.copy(),
69 | params=params,
70 | device=device,
71 | )
72 |
73 | streamlit_util.display_and_download_audio(
74 | segment,
75 | name=Path(image_file.name).stem,
76 | extension=extension,
77 | )
78 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/interpolation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import io
3 | import typing as T
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | import pydub
8 | import streamlit as st
9 | from PIL import Image
10 |
11 | from riffusion.datatypes import InferenceInput, PromptInput
12 | from riffusion.spectrogram_params import SpectrogramParams
13 | from riffusion.streamlit import util as streamlit_util
14 |
15 |
16 | def render() -> None:
17 | st.subheader("🎭 Interpolation")
18 | st.write(
19 | """
20 | Interpolate between prompts in the latent space.
21 | """
22 | )
23 |
24 | with st.expander("Help", False):
25 | st.write(
26 | """
27 | This tool allows specifying two endpoints and generating a long-form interpolation
28 | between them that traverses the latent space. The interpolation is generated by
29 | the method described at https://www.riffusion.com/about. A seed image is used to
30 | set the beat and tempo of the generated audio, and can be set in the sidebar.
31 | Usually the seed is changed or the prompt, but not both at once. You can browse
32 | infinite variations of the same prompt by changing the seed.
33 |
34 | For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising.
35 | This will generate a 50 second clip at 5 seconds per step. Then play with the seeds
36 | or denoising to get different variations.
37 | """
38 | )
39 |
40 | # Sidebar params
41 |
42 | device = streamlit_util.select_device(st.sidebar)
43 | extension = streamlit_util.select_audio_extension(st.sidebar)
44 |
45 | num_interpolation_steps = T.cast(
46 | int,
47 | st.sidebar.number_input(
48 | "Interpolation steps",
49 | value=12,
50 | min_value=1,
51 | max_value=20,
52 | help="Number of model generations between the two prompts. Controls the duration.",
53 | ),
54 | )
55 |
56 | num_inference_steps = T.cast(
57 | int,
58 | st.sidebar.number_input(
59 | "Steps per sample", value=50, help="Number of denoising steps per model run"
60 | ),
61 | )
62 |
63 | guidance = st.sidebar.number_input(
64 | "Guidance",
65 | value=7.0,
66 | help="How much the model listens to the text prompt",
67 | )
68 |
69 | init_image_name = st.sidebar.selectbox(
70 | "Seed image",
71 | # TODO(hayk): Read from directory
72 | options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
73 | index=0,
74 | help="Which seed image to use for img2img. Custom allows uploading your own.",
75 | )
76 | assert init_image_name is not None
77 | if init_image_name == "custom":
78 | init_image_file = st.sidebar.file_uploader(
79 | "Upload a custom seed image",
80 | type=streamlit_util.IMAGE_EXTENSIONS,
81 | label_visibility="collapsed",
82 | )
83 | if init_image_file:
84 | st.sidebar.image(init_image_file)
85 |
86 | alpha_power = st.sidebar.number_input("Alpha Power", value=1.0)
87 |
88 | show_individual_outputs = st.sidebar.checkbox(
89 | "Show individual outputs",
90 | value=False,
91 | help="Show each model output",
92 | )
93 | show_images = st.sidebar.checkbox(
94 | "Show individual images",
95 | value=False,
96 | help="Show each generated image",
97 | )
98 |
99 | alphas = np.linspace(0, 1, num_interpolation_steps)
100 |
101 | # Apply power scaling to alphas to customize the interpolation curve
102 | alphas_shifted = alphas * 2 - 1
103 | alphas_shifted = (np.abs(alphas_shifted) ** alpha_power * np.sign(alphas_shifted) + 1) / 2
104 | alphas = alphas_shifted
105 |
106 | alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
107 | st.write(f"**Alphas** : [{alphas_str}]")
108 |
109 | # Prompt inputs A and B in two columns
110 |
111 | with st.form(key="interpolation_form"):
112 | left, right = st.columns(2)
113 |
114 | with left:
115 | st.write("##### Prompt A")
116 | prompt_input_a = PromptInput(
117 | guidance=guidance, **get_prompt_inputs(key="a", denoising_default=0.75)
118 | )
119 |
120 | with right:
121 | st.write("##### Prompt B")
122 | prompt_input_b = PromptInput(
123 | guidance=guidance, **get_prompt_inputs(key="b", denoising_default=0.75)
124 | )
125 |
126 | st.form_submit_button("Generate", type="primary")
127 |
128 | if not prompt_input_a.prompt or not prompt_input_b.prompt:
129 | st.info("Enter both prompts to interpolate between them")
130 | return
131 |
132 | if init_image_name == "custom":
133 | if not init_image_file:
134 | st.info("Upload a custom seed image")
135 | return
136 | init_image = Image.open(init_image_file).convert("RGB")
137 | else:
138 | init_image_path = (
139 | Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
140 | )
141 | init_image = Image.open(str(init_image_path)).convert("RGB")
142 |
143 | # TODO(hayk): Move this code into a shared place and add to riffusion.cli
144 | image_list: T.List[Image.Image] = []
145 | audio_bytes_list: T.List[io.BytesIO] = []
146 | for i, alpha in enumerate(alphas):
147 | inputs = InferenceInput(
148 | alpha=float(alpha),
149 | num_inference_steps=num_inference_steps,
150 | seed_image_id="og_beat",
151 | start=prompt_input_a,
152 | end=prompt_input_b,
153 | )
154 |
155 | if i == 0:
156 | with st.expander("Example input JSON", expanded=False):
157 | st.json(dataclasses.asdict(inputs))
158 |
159 | image, audio_bytes = run_interpolation(
160 | inputs=inputs,
161 | init_image=init_image,
162 | device=device,
163 | extension=extension,
164 | )
165 |
166 | if show_individual_outputs:
167 | st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
168 | if show_images:
169 | st.image(image)
170 | st.audio(audio_bytes)
171 |
172 | image_list.append(image)
173 | audio_bytes_list.append(audio_bytes)
174 |
175 | st.write("#### Final Output")
176 |
177 | # TODO(hayk): Concatenate with overlap and better blending like in audio to audio
178 | audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
179 | concat_segment = audio_segments[0]
180 | for segment in audio_segments[1:]:
181 | concat_segment = concat_segment.append(segment, crossfade=0)
182 |
183 | audio_bytes = io.BytesIO()
184 | concat_segment.export(audio_bytes, format=extension)
185 | audio_bytes.seek(0)
186 |
187 | st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
188 | st.audio(audio_bytes)
189 |
190 | output_name = (
191 | f"{prompt_input_a.prompt.replace(' ', '_')}_"
192 | f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
193 | )
194 | st.download_button(
195 | output_name,
196 | data=audio_bytes,
197 | file_name=output_name,
198 | mime=f"audio/{extension}",
199 | )
200 |
201 |
202 | def get_prompt_inputs(
203 | key: str,
204 | include_negative_prompt: bool = False,
205 | cols: bool = False,
206 | denoising_default: float = 0.5,
207 | ) -> T.Dict[str, T.Any]:
208 | """
209 | Compute prompt inputs from widgets.
210 | """
211 | p: T.Dict[str, T.Any] = {}
212 |
213 | # Optionally use columns
214 | left, right = T.cast(T.Any, st.columns(2) if cols else (st, st))
215 |
216 | visibility = "visible" if cols else "collapsed"
217 | p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}")
218 |
219 | if include_negative_prompt:
220 | p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}")
221 |
222 | p["seed"] = T.cast(
223 | int,
224 | left.number_input(
225 | "Seed",
226 | value=42,
227 | key=f"seed_{key}",
228 | help="Integer used to generate a random result. Vary this to explore alternatives.",
229 | ),
230 | )
231 |
232 | p["denoising"] = right.number_input(
233 | "Denoising",
234 | value=denoising_default,
235 | key=f"denoising_{key}",
236 | help="How much to modify the seed image",
237 | )
238 |
239 | return p
240 |
241 |
242 | @st.cache_data
243 | def run_interpolation(
244 | inputs: InferenceInput,
245 | init_image: Image.Image,
246 | checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT,
247 | device: str = "cuda",
248 | extension: str = "mp3",
249 | ) -> T.Tuple[Image.Image, io.BytesIO]:
250 | """
251 | Cached function for riffusion interpolation.
252 | """
253 | pipeline = streamlit_util.load_riffusion_checkpoint(
254 | device=device,
255 | checkpoint=checkpoint,
256 | # No trace so we can have variable width
257 | no_traced_unet=True,
258 | )
259 |
260 | image = pipeline.riffuse(
261 | inputs,
262 | init_image=init_image,
263 | mask_image=None,
264 | )
265 |
266 | # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
267 | params = SpectrogramParams(
268 | min_frequency=0,
269 | max_frequency=10000,
270 | )
271 |
272 | # Reconstruct from image to audio
273 | audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
274 | image=image,
275 | params=params,
276 | device=device,
277 | output_format=extension,
278 | )
279 |
280 | return image, audio_bytes
281 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/sample_clips.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import typing as T
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pydub
7 | import streamlit as st
8 |
9 | from riffusion.spectrogram_params import SpectrogramParams
10 | from riffusion.streamlit import util as streamlit_util
11 |
12 |
13 | def render() -> None:
14 | st.subheader("📎 Sample Clips")
15 | st.write(
16 | """
17 | Export short clips from an audio file.
18 | """
19 | )
20 |
21 | with st.expander("Help", False):
22 | st.write(
23 | """
24 | This tool simply allows uploading an audio file and randomly sampling short clips
25 | from it. It's useful for generating a large number of short clips from a single
26 | audio file. Outputs can be saved to a given directory with a given audio extension.
27 | """
28 | )
29 |
30 | audio_file = st.file_uploader(
31 | "Upload a file",
32 | type=streamlit_util.AUDIO_EXTENSIONS,
33 | label_visibility="collapsed",
34 | )
35 | if not audio_file:
36 | st.info("Upload an audio file to get started")
37 | return
38 |
39 | st.audio(audio_file)
40 |
41 | segment = pydub.AudioSegment.from_file(audio_file)
42 | st.write(
43 | " \n".join(
44 | [
45 | f"**Duration**: {segment.duration_seconds:.3f} seconds",
46 | f"**Channels**: {segment.channels}",
47 | f"**Sample rate**: {segment.frame_rate} Hz",
48 | f"**Sample width**: {segment.sample_width} bytes",
49 | ]
50 | )
51 | )
52 |
53 | device = streamlit_util.select_device(st.sidebar)
54 | extension = streamlit_util.select_audio_extension(st.sidebar)
55 | save_to_disk = st.sidebar.checkbox("Save to Disk", False)
56 | export_as_mono = st.sidebar.checkbox("Export as Mono", False)
57 | compute_spectrograms = st.sidebar.checkbox("Compute Spectrograms", False)
58 |
59 | row = st.columns(4)
60 | num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3))
61 | duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000))
62 | seed = T.cast(int, row[2].number_input("Seed", value=42))
63 |
64 | counter = streamlit_util.StreamlitCounter()
65 | st.button("Sample Clips", type="primary", on_click=counter.increment)
66 | if counter.value == 0:
67 | return
68 |
69 | # Optionally pick an output directory
70 | if save_to_disk:
71 | output_dir = tempfile.mkdtemp(prefix="sample_clips_")
72 | output_path = Path(output_dir)
73 | output_path.mkdir(parents=True, exist_ok=True)
74 | st.info(f"Output directory: `{output_dir}`")
75 |
76 | if compute_spectrograms:
77 | images_dir = output_path / "images"
78 | images_dir.mkdir(parents=True, exist_ok=True)
79 |
80 | if seed >= 0:
81 | np.random.seed(seed)
82 |
83 | if export_as_mono and segment.channels > 1:
84 | segment = segment.set_channels(1)
85 |
86 | if save_to_disk:
87 | st.info(f"Writing {num_clips} clip(s) to `{str(output_path)}`")
88 |
89 | # TODO(hayk): Share code with riffusion.cli.sample_clips.
90 | segment_duration_ms = int(segment.duration_seconds * 1000)
91 | for i in range(num_clips):
92 | clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
93 | clip = segment[clip_start_ms : clip_start_ms + duration_ms]
94 |
95 | clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms"
96 |
97 | st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
98 |
99 | streamlit_util.display_and_download_audio(
100 | clip,
101 | name=clip_name,
102 | extension=extension,
103 | )
104 |
105 | if save_to_disk:
106 | clip_path = output_path / f"{clip_name}.{extension}"
107 | clip.export(clip_path, format=extension)
108 |
109 | if compute_spectrograms:
110 | params = SpectrogramParams()
111 |
112 | image = streamlit_util.spectrogram_image_from_audio(
113 | clip,
114 | params=params,
115 | device=device,
116 | )
117 |
118 | st.image(image)
119 |
120 | if save_to_disk:
121 | image_path = images_dir / f"{clip_name}.jpeg"
122 | image.save(image_path)
123 |
124 | if save_to_disk:
125 | st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
126 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/split_audio.py:
--------------------------------------------------------------------------------
1 | import typing as T
2 | from pathlib import Path
3 |
4 | import pydub
5 | import streamlit as st
6 |
7 | from riffusion.audio_splitter import split_audio
8 | from riffusion.streamlit import util as streamlit_util
9 | from riffusion.util import audio_util
10 |
11 |
12 | def render() -> None:
13 | st.subheader("✂️ Audio Splitter")
14 | st.write(
15 | """
16 | Split audio into individual instrument stems.
17 | """
18 | )
19 |
20 | with st.expander("Help", False):
21 | st.write(
22 | """
23 | This tool allows uploading an audio file of arbitrary length and splits it into
24 | stems of vocals, drums, bass, and other. It does this using a deep network that
25 | sweeps over the audio in clips, extracts the stems, and then cross fades the clips
26 | back together to construct the full length stems. It's particularly useful in
27 | combination with audio_to_audio, for example to split and preserve vocals while
28 | modifying the rest of the track with a prompt. Or, to pull out drums to add later
29 | in a DAW.
30 | """
31 | )
32 |
33 | device = streamlit_util.select_device(st.sidebar)
34 |
35 | extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"]
36 | extension = st.sidebar.selectbox(
37 | "Output format",
38 | options=extension_options,
39 | index=extension_options.index("mp3"),
40 | )
41 | assert extension is not None
42 |
43 | audio_file = st.file_uploader(
44 | "Upload audio",
45 | type=extension_options,
46 | label_visibility="collapsed",
47 | )
48 |
49 | stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
50 | recombine = st.sidebar.multiselect(
51 | "Recombine",
52 | options=stem_options,
53 | default=[],
54 | help="Recombine these stems at the end",
55 | )
56 |
57 | if not audio_file:
58 | st.info("Upload audio to get started")
59 | return
60 |
61 | st.write("#### Original")
62 | st.audio(audio_file)
63 |
64 | counter = streamlit_util.StreamlitCounter()
65 | st.button("Split", type="primary", on_click=counter.increment)
66 | if counter.value == 0:
67 | return
68 |
69 | segment = streamlit_util.load_audio_file(audio_file)
70 |
71 | # Split
72 | stems = split_audio_cached(segment, device=device)
73 |
74 | input_name = Path(audio_file.name).stem
75 |
76 | # Display each
77 | for name in stem_options:
78 | stem = stems[name.lower()]
79 | st.write(f"#### Stem: {name}")
80 |
81 | output_name = f"{input_name}_{name.lower()}"
82 | streamlit_util.display_and_download_audio(stem, output_name, extension=extension)
83 |
84 | if recombine:
85 | recombine_lower = [r.lower() for r in recombine]
86 | segments = [s for name, s in stems.items() if name in recombine_lower]
87 | recombined = audio_util.overlay_segments(segments)
88 |
89 | # Display
90 | st.write(f"#### Recombined: {', '.join(recombine)}")
91 | output_name = f"{input_name}_{'_'.join(recombine_lower)}"
92 | streamlit_util.display_and_download_audio(recombined, output_name, extension=extension)
93 |
94 |
95 | @st.cache
96 | def split_audio_cached(
97 | segment: pydub.AudioSegment, device: str = "cuda"
98 | ) -> T.Dict[str, pydub.AudioSegment]:
99 | return split_audio(segment, device=device)
100 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/text_to_audio.py:
--------------------------------------------------------------------------------
1 | import typing as T
2 |
3 | import streamlit as st
4 |
5 | from riffusion.spectrogram_params import SpectrogramParams
6 | from riffusion.streamlit import util as streamlit_util
7 |
8 |
9 | def render() -> None:
10 | st.subheader("🌊 Text to Audio")
11 | st.write(
12 | """
13 | Generate audio from text prompts.
14 | """
15 | )
16 |
17 | with st.expander("Help", False):
18 | st.write(
19 | """
20 | This tool runs riffusion in the simplest text to image form to generate an audio
21 | clip from a text prompt. There is no seed image or interpolation here. This mode
22 | allows more diversity and creativity than when using a seed image, but it also
23 | leads to having less control. Play with the seed to get infinite variations.
24 | """
25 | )
26 |
27 | device = streamlit_util.select_device(st.sidebar)
28 | extension = streamlit_util.select_audio_extension(st.sidebar)
29 | checkpoint = streamlit_util.select_checkpoint(st.sidebar)
30 |
31 | with st.form("Inputs"):
32 | prompt = st.text_input("Prompt")
33 | negative_prompt = st.text_input("Negative prompt")
34 |
35 | row = st.columns(4)
36 | num_clips = T.cast(
37 | int,
38 | row[0].number_input(
39 | "Number of clips",
40 | value=1,
41 | min_value=1,
42 | max_value=25,
43 | help="How many outputs to generate (seed gets incremented)",
44 | ),
45 | )
46 | starting_seed = T.cast(
47 | int,
48 | row[1].number_input(
49 | "Seed",
50 | value=42,
51 | help="Change this to generate different variations",
52 | ),
53 | )
54 |
55 | st.form_submit_button("Riff", type="primary")
56 |
57 | with st.sidebar:
58 | num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30))
59 | width = T.cast(int, st.number_input("Width", value=512))
60 | guidance = st.number_input(
61 | "Guidance", value=7.0, help="How much the model listens to the text prompt"
62 | )
63 | scheduler = st.selectbox(
64 | "Scheduler",
65 | options=streamlit_util.SCHEDULER_OPTIONS,
66 | index=0,
67 | help="Which diffusion scheduler to use",
68 | )
69 | assert scheduler is not None
70 |
71 | use_20k = st.checkbox("Use 20kHz", value=False)
72 |
73 | if not prompt:
74 | st.info("Enter a prompt")
75 | return
76 |
77 | if use_20k:
78 | params = SpectrogramParams(
79 | min_frequency=10,
80 | max_frequency=20000,
81 | sample_rate=44100,
82 | stereo=True,
83 | )
84 | else:
85 | params = SpectrogramParams(
86 | min_frequency=0,
87 | max_frequency=10000,
88 | stereo=False,
89 | )
90 |
91 | seed = starting_seed
92 | for i in range(1, num_clips + 1):
93 | st.write(f"#### Riff {i} / {num_clips} - Seed {seed}")
94 |
95 | image = streamlit_util.run_txt2img(
96 | prompt=prompt,
97 | num_inference_steps=num_inference_steps,
98 | guidance=guidance,
99 | negative_prompt=negative_prompt,
100 | seed=seed,
101 | width=width,
102 | height=512,
103 | checkpoint=checkpoint,
104 | device=device,
105 | scheduler=scheduler,
106 | )
107 | st.image(image)
108 |
109 | segment = streamlit_util.audio_segment_from_spectrogram_image(
110 | image=image,
111 | params=params,
112 | device=device,
113 | )
114 |
115 | streamlit_util.display_and_download_audio(
116 | segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension
117 | )
118 |
119 | seed += 1
120 |
--------------------------------------------------------------------------------
/riffusion/streamlit/tasks/text_to_audio_batch.py:
--------------------------------------------------------------------------------
1 | import json
2 | import typing as T
3 | from pathlib import Path
4 |
5 | import streamlit as st
6 |
7 | from riffusion.spectrogram_params import SpectrogramParams
8 | from riffusion.streamlit import util as streamlit_util
9 |
10 | # Example input json file to process in batch
11 | EXAMPLE_INPUT = """
12 | {
13 | "params": {
14 | "checkpoint": "riffusion/riffusion-model-v1",
15 | "scheduler": "DPMSolverMultistepScheduler",
16 | "num_inference_steps": 50,
17 | "guidance": 7.0,
18 | "width": 512,
19 | },
20 | "entries": [
21 | {
22 | "prompt": "Church bells",
23 | "seed": 42
24 | },
25 | {
26 | "prompt": "electronic beats",
27 | "negative_prompt": "drums",
28 | "seed": 100
29 | },
30 | {
31 | "prompt": "classical violin concerto",
32 | "seed": 4
33 | }
34 | ]
35 | }
36 | """
37 |
38 |
39 | def render() -> None:
40 | st.subheader("📜 Text to Audio Batch")
41 | st.write(
42 | """
43 | Generate audio in batch from a JSON file of text prompts.
44 | """
45 | )
46 |
47 | with st.expander("Help", False):
48 | st.write(
49 | """
50 | This tool is a batch form of text_to_audio, where the inputs are read in from a JSON
51 | file. The input file contains a global params block and a list of entries with positive
52 | and negative prompts. It's useful for automating a larger set of generations. See the
53 | example inputs below for the format of the file.
54 | """
55 | )
56 |
57 | device = streamlit_util.select_device(st.sidebar)
58 |
59 | # Upload a JSON file
60 | json_file = st.file_uploader(
61 | "JSON file",
62 | type=["json"],
63 | label_visibility="collapsed",
64 | )
65 |
66 | # Handle the null case
67 | if json_file is None:
68 | st.info("Upload a JSON file containing params and prompts")
69 | with st.expander("Example inputs.json", expanded=False):
70 | st.code(EXAMPLE_INPUT)
71 | return
72 |
73 | # Read in and print it
74 | data = json.loads(json_file.read())
75 | with st.expander("Input Data", expanded=False):
76 | st.json(data)
77 |
78 | # Params can either be a list or a single entry
79 | if isinstance(data["params"], list):
80 | param_sets = data["params"]
81 | else:
82 | param_sets = [data["params"]]
83 |
84 | entries = data["entries"]
85 |
86 | show_images = st.sidebar.checkbox("Show Images", True)
87 | num_seeds = st.sidebar.number_input(
88 | "Num Seeds",
89 | value=1,
90 | min_value=1,
91 | max_value=10,
92 | help="When > 1, increments the seed and runs multiple for each entry",
93 | )
94 |
95 | # Optionally specify an output directory
96 | output_dir = st.sidebar.text_input("Output Directory", "")
97 | output_path: T.Optional[Path] = None
98 | if output_dir:
99 | output_path = Path(output_dir)
100 | output_path.mkdir(parents=True, exist_ok=True)
101 |
102 | # Write title cards for each param set
103 | title_cols = st.columns(len(param_sets))
104 | for i, params in enumerate(param_sets):
105 | col = title_cols[i]
106 |
107 | if "name" not in params:
108 | params["name"] = f"params[{i}]"
109 |
110 | col.write(f"## Param Set {i}")
111 | col.json(params)
112 |
113 | for entry_i, entry in enumerate(entries):
114 | st.write("---")
115 | print(entry)
116 | prompt = entry["prompt"]
117 | negative_prompt = entry.get("negative_prompt", None)
118 |
119 | base_seed = entry.get("seed", 42)
120 |
121 | text = f"##### ({base_seed}) {prompt}"
122 | if negative_prompt:
123 | text += f" \n**Negative**: {negative_prompt}"
124 | st.write(text)
125 |
126 | for seed in range(base_seed, base_seed + num_seeds):
127 | cols = st.columns(len(param_sets))
128 | for i, params in enumerate(param_sets):
129 | col = cols[i]
130 | col.write(params["name"])
131 |
132 | image = streamlit_util.run_txt2img(
133 | prompt=prompt,
134 | negative_prompt=negative_prompt,
135 | seed=seed,
136 | num_inference_steps=params.get("num_inference_steps", 50),
137 | guidance=params.get("guidance", 7.0),
138 | width=params.get("width", 512),
139 | checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT),
140 | scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]),
141 | height=512,
142 | device=device,
143 | )
144 |
145 | if show_images:
146 | col.image(image)
147 |
148 | # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
149 | p_spectrogram = SpectrogramParams(
150 | min_frequency=0,
151 | max_frequency=10000,
152 | )
153 |
154 | output_format = "mp3"
155 | audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
156 | image=image,
157 | params=p_spectrogram,
158 | device=device,
159 | output_format=output_format,
160 | )
161 | col.audio(audio_bytes)
162 |
163 | if output_path:
164 | prompt_slug = entry["prompt"].replace(" ", "_")
165 | negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
166 |
167 | image_path = (
168 | output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
169 | )
170 | image.save(image_path, format="JPEG")
171 | entry["image_path"] = str(image_path)
172 |
173 | audio_path = (
174 | output_path
175 | / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
176 | )
177 | audio_path.write_bytes(audio_bytes.getbuffer())
178 | entry["audio_path"] = str(audio_path)
179 |
180 | if output_path:
181 | output_json_path = output_path / "index.json"
182 | output_json_path.write_text(json.dumps(data, indent=4))
183 | st.info(f"Output written to {str(output_path)}")
184 | else:
185 | st.info("Enter output directory in sidebar to save to disk")
186 |
--------------------------------------------------------------------------------
/riffusion/streamlit/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Streamlit utilities (mostly cached wrappers around riffusion code).
3 | """
4 | import io
5 | import threading
6 | import typing as T
7 |
8 | import pydub
9 | import streamlit as st
10 | import torch
11 | from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
12 | from PIL import Image
13 |
14 | from riffusion.audio_splitter import AudioSplitter
15 | from riffusion.riffusion_pipeline import RiffusionPipeline
16 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter
17 | from riffusion.spectrogram_params import SpectrogramParams
18 |
19 | # TODO(hayk): Add URL params
20 |
21 | DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1"
22 |
23 | AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
24 | IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
25 |
26 | SCHEDULER_OPTIONS = [
27 | "DPMSolverMultistepScheduler",
28 | "PNDMScheduler",
29 | "DDIMScheduler",
30 | "LMSDiscreteScheduler",
31 | "EulerDiscreteScheduler",
32 | "EulerAncestralDiscreteScheduler",
33 | ]
34 |
35 |
36 | @st.cache_resource
37 | def load_riffusion_checkpoint(
38 | checkpoint: str = DEFAULT_CHECKPOINT,
39 | no_traced_unet: bool = False,
40 | device: str = "cuda",
41 | ) -> RiffusionPipeline:
42 | """
43 | Load the riffusion pipeline.
44 | """
45 | return RiffusionPipeline.load_checkpoint(
46 | checkpoint=checkpoint,
47 | use_traced_unet=not no_traced_unet,
48 | device=device,
49 | )
50 |
51 |
52 | @st.cache_resource
53 | def load_stable_diffusion_pipeline(
54 | checkpoint: str = DEFAULT_CHECKPOINT,
55 | device: str = "cuda",
56 | dtype: torch.dtype = torch.float16,
57 | scheduler: str = SCHEDULER_OPTIONS[0],
58 | ) -> StableDiffusionPipeline:
59 | """
60 | Load the riffusion pipeline.
61 |
62 | TODO(hayk): Merge this into RiffusionPipeline to just load one model.
63 | """
64 | if device == "cpu" or device.lower().startswith("mps"):
65 | print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
66 | dtype = torch.float32
67 |
68 | pipeline = StableDiffusionPipeline.from_pretrained(
69 | checkpoint,
70 | revision="main",
71 | torch_dtype=dtype,
72 | safety_checker=lambda images, **kwargs: (images, False),
73 | ).to(device)
74 |
75 | pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
76 |
77 | return pipeline
78 |
79 |
80 | def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
81 | """
82 | Construct a denoising scheduler from a string.
83 | """
84 | if scheduler == "PNDMScheduler":
85 | from diffusers import PNDMScheduler
86 |
87 | return PNDMScheduler.from_config(config)
88 | elif scheduler == "DPMSolverMultistepScheduler":
89 | from diffusers import DPMSolverMultistepScheduler
90 |
91 | return DPMSolverMultistepScheduler.from_config(config)
92 | elif scheduler == "DDIMScheduler":
93 | from diffusers import DDIMScheduler
94 |
95 | return DDIMScheduler.from_config(config)
96 | elif scheduler == "LMSDiscreteScheduler":
97 | from diffusers import LMSDiscreteScheduler
98 |
99 | return LMSDiscreteScheduler.from_config(config)
100 | elif scheduler == "EulerDiscreteScheduler":
101 | from diffusers import EulerDiscreteScheduler
102 |
103 | return EulerDiscreteScheduler.from_config(config)
104 | elif scheduler == "EulerAncestralDiscreteScheduler":
105 | from diffusers import EulerAncestralDiscreteScheduler
106 |
107 | return EulerAncestralDiscreteScheduler.from_config(config)
108 | else:
109 | raise ValueError(f"Unknown scheduler {scheduler}")
110 |
111 |
112 | @st.cache_resource
113 | def pipeline_lock() -> threading.Lock:
114 | """
115 | Singleton lock used to prevent concurrent access to any model pipeline.
116 | """
117 | return threading.Lock()
118 |
119 |
120 | @st.cache_resource
121 | def load_stable_diffusion_img2img_pipeline(
122 | checkpoint: str = DEFAULT_CHECKPOINT,
123 | device: str = "cuda",
124 | dtype: torch.dtype = torch.float16,
125 | scheduler: str = SCHEDULER_OPTIONS[0],
126 | ) -> StableDiffusionImg2ImgPipeline:
127 | """
128 | Load the image to image pipeline.
129 |
130 | TODO(hayk): Merge this into RiffusionPipeline to just load one model.
131 | """
132 | if device == "cpu" or device.lower().startswith("mps"):
133 | print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
134 | dtype = torch.float32
135 |
136 | pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
137 | checkpoint,
138 | revision="main",
139 | torch_dtype=dtype,
140 | safety_checker=lambda images, **kwargs: (images, False),
141 | ).to(device)
142 |
143 | pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
144 |
145 | return pipeline
146 |
147 |
148 | @st.cache_data(persist=True)
149 | def run_txt2img(
150 | prompt: str,
151 | num_inference_steps: int,
152 | guidance: float,
153 | negative_prompt: str,
154 | seed: int,
155 | width: int,
156 | height: int,
157 | checkpoint: str = DEFAULT_CHECKPOINT,
158 | device: str = "cuda",
159 | scheduler: str = SCHEDULER_OPTIONS[0],
160 | ) -> Image.Image:
161 | """
162 | Run the text to image pipeline with caching.
163 | """
164 | with pipeline_lock():
165 | pipeline = load_stable_diffusion_pipeline(
166 | checkpoint=checkpoint,
167 | device=device,
168 | scheduler=scheduler,
169 | )
170 |
171 | generator_device = "cpu" if device.lower().startswith("mps") else device
172 | generator = torch.Generator(device=generator_device).manual_seed(seed)
173 |
174 | output = pipeline(
175 | prompt=prompt,
176 | num_inference_steps=num_inference_steps,
177 | guidance_scale=guidance,
178 | negative_prompt=negative_prompt or None,
179 | generator=generator,
180 | width=width,
181 | height=height,
182 | )
183 |
184 | return output["images"][0]
185 |
186 |
187 | @st.cache_resource
188 | def spectrogram_image_converter(
189 | params: SpectrogramParams,
190 | device: str = "cuda",
191 | ) -> SpectrogramImageConverter:
192 | return SpectrogramImageConverter(params=params, device=device)
193 |
194 |
195 | @st.cache
196 | def spectrogram_image_from_audio(
197 | segment: pydub.AudioSegment,
198 | params: SpectrogramParams,
199 | device: str = "cuda",
200 | ) -> Image.Image:
201 | converter = spectrogram_image_converter(params=params, device=device)
202 | return converter.spectrogram_image_from_audio(segment)
203 |
204 |
205 | @st.cache_data
206 | def audio_segment_from_spectrogram_image(
207 | image: Image.Image,
208 | params: SpectrogramParams,
209 | device: str = "cuda",
210 | ) -> pydub.AudioSegment:
211 | converter = spectrogram_image_converter(params=params, device=device)
212 | return converter.audio_from_spectrogram_image(image)
213 |
214 |
215 | @st.cache_data
216 | def audio_bytes_from_spectrogram_image(
217 | image: Image.Image,
218 | params: SpectrogramParams,
219 | device: str = "cuda",
220 | output_format: str = "mp3",
221 | ) -> io.BytesIO:
222 | segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device)
223 |
224 | audio_bytes = io.BytesIO()
225 | segment.export(audio_bytes, format=output_format)
226 |
227 | return audio_bytes
228 |
229 |
230 | def select_device(container: T.Any = st.sidebar) -> str:
231 | """
232 | Dropdown to select a torch device, with an intelligent default.
233 | """
234 | default_device = "cpu"
235 | if torch.cuda.is_available():
236 | default_device = "cuda"
237 | elif torch.backends.mps.is_available():
238 | default_device = "mps"
239 |
240 | device_options = ["cuda", "cpu", "mps"]
241 | device = st.sidebar.selectbox(
242 | "Device",
243 | options=device_options,
244 | index=device_options.index(default_device),
245 | help="Which compute device to use. CUDA is recommended.",
246 | )
247 | assert device is not None
248 |
249 | return device
250 |
251 |
252 | def select_audio_extension(container: T.Any = st.sidebar) -> str:
253 | """
254 | Dropdown to select an audio extension, with an intelligent default.
255 | """
256 | default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
257 | extension = container.selectbox(
258 | "Output format",
259 | options=AUDIO_EXTENSIONS,
260 | index=AUDIO_EXTENSIONS.index(default),
261 | )
262 | assert extension is not None
263 | return extension
264 |
265 |
266 | def select_scheduler(container: T.Any = st.sidebar) -> str:
267 | """
268 | Dropdown to select a scheduler.
269 | """
270 | scheduler = st.sidebar.selectbox(
271 | "Scheduler",
272 | options=SCHEDULER_OPTIONS,
273 | index=0,
274 | help="Which diffusion scheduler to use",
275 | )
276 | assert scheduler is not None
277 | return scheduler
278 |
279 |
280 | def select_checkpoint(container: T.Any = st.sidebar) -> str:
281 | """
282 | Provide a custom model checkpoint.
283 | """
284 | return container.text_input(
285 | "Custom Checkpoint",
286 | value=DEFAULT_CHECKPOINT,
287 | help="Provide a custom model checkpoint",
288 | )
289 |
290 |
291 | @st.cache_data
292 | def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
293 | return pydub.AudioSegment.from_file(audio_file)
294 |
295 |
296 | @st.cache_resource
297 | def get_audio_splitter(device: str = "cuda"):
298 | return AudioSplitter(device=device)
299 |
300 |
301 | @st.cache_resource
302 | def load_magic_mix_pipeline(
303 | checkpoint: str = DEFAULT_CHECKPOINT,
304 | device: str = "cuda",
305 | scheduler: str = SCHEDULER_OPTIONS[0],
306 | ):
307 | pipeline = DiffusionPipeline.from_pretrained(
308 | checkpoint,
309 | custom_pipeline="magic_mix",
310 | ).to(device)
311 |
312 | pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
313 |
314 | return pipeline
315 |
316 |
317 | @st.cache
318 | def run_img2img_magic_mix(
319 | prompt: str,
320 | init_image: Image.Image,
321 | num_inference_steps: int,
322 | guidance_scale: float,
323 | seed: int,
324 | kmin: float,
325 | kmax: float,
326 | mix_factor: float,
327 | checkpoint: str = DEFAULT_CHECKPOINT,
328 | device: str = "cuda",
329 | scheduler: str = SCHEDULER_OPTIONS[0],
330 | ):
331 | """
332 | Run the magic mix pipeline for img2img.
333 | """
334 | with pipeline_lock():
335 | pipeline = load_magic_mix_pipeline(
336 | checkpoint=checkpoint,
337 | device=device,
338 | scheduler=scheduler,
339 | )
340 |
341 | return pipeline(
342 | init_image,
343 | prompt=prompt,
344 | kmin=kmin,
345 | kmax=kmax,
346 | mix_factor=mix_factor,
347 | seed=seed,
348 | guidance_scale=guidance_scale,
349 | steps=num_inference_steps,
350 | )
351 |
352 |
353 | @st.cache
354 | def run_img2img(
355 | prompt: str,
356 | init_image: Image.Image,
357 | denoising_strength: float,
358 | num_inference_steps: int,
359 | guidance_scale: float,
360 | seed: int,
361 | negative_prompt: T.Optional[str] = None,
362 | checkpoint: str = DEFAULT_CHECKPOINT,
363 | device: str = "cuda",
364 | scheduler: str = SCHEDULER_OPTIONS[0],
365 | progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
366 | ) -> Image.Image:
367 | with pipeline_lock():
368 | pipeline = load_stable_diffusion_img2img_pipeline(
369 | checkpoint=checkpoint,
370 | device=device,
371 | scheduler=scheduler,
372 | )
373 |
374 | generator_device = "cpu" if device.lower().startswith("mps") else device
375 | generator = torch.Generator(device=generator_device).manual_seed(seed)
376 |
377 | num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)
378 |
379 | def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
380 | if progress_callback is not None:
381 | progress_callback(step / num_expected_steps)
382 |
383 | result = pipeline(
384 | prompt=prompt,
385 | image=init_image,
386 | strength=denoising_strength,
387 | num_inference_steps=num_inference_steps,
388 | guidance_scale=guidance_scale,
389 | negative_prompt=negative_prompt or None,
390 | num_images_per_prompt=1,
391 | generator=generator,
392 | callback=callback,
393 | callback_steps=1,
394 | )
395 |
396 | return result.images[0]
397 |
398 |
399 | class StreamlitCounter:
400 | """
401 | Simple counter stored in streamlit session state.
402 | """
403 |
404 | def __init__(self, key="_counter"):
405 | self.key = key
406 | if not st.session_state.get(self.key):
407 | st.session_state[self.key] = 0
408 |
409 | def increment(self):
410 | st.session_state[self.key] += 1
411 |
412 | @property
413 | def value(self):
414 | return st.session_state[self.key]
415 |
416 |
417 | def display_and_download_audio(
418 | segment: pydub.AudioSegment,
419 | name: str,
420 | extension: str = "mp3",
421 | ) -> None:
422 | """
423 | Display the given audio segment and provide a button to download it with
424 | a proper file name, since st.audio doesn't support that.
425 | """
426 | mime_type = f"audio/{extension}"
427 | audio_bytes = io.BytesIO()
428 | segment.export(audio_bytes, format=extension)
429 | st.audio(audio_bytes, format=mime_type)
430 |
431 | st.download_button(
432 | f"{name}.{extension}",
433 | data=audio_bytes,
434 | file_name=f"{name}.{extension}",
435 | mime=mime_type,
436 | )
437 |
--------------------------------------------------------------------------------
/riffusion/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/util/__init__.py
--------------------------------------------------------------------------------
/riffusion/util/audio_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Audio utility functions.
3 | """
4 |
5 | import io
6 | import typing as T
7 |
8 | import numpy as np
9 | import pydub
10 | from scipy.io import wavfile
11 |
12 |
13 | def audio_from_waveform(
14 | samples: np.ndarray, sample_rate: int, normalize: bool = False
15 | ) -> pydub.AudioSegment:
16 | """
17 | Convert a numpy array of samples of a waveform to an audio segment.
18 |
19 | Args:
20 | samples: (channels, samples) array
21 | """
22 | # Normalize volume to fit in int16
23 | if normalize:
24 | samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
25 |
26 | # Transpose and convert to int16
27 | samples = samples.transpose(1, 0)
28 | samples = samples.astype(np.int16)
29 |
30 | # Write to the bytes of a WAV file
31 | wav_bytes = io.BytesIO()
32 | wavfile.write(wav_bytes, sample_rate, samples)
33 | wav_bytes.seek(0)
34 |
35 | # Read into pydub
36 | return pydub.AudioSegment.from_wav(wav_bytes)
37 |
38 |
39 | def apply_filters(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment:
40 | """
41 | Apply post-processing filters to the audio segment to compress it and
42 | keep at a -10 dBFS level.
43 | """
44 | # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
45 | # TODO(hayk): Is this going to make audio unbalanced between sequential clips?
46 |
47 | if compression:
48 | segment = pydub.effects.normalize(
49 | segment,
50 | headroom=0.1,
51 | )
52 |
53 | segment = segment.apply_gain(-10 - segment.dBFS)
54 |
55 | # TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU
56 | segment = pydub.effects.compress_dynamic_range(
57 | segment,
58 | threshold=-20.0,
59 | ratio=4.0,
60 | attack=5.0,
61 | release=50.0,
62 | )
63 |
64 | desired_db = -12
65 | segment = segment.apply_gain(desired_db - segment.dBFS)
66 |
67 | segment = pydub.effects.normalize(
68 | segment,
69 | headroom=0.1,
70 | )
71 |
72 | return segment
73 |
74 |
75 | def stitch_segments(
76 | segments: T.Sequence[pydub.AudioSegment], crossfade_s: float
77 | ) -> pydub.AudioSegment:
78 | """
79 | Stitch together a sequence of audio segments with a crossfade between each segment.
80 | """
81 | crossfade_ms = int(crossfade_s * 1000)
82 | combined_segment = segments[0]
83 | for segment in segments[1:]:
84 | combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
85 | return combined_segment
86 |
87 |
88 | def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment:
89 | """
90 | Overlay a sequence of audio segments on top of each other.
91 | """
92 | assert len(segments) > 0
93 | output: pydub.AudioSegment = None
94 | for segment in segments:
95 | if output is None:
96 | output = segment
97 | else:
98 | output = output.overlay(segment)
99 | return output
100 |
--------------------------------------------------------------------------------
/riffusion/util/base64_util.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 |
4 |
5 | def encode(buffer: io.BytesIO) -> str:
6 | """
7 | Encode the given buffer as base64.
8 | """
9 | return base64.encodebytes(buffer.getvalue()).decode("ascii")
10 |
--------------------------------------------------------------------------------
/riffusion/util/fft_util.py:
--------------------------------------------------------------------------------
1 | """
2 | FFT tools to analyze frequency content of audio segments. This is not code for
3 | dealing with spectrogram images, but for analysis of waveforms.
4 | """
5 | import struct
6 | import typing as T
7 |
8 | import numpy as np
9 | import plotly.graph_objects as go
10 | import pydub
11 | from scipy.fft import rfft, rfftfreq
12 |
13 |
14 | def plot_ffts(
15 | segments: T.Dict[str, pydub.AudioSegment],
16 | title: str = "FFT",
17 | min_frequency: float = 20,
18 | max_frequency: float = 20000,
19 | ) -> None:
20 | """
21 | Plot an FFT analysis of the given audio segments.
22 | """
23 | ffts = {name: compute_fft(seg) for name, seg in segments.items()}
24 |
25 | fig = go.Figure(
26 | data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
27 | layout={"title": title},
28 | )
29 | fig.update_xaxes(
30 | range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
31 | type="log",
32 | title="Frequency",
33 | )
34 | fig.update_yaxes(title="Value")
35 | fig.show()
36 |
37 |
38 | def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
39 | """
40 | Compute the FFT of the given audio segment as a mono signal.
41 |
42 | Returns:
43 | frequencies: FFT computed frequencies
44 | amplitudes: Amplitudes of each frequency
45 | """
46 | # Convert to mono if needed.
47 | if sound.channels > 1:
48 | sound = sound.set_channels(1)
49 |
50 | sample_rate = sound.frame_rate
51 |
52 | num_samples = int(sound.frame_count())
53 | samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
54 |
55 | fft_values = rfft(samples)
56 | amplitudes = np.abs(fft_values)
57 |
58 | frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
59 |
60 | return frequencies, amplitudes
61 |
--------------------------------------------------------------------------------
/riffusion/util/image_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Module for converting between spectrograms tensors and spectrogram images, as well as
3 | general helpers for operating on pillow images.
4 | """
5 | import typing as T
6 |
7 | import numpy as np
8 | from PIL import Image
9 |
10 | from riffusion.spectrogram_params import SpectrogramParams
11 |
12 |
13 | def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
14 | """
15 | Compute a spectrogram image from a spectrogram magnitude array.
16 |
17 | This is the inverse of spectrogram_from_image, except for discretization error from
18 | quantizing to uint8.
19 |
20 | Args:
21 | spectrogram: (channels, frequency, time)
22 | power: A power curve to apply to the spectrogram to preserve contrast
23 |
24 | Returns:
25 | image: (frequency, time, channels)
26 | """
27 | # Rescale to 0-1
28 | max_value = np.max(spectrogram)
29 | data = spectrogram / max_value
30 |
31 | # Apply the power curve
32 | data = np.power(data, power)
33 |
34 | # Rescale to 0-255
35 | data = data * 255
36 |
37 | # Invert
38 | data = 255 - data
39 |
40 | # Convert to uint8
41 | data = data.astype(np.uint8)
42 |
43 | # Munge channels into a PIL image
44 | if data.shape[0] == 1:
45 | # TODO(hayk): Do we want to write single channel to disk instead?
46 | image = Image.fromarray(data[0], mode="L").convert("RGB")
47 | elif data.shape[0] == 2:
48 | data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
49 | image = Image.fromarray(data, mode="RGB")
50 | else:
51 | raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
52 |
53 | # Flip Y
54 | image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
55 |
56 | return image
57 |
58 |
59 | def spectrogram_from_image(
60 | image: Image.Image,
61 | power: float = 0.25,
62 | stereo: bool = False,
63 | max_value: float = 30e6,
64 | ) -> np.ndarray:
65 | """
66 | Compute a spectrogram magnitude array from a spectrogram image.
67 |
68 | This is the inverse of image_from_spectrogram, except for discretization error from
69 | quantizing to uint8.
70 |
71 | Args:
72 | image: (frequency, time, channels)
73 | power: The power curve applied to the spectrogram
74 | stereo: Whether the spectrogram encodes stereo data
75 | max_value: The max value of the original spectrogram. In practice doesn't matter.
76 |
77 | Returns:
78 | spectrogram: (channels, frequency, time)
79 | """
80 | # Convert to RGB if single channel
81 | if image.mode in ("P", "L"):
82 | image = image.convert("RGB")
83 |
84 | # Flip Y
85 | image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
86 |
87 | # Munge channels into a numpy array of (channels, frequency, time)
88 | data = np.array(image).transpose(2, 0, 1)
89 | if stereo:
90 | # Take the G and B channels as done in image_from_spectrogram
91 | data = data[[1, 2], :, :]
92 | else:
93 | data = data[0:1, :, :]
94 |
95 | # Convert to floats
96 | data = data.astype(np.float32)
97 |
98 | # Invert
99 | data = 255 - data
100 |
101 | # Rescale to 0-1
102 | data = data / 255
103 |
104 | # Reverse the power curve
105 | data = np.power(data, 1 / power)
106 |
107 | # Rescale to max value
108 | data = data * max_value
109 |
110 | return data
111 |
112 |
113 | def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]:
114 | """
115 | Get the EXIF data from a PIL image as a dict.
116 | """
117 | exif = pil_image.getexif()
118 |
119 | if exif is None or len(exif) == 0:
120 | return {}
121 |
122 | return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()}
123 |
--------------------------------------------------------------------------------
/riffusion/util/torch_util.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def check_device(device: str, backup: str = "cpu") -> str:
8 | """
9 | Check that the device is valid and available. If not,
10 | """
11 | cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
12 | mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
13 |
14 | if cuda_not_found or mps_not_found:
15 | warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
16 | return backup
17 |
18 | return device
19 |
20 |
21 | def slerp(
22 | t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995
23 | ) -> torch.Tensor:
24 | """
25 | Helper function to spherically interpolate two arrays v1 v2.
26 | """
27 | if not isinstance(v0, np.ndarray):
28 | inputs_are_torch = True
29 | input_device = v0.device
30 | v0 = v0.cpu().numpy()
31 | v1 = v1.cpu().numpy()
32 |
33 | dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
34 | if np.abs(dot) > dot_threshold:
35 | v2 = (1 - t) * v0 + t * v1
36 | else:
37 | theta_0 = np.arccos(dot)
38 | sin_theta_0 = np.sin(theta_0)
39 | theta_t = theta_0 * t
40 | sin_theta_t = np.sin(theta_t)
41 | s0 = np.sin(theta_0 - theta_t) / sin_theta_0
42 | s1 = sin_theta_t / sin_theta_0
43 | v2 = s0 * v0 + s1 * v1
44 |
45 | if inputs_are_torch:
46 | v2 = torch.from_numpy(v2).to(input_device)
47 |
48 | return v2
49 |
--------------------------------------------------------------------------------
/seed_images/agile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/agile.png
--------------------------------------------------------------------------------
/seed_images/marim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/marim.png
--------------------------------------------------------------------------------
/seed_images/mask_beat_lines_80.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_beat_lines_80.png
--------------------------------------------------------------------------------
/seed_images/mask_gradient_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_gradient_dark.png
--------------------------------------------------------------------------------
/seed_images/mask_gradient_top_70.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_gradient_top_70.png
--------------------------------------------------------------------------------
/seed_images/mask_gradient_top_fifth_75.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_gradient_top_fifth_75.png
--------------------------------------------------------------------------------
/seed_images/mask_top_third_75.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_top_third_75.png
--------------------------------------------------------------------------------
/seed_images/mask_top_third_95.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_top_third_95.png
--------------------------------------------------------------------------------
/seed_images/motorway.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/motorway.png
--------------------------------------------------------------------------------
/seed_images/og_beat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/og_beat.png
--------------------------------------------------------------------------------
/seed_images/vibes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/vibes.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from setuptools import find_packages, setup
4 |
5 | # Load the version from file
6 | __version__ = Path("VERSION").read_text().strip()
7 |
8 | # Load packages
9 | packages = Path("requirements.txt").read_text().splitlines()
10 |
11 | setup(
12 | name="riffusion",
13 | packages=find_packages(exclude=[]),
14 | version=__version__,
15 | license="MIT",
16 | description="Riffusion - Stable diffusion for real-time music generation",
17 | author="Hayk Martiros",
18 | author_email="hayk.mart@gmail.com",
19 | long_description_content_type="text/markdown",
20 | url="https://github.com/riffusion/riffusion",
21 | keywords=[
22 | "artificial intelligence",
23 | "audio generation",
24 | "music",
25 | "diffusion",
26 | "riffusion",
27 | "deep learning",
28 | "transformers",
29 | ],
30 | install_requires=packages,
31 | package_data={
32 | "riffusion": ["py.typed"],
33 | },
34 | classifiers=[
35 | "Development Status :: 4 - Beta",
36 | "Intended Audience :: Developers",
37 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
38 | "License :: OSI Approved :: MIT License",
39 | "Programming Language :: Python :: 3.9",
40 | ],
41 | )
42 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/__init__.py
--------------------------------------------------------------------------------
/test/audio_to_image_test.py:
--------------------------------------------------------------------------------
1 | import typing as T
2 |
3 | import numpy as np
4 | from PIL import Image
5 |
6 | from riffusion.cli import audio_to_image
7 | from riffusion.spectrogram_params import SpectrogramParams
8 |
9 | from .test_case import TestCase
10 |
11 |
12 | class AudioToImageTest(TestCase):
13 | """
14 | Test riffusion.cli audio-to-image
15 | """
16 |
17 | @classmethod
18 | def default_params(cls) -> T.Dict:
19 | return dict(
20 | step_size_ms=10,
21 | num_frequencies=512,
22 | # TODO(hayk): Change these to [20, 20000] once a model is updated
23 | min_frequency=0,
24 | max_frequency=10000,
25 | stereo=False,
26 | device=cls.DEVICE,
27 | )
28 |
29 | def test_audio_to_image(self) -> None:
30 | """
31 | Test audio-to-image with default params.
32 | """
33 | params = self.default_params()
34 | self.helper_test_with_params(params)
35 |
36 | def test_stereo(self) -> None:
37 | """
38 | Test audio-to-image with stereo=True.
39 | """
40 | params = self.default_params()
41 | params["stereo"] = True
42 | self.helper_test_with_params(params)
43 |
44 | def helper_test_with_params(self, params: T.Dict) -> None:
45 | audio_path = (
46 | self.TEST_DATA_PATH
47 | / "tired_traveler"
48 | / "clips"
49 | / "clip_2_start_103694_ms_duration_5678_ms.wav"
50 | )
51 | output_dir = self.get_tmp_dir("audio_to_image_")
52 |
53 | if params["stereo"]:
54 | stem = f"{audio_path.stem}_stereo"
55 | else:
56 | stem = audio_path.stem
57 |
58 | image_path = output_dir / f"{stem}.png"
59 |
60 | audio_to_image(audio=str(audio_path), image=str(image_path), **params)
61 |
62 | # Check that the image exists
63 | self.assertTrue(image_path.exists())
64 |
65 | pil_image = Image.open(image_path)
66 |
67 | # Check the image mode
68 | self.assertEqual(pil_image.mode, "RGB")
69 |
70 | # Check the image dimensions
71 | duration_ms = 5678
72 | self.assertTrue(str(duration_ms) in audio_path.name)
73 | expected_image_width = round(duration_ms / params["step_size_ms"])
74 | self.assertEqual(pil_image.width, expected_image_width)
75 | self.assertEqual(pil_image.height, params["num_frequencies"])
76 |
77 | # Get channels as numpy arrays
78 | channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
79 | self.assertEqual(len(channels), 3)
80 |
81 | if params["stereo"]:
82 | # Check that the first channel is zero
83 | self.assertTrue(np.all(channels[0] == 0))
84 | else:
85 | # Check that all channels are the same
86 | self.assertTrue(np.all(channels[0] == channels[1]))
87 | self.assertTrue(np.all(channels[0] == channels[2]))
88 |
89 | # Check that the image has exif data
90 | exif = pil_image.getexif()
91 | self.assertIsNotNone(exif)
92 | params_from_exif = SpectrogramParams.from_exif(exif)
93 | expected_params = SpectrogramParams(
94 | stereo=params["stereo"],
95 | step_size_ms=params["step_size_ms"],
96 | num_frequencies=params["num_frequencies"],
97 | max_frequency=params["max_frequency"],
98 | )
99 | self.assertTrue(params_from_exif == expected_params)
100 |
--------------------------------------------------------------------------------
/test/image_to_audio_test.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pydub
4 |
5 | from riffusion.cli import image_to_audio
6 |
7 | from .test_case import TestCase
8 |
9 |
10 | class ImageToAudioTest(TestCase):
11 | """
12 | Test riffusion.cli image-to-audio
13 | """
14 |
15 | def test_image_to_audio_mono(self) -> None:
16 | self.helper_image_to_audio(
17 | song_dir=self.TEST_DATA_PATH / "tired_traveler",
18 | clip_name="clip_2_start_103694_ms_duration_5678_ms",
19 | stereo=False,
20 | )
21 |
22 | def test_image_to_audio_stereo(self) -> None:
23 | self.helper_image_to_audio(
24 | song_dir=self.TEST_DATA_PATH / "tired_traveler",
25 | clip_name="clip_2_start_103694_ms_duration_5678_ms",
26 | stereo=True,
27 | )
28 |
29 | def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
30 | if stereo:
31 | image_stem = clip_name + "_stereo"
32 | else:
33 | image_stem = clip_name
34 |
35 | image_path = song_dir / "images" / f"{image_stem}.png"
36 | output_dir = self.get_tmp_dir("image_to_audio_")
37 | audio_path = output_dir / f"{image_path.stem}.wav"
38 |
39 | image_to_audio(
40 | image=str(image_path),
41 | audio=str(audio_path),
42 | device=self.DEVICE,
43 | )
44 |
45 | # Check that the audio exists
46 | self.assertTrue(audio_path.exists())
47 |
48 | # Load the reconstructed audio and the original clip
49 | segment = pydub.AudioSegment.from_file(str(audio_path))
50 | expected_segment = pydub.AudioSegment.from_file(
51 | str(song_dir / "clips" / f"{clip_name}.wav")
52 | )
53 |
54 | # Check sample rate
55 | self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
56 |
57 | # Check duration
58 | actual_duration_ms = round(segment.duration_seconds * 1000)
59 | expected_duration_ms = round(expected_segment.duration_seconds * 1000)
60 | self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
61 |
62 | # Check the number of channels
63 | self.assertEqual(expected_segment.channels, 2)
64 | if stereo:
65 | self.assertEqual(segment.channels, 2)
66 | else:
67 | self.assertEqual(segment.channels, 1)
68 |
69 |
70 | if __name__ == "__main__":
71 | TestCase.main()
72 |
--------------------------------------------------------------------------------
/test/image_util_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pydub
3 |
4 | from riffusion.spectrogram_converter import SpectrogramConverter
5 | from riffusion.spectrogram_params import SpectrogramParams
6 | from riffusion.util import image_util
7 |
8 | from .test_case import TestCase
9 |
10 |
11 | class ImageUtilTest(TestCase):
12 | """
13 | Test riffusion.util.image_util
14 | """
15 |
16 | def test_spectrogram_to_image_round_trip(self) -> None:
17 | audio_path = (
18 | self.TEST_DATA_PATH
19 | / "tired_traveler"
20 | / "clips"
21 | / "clip_2_start_103694_ms_duration_5678_ms.wav"
22 | )
23 |
24 | # Load up the audio file
25 | segment = pydub.AudioSegment.from_file(audio_path)
26 |
27 | # Convert to mono
28 | segment = segment.set_channels(1)
29 |
30 | # Compute a spectrogram with default params
31 | params = SpectrogramParams(sample_rate=segment.frame_rate)
32 | converter = SpectrogramConverter(params=params, device=self.DEVICE)
33 | spectrogram = converter.spectrogram_from_audio(segment)
34 |
35 | # Compute the image from the spectrogram
36 | image = image_util.image_from_spectrogram(
37 | spectrogram=spectrogram,
38 | power=params.power_for_image,
39 | )
40 |
41 | # Save the max value
42 | max_value = np.max(spectrogram)
43 |
44 | # Compute the spectrogram from the image
45 | spectrogram_reversed = image_util.spectrogram_from_image(
46 | image=image,
47 | max_value=max_value,
48 | power=params.power_for_image,
49 | stereo=params.stereo,
50 | )
51 |
52 | # Check the shapes
53 | self.assertEqual(spectrogram.shape, spectrogram_reversed.shape)
54 |
55 | # Check the max values
56 | self.assertEqual(np.max(spectrogram), np.max(spectrogram_reversed))
57 |
58 | # Check the median values
59 | self.assertTrue(
60 | np.allclose(np.median(spectrogram), np.median(spectrogram_reversed), rtol=0.05)
61 | )
62 |
63 | # Make sure all values are somewhat similar, but allow for discretization error
64 | # TODO(hayk): Investigate error more closely
65 | self.assertTrue(np.allclose(spectrogram, spectrogram_reversed, rtol=0.15))
66 |
--------------------------------------------------------------------------------
/test/print_exif_test.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import io
3 |
4 | from riffusion.cli import print_exif
5 |
6 | from .test_case import TestCase
7 |
8 |
9 | class PrintExifTest(TestCase):
10 | """
11 | Test riffusion.cli print-exif
12 | """
13 |
14 | def test_print_exif(self) -> None:
15 | """
16 | Test print-exif.
17 | """
18 | image_path = (
19 | self.TEST_DATA_PATH
20 | / "tired_traveler"
21 | / "images"
22 | / "clip_2_start_103694_ms_duration_5678_ms.png"
23 | )
24 |
25 | # Redirect stdout
26 | stdout = io.StringIO()
27 | with contextlib.redirect_stdout(stdout):
28 | print_exif(image=str(image_path))
29 |
30 | # Check that a couple of values are printed
31 | self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
32 | self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())
33 |
--------------------------------------------------------------------------------
/test/sample_clips_test.py:
--------------------------------------------------------------------------------
1 | import typing as T
2 |
3 | import pydub
4 |
5 | from riffusion.cli import sample_clips
6 |
7 | from .test_case import TestCase
8 |
9 |
10 | class SampleClipsTest(TestCase):
11 | """
12 | Test riffusion.cli sample-clips
13 | """
14 |
15 | @staticmethod
16 | def default_params() -> T.Dict:
17 | return dict(
18 | num_clips=3,
19 | duration_ms=5678,
20 | mono=False,
21 | extension="wav",
22 | seed=42,
23 | )
24 |
25 | def test_sample_clips(self) -> None:
26 | """
27 | Test sample-clips with default params.
28 | """
29 | params = self.default_params()
30 | self.helper_test_with_params(params)
31 |
32 | def test_mono(self) -> None:
33 | """
34 | Test sample-clips with mono=True.
35 | """
36 | params = self.default_params()
37 | params["mono"] = True
38 | params["num_clips"] = 1
39 | self.helper_test_with_params(params)
40 |
41 | def test_mp3(self) -> None:
42 | """
43 | Test sample-clips with extension=mp3.
44 | """
45 | if pydub.AudioSegment.converter is None:
46 | self.skipTest("skipping, ffmpeg not found")
47 |
48 | params = self.default_params()
49 | params["extension"] = "mp3"
50 | params["num_clips"] = 1
51 | self.helper_test_with_params(params)
52 |
53 | def helper_test_with_params(self, params: T.Dict) -> None:
54 | """
55 | Test sample-clips with the given params.
56 | """
57 | audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
58 | output_dir = self.get_tmp_dir("sample_clips_")
59 |
60 | sample_clips(
61 | audio=str(audio_path),
62 | output_dir=str(output_dir),
63 | **params,
64 | )
65 |
66 | # For each file in output dir
67 | counter = 0
68 | for clip_path in output_dir.iterdir():
69 | # Check that it has the right extension
70 | self.assertEqual(clip_path.suffix, f".{params['extension']}")
71 |
72 | # Check that it has the right duration
73 | segment = pydub.AudioSegment.from_file(clip_path)
74 | self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
75 |
76 | # Check that it has the right number of channels
77 | if params["mono"]:
78 | self.assertEqual(segment.channels, 1)
79 | else:
80 | self.assertEqual(segment.channels, 2)
81 |
82 | counter += 1
83 |
84 | self.assertEqual(counter, params["num_clips"])
85 |
86 |
87 | if __name__ == "__main__":
88 | TestCase.main()
89 |
--------------------------------------------------------------------------------
/test/spectrogram_converter_test.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import typing as T
3 |
4 | import pydub
5 |
6 | from riffusion.spectrogram_converter import SpectrogramConverter
7 | from riffusion.spectrogram_params import SpectrogramParams
8 | from riffusion.util import fft_util
9 |
10 | from .test_case import TestCase
11 |
12 |
13 | class SpectrogramConverterTest(TestCase):
14 | """
15 | Test going from audio to spectrogram to audio, without converting to
16 | an image, to check quality loss of the reconstruction.
17 |
18 | This test allows comparing multiple sets of spectrogram params by listening to output audio
19 | and by plotting their FFTs.
20 | """
21 |
22 | # TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
23 |
24 | def test_round_trip(self) -> None:
25 | audio_path = (
26 | self.TEST_DATA_PATH
27 | / "tired_traveler"
28 | / "clips"
29 | / "clip_2_start_103694_ms_duration_5678_ms.wav"
30 | )
31 | output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
32 |
33 | # Load up the audio file
34 | segment = pydub.AudioSegment.from_file(audio_path)
35 |
36 | # Convert to mono if desired
37 | use_stereo = False
38 | if use_stereo:
39 | assert segment.channels == 2
40 | else:
41 | segment = segment.set_channels(1)
42 |
43 | # Define named sets of parameters
44 | param_sets: T.Dict[str, SpectrogramParams] = {}
45 |
46 | param_sets["default"] = SpectrogramParams(
47 | sample_rate=segment.frame_rate,
48 | stereo=use_stereo,
49 | step_size_ms=10,
50 | min_frequency=20,
51 | max_frequency=20000,
52 | num_frequencies=512,
53 | )
54 |
55 | if self.DEBUG:
56 | param_sets["freq_0_to_10k"] = dataclasses.replace(
57 | param_sets["default"],
58 | min_frequency=0,
59 | max_frequency=10000,
60 | )
61 |
62 | segments: T.Dict[str, pydub.AudioSegment] = {
63 | "original": segment,
64 | }
65 | for name, params in param_sets.items():
66 | converter = SpectrogramConverter(params=params, device=self.DEVICE)
67 | spectrogram = converter.spectrogram_from_audio(segment)
68 | segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
69 |
70 | # Save segments to disk
71 | for name, segment in segments.items():
72 | audio_out = output_dir / f"{name}.wav"
73 | segment.export(audio_out, format="wav")
74 | print(f"Saved {audio_out}")
75 |
76 | # Check params
77 | self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
78 | self.assertEqual(segments["original"].channels, segments["default"].channels)
79 | self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
80 | self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
81 |
82 | # TODO(hayk): Test something more rigorous about the quality of the reconstruction.
83 |
84 | # If debugging, load up a browser tab plotting the FFTs
85 | if self.DEBUG:
86 | fft_util.plot_ffts(segments)
87 |
--------------------------------------------------------------------------------
/test/spectrogram_image_converter_test.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import typing as T
3 |
4 | import pydub
5 | from PIL import Image
6 |
7 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter
8 | from riffusion.spectrogram_params import SpectrogramParams
9 | from riffusion.util import fft_util
10 |
11 | from .test_case import TestCase
12 |
13 |
14 | class SpectrogramImageConverterTest(TestCase):
15 | """
16 | Test going from audio to spectrogram images to audio, testing the quality loss of the
17 | end-to-end pipeline.
18 |
19 | This test allows comparing multiple sets of spectrogram params by listening to output audio
20 | and by plotting their FFTs.
21 |
22 | See spectrogram_converter_test.py for a similar test that does not convert to images.
23 | """
24 |
25 | def test_round_trip(self) -> None:
26 | audio_path = (
27 | self.TEST_DATA_PATH
28 | / "tired_traveler"
29 | / "clips"
30 | / "clip_2_start_103694_ms_duration_5678_ms.wav"
31 | )
32 | output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_")
33 |
34 | # Load up the audio file
35 | segment = pydub.AudioSegment.from_file(audio_path)
36 |
37 | # Convert to mono if desired
38 | use_stereo = False
39 | if use_stereo:
40 | assert segment.channels == 2
41 | else:
42 | segment = segment.set_channels(1)
43 |
44 | # Define named sets of parameters
45 | param_sets: T.Dict[str, SpectrogramParams] = {}
46 |
47 | param_sets["default"] = SpectrogramParams(
48 | sample_rate=segment.frame_rate,
49 | stereo=use_stereo,
50 | step_size_ms=10,
51 | min_frequency=20,
52 | max_frequency=20000,
53 | num_frequencies=512,
54 | )
55 |
56 | if self.DEBUG:
57 | param_sets["freq_0_to_10k"] = dataclasses.replace(
58 | param_sets["default"],
59 | min_frequency=0,
60 | max_frequency=10000,
61 | )
62 |
63 | segments: T.Dict[str, pydub.AudioSegment] = {
64 | "original": segment,
65 | }
66 | images: T.Dict[str, Image.Image] = {}
67 | for name, params in param_sets.items():
68 | converter = SpectrogramImageConverter(params=params, device=self.DEVICE)
69 | images[name] = converter.spectrogram_image_from_audio(segment)
70 | segments[name] = converter.audio_from_spectrogram_image(
71 | image=images[name],
72 | apply_filters=True,
73 | )
74 |
75 | # Save images to disk
76 | for name, image in images.items():
77 | image_out = output_dir / f"{name}.png"
78 | image.save(image_out, exif=image.getexif(), format="PNG")
79 | print(f"Saved {image_out}")
80 |
81 | # Save segments to disk
82 | for name, segment in segments.items():
83 | audio_out = output_dir / f"{name}.wav"
84 | segment.export(audio_out, format="wav")
85 | print(f"Saved {audio_out}")
86 |
87 | # Check params
88 | self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
89 | self.assertEqual(segments["original"].channels, segments["default"].channels)
90 | self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
91 | self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
92 |
93 | # TODO(hayk): Test something more rigorous about the quality of the reconstruction.
94 |
95 | # If debugging, load up a browser tab plotting the FFTs
96 | if self.DEBUG:
97 | fft_util.plot_ffts(segments)
98 |
--------------------------------------------------------------------------------
/test/test_case.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tempfile
4 | import typing as T
5 | import unittest
6 | import warnings
7 | from pathlib import Path
8 |
9 |
10 | class TestCase(unittest.TestCase):
11 | """
12 | Base class for tests.
13 | """
14 |
15 | # Where checked-in test data is stored
16 | TEST_DATA_PATH = Path(__file__).resolve().parent / "test_data"
17 |
18 | # Whether to run tests in debug mode (e.g. don't clean up temporary directories, show plots)
19 | DEBUG = bool(os.environ.get("RIFFUSION_TEST_DEBUG"))
20 |
21 | # Which torch device to use for tests
22 | DEVICE = os.environ.get("RIFFUSION_TEST_DEVICE", "cuda")
23 |
24 | @staticmethod
25 | def main(*args: T.Any, **kwargs: T.Any) -> None:
26 | """
27 | Run the tests.
28 | """
29 | unittest.main(*args, **kwargs)
30 |
31 | @classmethod
32 | def setUpClass(cls):
33 | warnings.filterwarnings("ignore", category=ResourceWarning)
34 |
35 | def get_tmp_dir(self, prefix: str) -> Path:
36 | """
37 | Create a temporary directory.
38 | """
39 | tmp_dir = tempfile.mkdtemp(prefix=prefix)
40 |
41 | # Clean up the temporary directory if not debugging
42 | if not self.DEBUG:
43 | self.addCleanup(lambda: shutil.rmtree(tmp_dir, ignore_errors=True))
44 |
45 | dir_path = Path(tmp_dir)
46 | assert dir_path.is_dir()
47 |
48 | return dir_path
49 |
--------------------------------------------------------------------------------
/test/test_data/README.md:
--------------------------------------------------------------------------------
1 | # Test Data
2 |
3 | ### tired_traveler
4 |
5 | * Song: Tired traveler on the way to home
6 | * Artist: Andrew Codeman
7 | * Source: https://freemusicarchive.org/
8 |
--------------------------------------------------------------------------------
/test/test_data/tired_traveler/clips/clip_0_start_15795_ms_duration_5678_ms.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/clips/clip_0_start_15795_ms_duration_5678_ms.wav
--------------------------------------------------------------------------------
/test/test_data/tired_traveler/clips/clip_1_start_860_ms_duration_5678_ms.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/clips/clip_1_start_860_ms_duration_5678_ms.wav
--------------------------------------------------------------------------------
/test/test_data/tired_traveler/clips/clip_2_start_103694_ms_duration_5678_ms.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/clips/clip_2_start_103694_ms_duration_5678_ms.wav
--------------------------------------------------------------------------------
/test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms.png
--------------------------------------------------------------------------------
/test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms_stereo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms_stereo.png
--------------------------------------------------------------------------------
/test/test_data/tired_traveler/tired_traveler.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/tired_traveler.mp3
--------------------------------------------------------------------------------