├── .github └── workflows │ ├── black.yml │ ├── ci.yml │ ├── mypy.yml │ ├── pip.yml │ └── ruff.yml ├── .gitignore ├── CITATION ├── LICENSE ├── README.md ├── VERSION ├── cog.yaml ├── integrations ├── README.md ├── __init__.py ├── baseten.py └── cog_riffusion.py ├── pyproject.toml ├── requirements.txt ├── requirements_all.txt ├── requirements_dev.txt ├── riffusion ├── __init__.py ├── audio_splitter.py ├── cli.py ├── datatypes.py ├── external │ ├── README.md │ ├── __init__.py │ └── prompt_weighting.py ├── py.typed ├── riffusion_pipeline.py ├── server.py ├── spectrogram_converter.py ├── spectrogram_image_converter.py ├── spectrogram_params.py ├── streamlit │ ├── README.md │ ├── __init__.py │ ├── playground.py │ ├── tasks │ │ ├── __init__.py │ │ ├── audio_to_audio.py │ │ ├── home.py │ │ ├── image_to_audio.py │ │ ├── interpolation.py │ │ ├── sample_clips.py │ │ ├── split_audio.py │ │ ├── text_to_audio.py │ │ └── text_to_audio_batch.py │ └── util.py └── util │ ├── __init__.py │ ├── audio_util.py │ ├── base64_util.py │ ├── fft_util.py │ ├── image_util.py │ └── torch_util.py ├── seed_images ├── agile.png ├── marim.png ├── mask_beat_lines_80.png ├── mask_gradient_dark.png ├── mask_gradient_top_70.png ├── mask_gradient_top_fifth_75.png ├── mask_top_third_75.png ├── mask_top_third_95.png ├── motorway.png ├── og_beat.png └── vibes.png ├── setup.py └── test ├── __init__.py ├── audio_to_image_test.py ├── image_to_audio_test.py ├── image_util_test.py ├── print_exif_test.py ├── sample_clips_test.py ├── spectrogram_converter_test.py ├── spectrogram_image_converter_test.py ├── test_case.py └── test_data ├── README.md └── tired_traveler ├── clips ├── clip_0_start_15795_ms_duration_5678_ms.wav ├── clip_1_start_860_ms_duration_5678_ms.wav └── clip_2_start_103694_ms_duration_5678_ms.wav ├── images ├── clip_2_start_103694_ms_duration_5678_ms.png └── clip_2_start_103694_ms_duration_5678_ms_stereo.png └── tired_traveler.mp3 /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: black format 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | 10 | jobs: 11 | run: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: psf/black@stable 16 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: python test 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | 10 | jobs: 11 | run: 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.9", "3.10"] 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install system packages 28 | run: | 29 | sudo apt-get update 30 | sudo apt-get install -y ffmpeg libsndfile1 31 | 32 | - name: Upgrade pip 33 | run: | 34 | python -m pip install --upgrade pip 35 | 36 | - name: Install pip packages 37 | run: | 38 | pip install -r requirements_all.txt 39 | 40 | - name: Test with unittest 41 | run: | 42 | RIFFUSION_TEST_DEVICE=cpu python -m unittest test/*_test.py 43 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy type check 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | 10 | jobs: 11 | run: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | with: 16 | submodules: recursive 17 | 18 | - uses: jpetrucciani/mypy-check@master 19 | with: 20 | mypy_flags: '--config-file pyproject.toml' 21 | requirements_file: 'requirements_all.txt' 22 | -------------------------------------------------------------------------------- /.github/workflows/pip.yml: -------------------------------------------------------------------------------- 1 | name: pip install 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | 10 | jobs: 11 | run: 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.9", "3.10"] 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Editable Pip Install 28 | run: python -m pip install -e . 29 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: ruff lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | types: [opened, synchronize, reopened, ready_for_review] 9 | 10 | jobs: 11 | run: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: jpetrucciani/ruff-check@main 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # VSCode 10 | .vscode 11 | 12 | # Cog 13 | .cog/ 14 | 15 | # Random stuff I don't care about 16 | .graveyard/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # OSX cruft 40 | .DS_Store 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | @article{Forsgren_Martiros_2022, 2 | author = {Forsgren, Seth* and Martiros, Hayk*}, 3 | title = {{Riffusion - Stable diffusion for real-time music generation}}, 4 | url = {https://riffusion.com/about}, 5 | year = {2022} 6 | } 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Hayk Martiros and Seth Forsgren 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 4 | associated documentation files (the "Software"), to deal in the Software without restriction, 5 | including without limitation the rights to use, copy, modify, merge, publish, distribute, 6 | sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is 7 | furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial 10 | portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT 13 | NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 14 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES 15 | OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :guitar: Riffusion (hobby) 2 | 3 | :no_entry: This project is no longer actively maintained. 4 | 5 | CI status 6 | Python 3.9 | 3.10 7 | MIT License 8 | 9 | Riffusion is a library for real-time music and audio generation with stable diffusion. 10 | 11 | Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/. 12 | 13 | This is the core repository for riffusion image and audio processing code. 14 | 15 | * Diffusion pipeline that performs prompt interpolation combined with image conditioning 16 | * Conversions between spectrogram images and audio clips 17 | * Command-line interface for common tasks 18 | * Interactive app using streamlit 19 | * Flask server to provide model inference via API 20 | * Various third party integrations 21 | 22 | Related repositories: 23 | * Web app: https://github.com/riffusion/riffusion-app 24 | * Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1 25 | 26 | ## Citation 27 | 28 | If you build on this work, please cite it as follows: 29 | 30 | ``` 31 | @article{Forsgren_Martiros_2022, 32 | author = {Forsgren, Seth* and Martiros, Hayk*}, 33 | title = {{Riffusion - Stable diffusion for real-time music generation}}, 34 | url = {https://riffusion.com/about}, 35 | year = {2022} 36 | } 37 | ``` 38 | 39 | ## Install 40 | 41 | Tested in CI with Python 3.9 and 3.10. 42 | 43 | It's highly recommended to set up a virtual Python environment with `conda` or `virtualenv`: 44 | ``` 45 | conda create --name riffusion python=3.9 46 | conda activate riffusion 47 | ``` 48 | 49 | Install Python dependencies: 50 | ``` 51 | python -m pip install -r requirements.txt 52 | ``` 53 | 54 | In order to use audio formats other than WAV, [ffmpeg](https://ffmpeg.org/download.html) is required. 55 | ``` 56 | sudo apt-get install ffmpeg # linux 57 | brew install ffmpeg # mac 58 | conda install -c conda-forge ffmpeg # conda 59 | ``` 60 | 61 | If torchaudio has no backend, you may need to install `libsndfile`. See [this issue](https://github.com/riffusion/riffusion/issues/12). 62 | 63 | If you have an issue, try upgrading [diffusers](https://github.com/huggingface/diffusers). Tested with 0.9 - 0.11. 64 | 65 | Guides: 66 | * [Simple Install Guide for Windows](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/) 67 | 68 | ## Backends 69 | 70 | ### CPU 71 | `cpu` is supported but is quite slow. 72 | 73 | ### CUDA 74 | `cuda` is the recommended and most performant backend. 75 | 76 | To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the 77 | [install guide](https://pytorch.org/get-started/locally/) or 78 | [stable wheels](https://download.pytorch.org/whl/torch_stable.html). 79 | 80 | To generate audio in real-time, you need a GPU that can run stable diffusion with approximately 50 81 | steps in under five seconds, such as a 3090 or A10G. 82 | 83 | Test availability with: 84 | 85 | ```python3 86 | import torch 87 | torch.cuda.is_available() 88 | ``` 89 | 90 | ### MPS 91 | The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU, 92 | particularly for audio processing. You may need to set 93 | `PYTORCH_ENABLE_MPS_FALLBACK=1`. 94 | 95 | In addition, this backend is not deterministic. 96 | 97 | Test availability with: 98 | 99 | ```python3 100 | import torch 101 | torch.backends.mps.is_available() 102 | ``` 103 | 104 | ## Command-line interface 105 | 106 | Riffusion comes with a command line interface for performing common tasks. 107 | 108 | See available commands: 109 | ``` 110 | python -m riffusion.cli -h 111 | ``` 112 | 113 | Get help for a specific command: 114 | ``` 115 | python -m riffusion.cli image-to-audio -h 116 | ``` 117 | 118 | Execute: 119 | ``` 120 | python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav 121 | ``` 122 | 123 | ## Riffusion Playground 124 | 125 | Riffusion contains a [streamlit](https://streamlit.io/) app for interactive use and exploration. 126 | 127 | Run with: 128 | ``` 129 | python -m riffusion.streamlit.playground 130 | ``` 131 | 132 | And access at http://127.0.0.1:8501/ 133 | 134 | Riffusion Playground 135 | 136 | ## Run the model server 137 | 138 | Riffusion can be run as a flask server that provides inference via API. This server enables the [web app](https://github.com/riffusion/riffusion-app) to run locally. 139 | 140 | Run with: 141 | 142 | ``` 143 | python -m riffusion.server --host 127.0.0.1 --port 3013 144 | ``` 145 | 146 | You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format. 147 | 148 | Use the `--device` argument to specify the torch device to use. 149 | 150 | The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request. 151 | 152 | Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API): 153 | ``` 154 | { 155 | "alpha": 0.75, 156 | "num_inference_steps": 50, 157 | "seed_image_id": "og_beat", 158 | 159 | "start": { 160 | "prompt": "church bells on sunday", 161 | "seed": 42, 162 | "denoising": 0.75, 163 | "guidance": 7.0 164 | }, 165 | 166 | "end": { 167 | "prompt": "jazz with piano", 168 | "seed": 123, 169 | "denoising": 0.75, 170 | "guidance": 7.0 171 | } 172 | } 173 | ``` 174 | 175 | Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L54) for the API): 176 | ``` 177 | { 178 | "image": "< base64 encoded JPEG image >", 179 | "audio": "< base64 encoded MP3 clip >" 180 | } 181 | ``` 182 | 183 | ## Tests 184 | Tests live in the `test/` directory and are implemented with `unittest`. 185 | 186 | To run all tests: 187 | ``` 188 | python -m unittest test/*_test.py 189 | ``` 190 | 191 | To run a single test: 192 | ``` 193 | python -m unittest test.audio_to_image_test 194 | ``` 195 | 196 | To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`: 197 | ``` 198 | RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test 199 | ``` 200 | 201 | To run a single test case within a test: 202 | ``` 203 | python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo 204 | ``` 205 | 206 | To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with 207 | `cpu`, `cuda`, and `mps` backends. 208 | 209 | ## Development Guide 210 | Install additional packages for dev with `python -m pip install -r requirements_dev.txt`. 211 | 212 | * Linter: `ruff` 213 | * Formatter: `black` 214 | * Type checker: `mypy` 215 | 216 | These are configured in `pyproject.toml`. 217 | 218 | The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR. 219 | 220 | CI is run through GitHub Actions from `.github/workflows/ci.yml`. 221 | 222 | Contributions are welcome through pull requests. 223 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.3.1 2 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | 8 | # a list of ubuntu apt packages to install 9 | system_packages: 10 | - "ffmpeg" 11 | - "libsndfile1" 12 | 13 | # python version in the form '3.8' or '3.8.12' 14 | python_version: "3.9" 15 | 16 | # a list of packages in the format == 17 | python_packages: 18 | - "accelerate==0.15.0" 19 | - "argh==0.26.2" 20 | - "dacite==1.6.0" 21 | - "diffusers==0.10.2" 22 | - "flask_cors==3.0.10" 23 | - "flask==1.1.2" 24 | - "numpy==1.19.4" 25 | - "pillow==9.1.0" 26 | - "pydub==0.25.1" 27 | - "scipy==1.6.3" 28 | - "torch==1.13.0" 29 | - "torchaudio==0.13.0" 30 | - "transformers==4.25.1" 31 | 32 | # commands run after the environment is setup 33 | # run: 34 | # - "echo env is ready!" 35 | # - "echo another command if needed" 36 | 37 | # predict.py defines how predictions are run on your model 38 | predict: "integrations/cog_riffusion.py:RiffusionPredictor" 39 | -------------------------------------------------------------------------------- /integrations/README.md: -------------------------------------------------------------------------------- 1 | # Integrations 2 | 3 | This package contains integrations of Riffusion into third party apps and deployments. 4 | 5 | ## Baseten 6 | 7 | [Baseten](https://baseten.com) is a platform for building and deploying machine learning models. 8 | 9 | ## Replicate 10 | 11 | To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and 12 | download the model weights: 13 | 14 | cog run python -m integrations.cog_riffusion --download_weights 15 | 16 | Then you can run predictions: 17 | 18 | cog predict -i prompt_a="funky synth solo" 19 | 20 | You can also view the model on Replicate [here](https://replicate.com/riffusion/riffusion). Owners 21 | can push an updated version of the model like so: 22 | 23 | # download weights locally if you haven't already 24 | cog run python -m integrations.cog_riffusion --download_weights 25 | 26 | cog login 27 | cog push r8.im/riffusion/riffusion 28 | -------------------------------------------------------------------------------- /integrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/integrations/__init__.py -------------------------------------------------------------------------------- /integrations/baseten.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file can be used to build a Truss for deployment with Baseten. 3 | If used, it should be renamed to model.py and placed alongside the other 4 | files from /riffusion in the standard /model directory of the Truss. 5 | 6 | For more on the Truss file format, see https://truss.baseten.co/ 7 | """ 8 | 9 | import typing as T 10 | 11 | import dacite 12 | import torch 13 | from huggingface_hub import snapshot_download 14 | 15 | from riffusion.datatypes import InferenceInput 16 | from riffusion.riffusion_pipeline import RiffusionPipeline 17 | from riffusion.server import compute_request 18 | 19 | 20 | class Model: 21 | """ 22 | Baseten Truss model class for riffusion. 23 | 24 | See: https://truss.baseten.co/reference/structure#model.py 25 | """ 26 | 27 | def __init__(self, **kwargs) -> None: 28 | self._data_dir = kwargs["data_dir"] 29 | self._config = kwargs["config"] 30 | self._pipeline = None 31 | self._vae = None 32 | 33 | self.checkpoint_name = "riffusion/riffusion-model-v1" 34 | 35 | # Download entire seed image folder from huggingface hub 36 | self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png") 37 | 38 | def load(self): 39 | """ 40 | Load the model. Guaranteed to be called before `predict`. 41 | """ 42 | self._pipeline = RiffusionPipeline.load_checkpoint( 43 | checkpoint=self.checkpoint_name, 44 | use_traced_unet=True, 45 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 46 | ) 47 | 48 | def preprocess(self, request: T.Dict) -> T.Dict: 49 | """ 50 | Incorporate pre-processing required by the model if desired here. 51 | 52 | These might be feature transformations that are tightly coupled to the model. 53 | """ 54 | return request 55 | 56 | def predict(self, request: T.Dict) -> T.Dict[str, T.List]: 57 | """ 58 | This is the main function that is called. 59 | """ 60 | assert self._pipeline is not None, "Model pipeline not loaded" 61 | 62 | try: 63 | inputs = dacite.from_dict(InferenceInput, request) 64 | except dacite.exceptions.WrongTypeError as exception: 65 | return str(exception), 400 66 | except dacite.exceptions.MissingValueError as exception: 67 | return str(exception), 400 68 | 69 | # NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4 70 | with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False): 71 | response = compute_request( 72 | inputs=inputs, 73 | pipeline=self._pipeline, 74 | seed_images_dir=self._seed_images_dir, 75 | ) 76 | 77 | return response 78 | 79 | def postprocess(self, request: T.Dict) -> T.Dict: 80 | """ 81 | Incorporate post-processing required by the model if desired here. 82 | """ 83 | return request 84 | -------------------------------------------------------------------------------- /integrations/cog_riffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prediction interface for Cog ⚙️ 3 | https://github.com/replicate/cog/blob/main/docs/python.md 4 | """ 5 | 6 | import argparse 7 | import os 8 | import shutil 9 | import typing as T 10 | 11 | import numpy as np 12 | import PIL 13 | import torch 14 | from cog import BaseModel, BasePredictor, Input, Path 15 | 16 | from riffusion.datatypes import InferenceInput, PromptInput 17 | from riffusion.riffusion_pipeline import RiffusionPipeline 18 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter 19 | from riffusion.spectrogram_params import SpectrogramParams 20 | 21 | MODEL_ID = "riffusion/riffusion-model-v1" 22 | MODEL_CACHE = "riffusion-cache" 23 | 24 | # Where built-in seed images are stored 25 | SEED_IMAGES_DIR = Path("./seed_images") 26 | SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val] 27 | SEED_IMAGES.sort() 28 | 29 | 30 | class Output(BaseModel): 31 | """ 32 | Output class for riffusion predictions 33 | """ 34 | 35 | audio: Path 36 | spectrogram: Path 37 | error: T.Optional[str] = None 38 | 39 | 40 | class RiffusionPredictor(BasePredictor): 41 | """ 42 | Implementation of cog predictor object s.t. we can run riffusion predictions w/cog. 43 | 44 | See README & https://github.com/replicate/cog for details 45 | """ 46 | 47 | def setup(self, local_files_only=True): 48 | """ 49 | Loads the model onto GPU from local cache. 50 | """ 51 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 52 | 53 | self.model = RiffusionPipeline.load_checkpoint( 54 | checkpoint=MODEL_ID, 55 | use_traced_unet=True, 56 | device=self.device, 57 | local_files_only=local_files_only, 58 | cache_dir=MODEL_CACHE, 59 | ) 60 | 61 | def predict( 62 | self, 63 | prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"), 64 | denoising: float = Input( 65 | description="How much to transform input spectrogram", 66 | default=0.75, 67 | ge=0, 68 | le=1, 69 | ), 70 | prompt_b: str = Input( 71 | description="The second prompt to interpolate with the first," 72 | "leave blank if no interpolation", 73 | default=None, 74 | ), 75 | alpha: float = Input( 76 | description="Interpolation alpha if using two prompts." 77 | "A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully", 78 | default=0.5, 79 | ge=0, 80 | le=1, 81 | ), 82 | num_inference_steps: int = Input( 83 | description="Number of steps to run the diffusion model", default=50, ge=1 84 | ), 85 | seed_image_id: str = Input( 86 | description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES 87 | ), 88 | ) -> Output: 89 | """ 90 | Runs riffusion inference. 91 | """ 92 | # Load the seed image by ID 93 | init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png") 94 | if not init_image_path.is_file(): 95 | return Output(error=f"Invalid seed image: {seed_image_id}") 96 | init_image = PIL.Image.open(str(init_image_path)).convert("RGB") 97 | 98 | # fake max ints 99 | seed_a = np.random.randint(0, 2147483647) 100 | seed_b = np.random.randint(0, 2147483647) 101 | 102 | start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising) 103 | if not prompt_b: # no transition 104 | prompt_b = prompt_a 105 | alpha = 0 106 | end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising) 107 | riffusion_input = InferenceInput( 108 | start=start, 109 | end=end, 110 | alpha=alpha, 111 | num_inference_steps=num_inference_steps, 112 | seed_image_id=seed_image_id, 113 | ) 114 | 115 | # Execute the model to get the spectrogram image 116 | image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None) 117 | 118 | # Reconstruct audio from the image 119 | params = SpectrogramParams() 120 | converter = SpectrogramImageConverter(params=params, device=self.device) 121 | segment = converter.audio_from_spectrogram_image(image) 122 | 123 | if not os.path.exists("out/"): 124 | os.mkdir("out") 125 | 126 | out_img_path = "out/spectrogram.jpg" 127 | image.save("out/spectrogram.jpg", exif=image.getexif()) 128 | 129 | out_wav_path = "out/gen_sound.wav" 130 | segment.export(out_wav_path, format="wav") 131 | 132 | return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path)) 133 | 134 | 135 | # TODO(hayk): Can we get rid of the below functions and incorporate into 136 | # RiffusionPipeline.load_checkpoint? 137 | 138 | 139 | def download_weights(): 140 | """ 141 | Clears local cache & downloads riffusion weights 142 | """ 143 | if os.path.exists(MODEL_CACHE): 144 | shutil.rmtree(MODEL_CACHE) 145 | os.makedirs(MODEL_CACHE) 146 | 147 | pred = RiffusionPredictor() 148 | pred.setup(local_files_only=False) 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument( 154 | "--download_weights", action="store_true", help="Download and cache weights" 155 | ) 156 | args = parser.parse_args() 157 | if args.download_weights: 158 | download_weights() 159 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | [tool.ruff] 5 | line-length = 100 6 | 7 | # Which rules to run 8 | select = [ 9 | # Pyflakes 10 | "F", 11 | # Pycodestyle 12 | "E", 13 | "W", 14 | # isort 15 | "I001" 16 | ] 17 | 18 | ignore = [] 19 | 20 | # Exclude a variety of commonly ignored directories. 21 | exclude = [ 22 | ".bzr", 23 | ".direnv", 24 | ".eggs", 25 | ".git", 26 | ".hg", 27 | ".mypy_cache", 28 | ".nox", 29 | ".pants.d", 30 | ".ruff_cache", 31 | ".svn", 32 | ".tox", 33 | ".venv", 34 | "__pypackages__", 35 | "_build", 36 | "buck-out", 37 | "build", 38 | "dist", 39 | "node_modules", 40 | "venv", 41 | ] 42 | per-file-ignores = {} 43 | 44 | # Allow unused variables when underscore-prefixed. 45 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 46 | 47 | # Assume Python 3.10. 48 | target-version = "py310" 49 | 50 | [tool.ruff.mccabe] 51 | # Unlike Flake8, default to a complexity level of 10. 52 | max-complexity = 10 53 | 54 | [tool.mypy] 55 | python_version = "3.10" 56 | 57 | [[tool.mypy.overrides]] 58 | module = "argh.*" 59 | ignore_missing_imports = true 60 | 61 | [[tool.mypy.overrides]] 62 | module = "cog.*" 63 | ignore_missing_imports = true 64 | 65 | [[tool.mypy.overrides]] 66 | module = "diffusers.*" 67 | ignore_missing_imports = true 68 | 69 | [[tool.mypy.overrides]] 70 | module = "huggingface_hub.*" 71 | ignore_missing_imports = true 72 | 73 | [[tool.mypy.overrides]] 74 | module = "numpy.*" 75 | ignore_missing_imports = true 76 | 77 | [[tool.mypy.overrides]] 78 | module = "plotly.*" 79 | ignore_missing_imports = true 80 | 81 | [[tool.mypy.overrides]] 82 | module = "pydub.*" 83 | ignore_missing_imports = true 84 | 85 | [[tool.mypy.overrides]] 86 | module = "scipy.fft.*" 87 | ignore_missing_imports = true 88 | 89 | [[tool.mypy.overrides]] 90 | module = "scipy.io.*" 91 | ignore_missing_imports = true 92 | 93 | [[tool.mypy.overrides]] 94 | module = "torchaudio.*" 95 | ignore_missing_imports = true 96 | 97 | [[tool.mypy.overrides]] 98 | module = "transformers.*" 99 | ignore_missing_imports = true 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | argh 3 | dacite 4 | demucs 5 | diffusers==0.9.0 6 | flask 7 | flask_cors 8 | numpy 9 | pillow>=9.1.0 10 | plotly 11 | pydub 12 | pysoundfile 13 | scipy 14 | soundfile 15 | sox 16 | streamlit>=1.18.0 17 | torch 18 | torchaudio 19 | torchvision 20 | transformers 21 | -------------------------------------------------------------------------------- /requirements_all.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | -r requirements_dev.txt 3 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | ipdb 3 | mypy 4 | ruff 5 | types-Flask-Cors 6 | types-Pillow 7 | types-requests 8 | types-setuptools 9 | types-tqdm 10 | -------------------------------------------------------------------------------- /riffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/__init__.py -------------------------------------------------------------------------------- /riffusion/audio_splitter.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import subprocess 3 | import tempfile 4 | import typing as T 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import pydub 9 | import torch 10 | import torchaudio 11 | from torchaudio.transforms import Fade 12 | 13 | from riffusion.util import audio_util 14 | 15 | 16 | def split_audio( 17 | segment: pydub.AudioSegment, 18 | model_name: str = "htdemucs_6s", 19 | extension: str = "wav", 20 | jobs: int = 4, 21 | device: str = "cuda", 22 | ) -> T.Dict[str, pydub.AudioSegment]: 23 | """ 24 | Split audio into stems using demucs. 25 | """ 26 | tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_")) 27 | 28 | # Save the audio to a temporary file 29 | audio_path = tmp_dir / "audio.mp3" 30 | segment.export(audio_path, format="mp3") 31 | 32 | # Assemble command 33 | command = [ 34 | "demucs", 35 | str(audio_path), 36 | "--name", 37 | model_name, 38 | "--out", 39 | str(tmp_dir), 40 | "--jobs", 41 | str(jobs), 42 | "--device", 43 | device if device != "mps" else "cpu", 44 | ] 45 | print(" ".join(command)) 46 | 47 | if extension == "mp3": 48 | command.append("--mp3") 49 | 50 | # Run demucs 51 | subprocess.run( 52 | command, 53 | check=True, 54 | ) 55 | 56 | # Load the stems 57 | stems = {} 58 | for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"): 59 | stem = pydub.AudioSegment.from_file(stem_path) 60 | stems[stem_path.stem] = stem 61 | 62 | # Delete tmp dir 63 | shutil.rmtree(tmp_dir) 64 | 65 | return stems 66 | 67 | 68 | class AudioSplitter: 69 | """ 70 | Split audio into instrument stems like {drums, bass, vocals, etc.} 71 | 72 | NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer 73 | model in the demucs repo. See the function above. Probably just delete this. 74 | 75 | See: 76 | https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html 77 | """ 78 | 79 | def __init__( 80 | self, 81 | segment_length_s: float = 10.0, 82 | overlap_s: float = 0.1, 83 | device: str = "cuda", 84 | ): 85 | self.segment_length_s = segment_length_s 86 | self.overlap_s = overlap_s 87 | self.device = device 88 | 89 | self.model = self.load_model().to(device) 90 | 91 | @staticmethod 92 | def load_model(model_path: str = "models/hdemucs_high_trained.pt") -> torchaudio.models.HDemucs: 93 | """ 94 | Load the trained HDEMUCS pytorch model. 95 | """ 96 | # NOTE(hayk): The sources are baked into the pretrained model and can't be changed 97 | model = torchaudio.models.hdemucs_high(sources=["drums", "bass", "other", "vocals"]) 98 | 99 | path = torchaudio.utils.download_asset(model_path) 100 | state_dict = torch.load(path) 101 | model.load_state_dict(state_dict) 102 | model.eval() 103 | 104 | return model 105 | 106 | def split(self, audio: pydub.AudioSegment) -> T.Dict[str, pydub.AudioSegment]: 107 | """ 108 | Split the given audio segment into instrument stems. 109 | """ 110 | if audio.channels == 1: 111 | audio_stereo = audio.set_channels(2) 112 | elif audio.channels == 2: 113 | audio_stereo = audio 114 | else: 115 | raise ValueError(f"Audio must be stereo, but got {audio.channels} channels") 116 | 117 | # Get as (samples, channels) float numpy array 118 | waveform_np = np.array(audio_stereo.get_array_of_samples()) 119 | waveform_np = waveform_np.reshape(-1, audio_stereo.channels) 120 | waveform_np_float = waveform_np.astype(np.float32) 121 | 122 | # To torch and channels-first 123 | waveform = torch.from_numpy(waveform_np_float).to(self.device) 124 | waveform = waveform.transpose(1, 0) 125 | 126 | # Normalize 127 | ref = waveform.mean(0) 128 | waveform = (waveform - ref.mean()) / ref.std() 129 | 130 | # Split 131 | sources = self.separate_sources( 132 | waveform[None], 133 | sample_rate=audio.frame_rate, 134 | )[0] 135 | 136 | # De-normalize 137 | sources = sources * ref.std() + ref.mean() 138 | 139 | # To numpy 140 | sources_np = sources.cpu().numpy().astype(waveform_np.dtype) 141 | 142 | # Convert to pydub 143 | stem_segments = [ 144 | audio_util.audio_from_waveform(waveform, audio.frame_rate) for waveform in sources_np 145 | ] 146 | 147 | # Convert back to mono if necessary 148 | if audio.channels == 1: 149 | stem_segments = [stem.set_channels(1) for stem in stem_segments] 150 | 151 | return dict(zip(self.model.sources, stem_segments)) 152 | 153 | def separate_sources( 154 | self, 155 | waveform: torch.Tensor, 156 | sample_rate: int = 44100, 157 | ): 158 | """ 159 | Apply model to a given waveform in chunks. Use fade and overlap to smooth the edges. 160 | """ 161 | batch, channels, length = waveform.shape 162 | 163 | chunk_len = int(sample_rate * self.segment_length_s * (1 + self.overlap_s)) 164 | start = 0 165 | end = chunk_len 166 | overlap_frames = self.overlap_s * sample_rate 167 | fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear") 168 | 169 | final = torch.zeros(batch, len(self.model.sources), channels, length, device=self.device) 170 | 171 | # TODO(hayk): Improve this code, which came from the torchaudio docs 172 | while start < length - overlap_frames: 173 | chunk = waveform[:, :, start:end] 174 | with torch.no_grad(): 175 | out = self.model.forward(chunk) 176 | out = fade(out) 177 | final[:, :, :, start:end] += out 178 | if start == 0: 179 | fade.fade_in_len = int(overlap_frames) 180 | start += int(chunk_len - overlap_frames) 181 | else: 182 | start += chunk_len 183 | end += chunk_len 184 | if end >= length: 185 | fade.fade_out_len = 0 186 | 187 | return final 188 | -------------------------------------------------------------------------------- /riffusion/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command line tools for riffusion. 3 | """ 4 | 5 | import random 6 | import typing as T 7 | from multiprocessing.pool import ThreadPool 8 | from pathlib import Path 9 | 10 | import argh 11 | import numpy as np 12 | import pydub 13 | import tqdm 14 | from PIL import Image 15 | 16 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter 17 | from riffusion.spectrogram_params import SpectrogramParams 18 | from riffusion.util import image_util 19 | 20 | 21 | @argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image") 22 | @argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image") 23 | def audio_to_image( 24 | *, 25 | audio: str, 26 | image: str, 27 | step_size_ms: int = 10, 28 | num_frequencies: int = 512, 29 | min_frequency: int = 0, 30 | max_frequency: int = 10000, 31 | window_duration_ms: int = 100, 32 | padded_duration_ms: int = 400, 33 | power_for_image: float = 0.25, 34 | stereo: bool = False, 35 | device: str = "cuda", 36 | ): 37 | """ 38 | Compute a spectrogram image from a waveform. 39 | """ 40 | segment = pydub.AudioSegment.from_file(audio) 41 | 42 | params = SpectrogramParams( 43 | sample_rate=segment.frame_rate, 44 | stereo=stereo, 45 | window_duration_ms=window_duration_ms, 46 | padded_duration_ms=padded_duration_ms, 47 | step_size_ms=step_size_ms, 48 | min_frequency=min_frequency, 49 | max_frequency=max_frequency, 50 | num_frequencies=num_frequencies, 51 | power_for_image=power_for_image, 52 | ) 53 | 54 | converter = SpectrogramImageConverter(params=params, device=device) 55 | 56 | pil_image = converter.spectrogram_image_from_audio(segment) 57 | 58 | pil_image.save(image, exif=pil_image.getexif(), format="PNG") 59 | print(f"Wrote {image}") 60 | 61 | 62 | def print_exif(*, image: str) -> None: 63 | """ 64 | Print the params of a spectrogram image as saved in the exif data. 65 | """ 66 | pil_image = Image.open(image) 67 | exif_data = image_util.exif_from_image(pil_image) 68 | 69 | for name, value in exif_data.items(): 70 | print(f"{name:<20} = {value:>15}") 71 | 72 | 73 | def image_to_audio(*, image: str, audio: str, device: str = "cuda"): 74 | """ 75 | Reconstruct an audio clip from a spectrogram image. 76 | """ 77 | pil_image = Image.open(image) 78 | 79 | # Get parameters from image exif 80 | img_exif = pil_image.getexif() 81 | assert img_exif is not None 82 | 83 | try: 84 | params = SpectrogramParams.from_exif(exif=img_exif) 85 | except (KeyError, AttributeError): 86 | print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.") 87 | params = SpectrogramParams() 88 | 89 | converter = SpectrogramImageConverter(params=params, device=device) 90 | segment = converter.audio_from_spectrogram_image(pil_image) 91 | 92 | extension = Path(audio).suffix[1:] 93 | segment.export(audio, format=extension) 94 | 95 | print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)") 96 | 97 | 98 | def sample_clips( 99 | *, 100 | audio: str, 101 | output_dir: str, 102 | num_clips: int = 1, 103 | duration_ms: int = 5120, 104 | mono: bool = False, 105 | extension: str = "wav", 106 | seed: int = -1, 107 | ): 108 | """ 109 | Slice an audio file into clips of the given duration. 110 | """ 111 | if seed >= 0: 112 | np.random.seed(seed) 113 | 114 | segment = pydub.AudioSegment.from_file(audio) 115 | 116 | if mono: 117 | segment = segment.set_channels(1) 118 | 119 | output_dir_path = Path(output_dir) 120 | if not output_dir_path.exists(): 121 | output_dir_path.mkdir(parents=True) 122 | 123 | segment_duration_ms = int(segment.duration_seconds * 1000) 124 | for i in range(num_clips): 125 | clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) 126 | clip = segment[clip_start_ms : clip_start_ms + duration_ms] 127 | 128 | clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}" 129 | clip_path = output_dir_path / clip_name 130 | clip.export(clip_path, format=extension) 131 | print(f"Wrote {clip_path}") 132 | 133 | 134 | def audio_to_images_batch( 135 | *, 136 | audio_dir: str, 137 | output_dir: str, 138 | image_extension: str = "jpg", 139 | step_size_ms: int = 10, 140 | num_frequencies: int = 512, 141 | min_frequency: int = 0, 142 | max_frequency: int = 10000, 143 | power_for_image: float = 0.25, 144 | mono: bool = False, 145 | sample_rate: int = 44100, 146 | device: str = "cuda", 147 | num_threads: T.Optional[int] = None, 148 | limit: int = -1, 149 | ): 150 | """ 151 | Process audio clips into spectrograms in batch, multi-threaded. 152 | """ 153 | audio_paths = list(Path(audio_dir).glob("*")) 154 | audio_paths.sort() 155 | 156 | if limit > 0: 157 | audio_paths = audio_paths[:limit] 158 | 159 | output_path = Path(output_dir) 160 | output_path.mkdir(parents=True, exist_ok=True) 161 | 162 | params = SpectrogramParams( 163 | step_size_ms=step_size_ms, 164 | num_frequencies=num_frequencies, 165 | min_frequency=min_frequency, 166 | max_frequency=max_frequency, 167 | power_for_image=power_for_image, 168 | stereo=not mono, 169 | sample_rate=sample_rate, 170 | ) 171 | 172 | converter = SpectrogramImageConverter(params=params, device=device) 173 | 174 | def process_one(audio_path: Path) -> None: 175 | # Load 176 | try: 177 | segment = pydub.AudioSegment.from_file(str(audio_path)) 178 | except Exception: 179 | return 180 | 181 | # TODO(hayk): Sanity checks on clip 182 | 183 | if mono and segment.channels != 1: 184 | segment = segment.set_channels(1) 185 | elif not mono and segment.channels != 2: 186 | segment = segment.set_channels(2) 187 | 188 | # Frame rate 189 | if segment.frame_rate != params.sample_rate: 190 | segment = segment.set_frame_rate(params.sample_rate) 191 | 192 | # Convert 193 | image = converter.spectrogram_image_from_audio(segment) 194 | 195 | # Save 196 | image_path = output_path / f"{audio_path.stem}.{image_extension}" 197 | image_format = {"jpg": "JPEG", "jpeg": "JPEG", "png": "PNG"}[image_extension] 198 | image.save(image_path, exif=image.getexif(), format=image_format) 199 | 200 | # Create thread pool 201 | pool = ThreadPool(processes=num_threads) 202 | with tqdm.tqdm(total=len(audio_paths)) as pbar: 203 | for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)): 204 | pbar.update() 205 | 206 | 207 | def sample_clips_batch( 208 | *, 209 | audio_dir: str, 210 | output_dir: str, 211 | num_clips_per_file: int = 1, 212 | duration_ms: int = 5120, 213 | mono: bool = False, 214 | extension: str = "mp3", 215 | num_threads: T.Optional[int] = None, 216 | glob: str = "*", 217 | limit: int = -1, 218 | seed: int = -1, 219 | ): 220 | """ 221 | Sample short clips from a directory of audio files, multi-threaded. 222 | """ 223 | audio_paths = list(Path(audio_dir).glob(glob)) 224 | audio_paths.sort() 225 | 226 | # Exclude json 227 | audio_paths = [p for p in audio_paths if p.suffix != ".json"] 228 | 229 | if limit > 0: 230 | audio_paths = audio_paths[:limit] 231 | 232 | output_path = Path(output_dir) 233 | output_path.mkdir(parents=True, exist_ok=True) 234 | 235 | if seed >= 0: 236 | random.seed(seed) 237 | 238 | def process_one(audio_path: Path) -> None: 239 | try: 240 | segment = pydub.AudioSegment.from_file(str(audio_path)) 241 | except Exception: 242 | return 243 | 244 | if mono: 245 | segment = segment.set_channels(1) 246 | 247 | segment_duration_ms = int(segment.duration_seconds * 1000) 248 | for i in range(num_clips_per_file): 249 | try: 250 | clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) 251 | except ValueError: 252 | continue 253 | 254 | clip = segment[clip_start_ms : clip_start_ms + duration_ms] 255 | 256 | clip_name = ( 257 | f"{audio_path.stem}_{i}_" 258 | f"start_{clip_start_ms}_ms_dur_{duration_ms}_ms.{extension}" 259 | ) 260 | clip.export(output_path / clip_name, format=extension) 261 | 262 | pool = ThreadPool(processes=num_threads) 263 | with tqdm.tqdm(total=len(audio_paths)) as pbar: 264 | for result in pool.imap_unordered(process_one, audio_paths): 265 | pbar.update() 266 | 267 | 268 | if __name__ == "__main__": 269 | argh.dispatch_commands( 270 | [ 271 | audio_to_image, 272 | image_to_audio, 273 | sample_clips, 274 | print_exif, 275 | audio_to_images_batch, 276 | sample_clips_batch, 277 | ] 278 | ) 279 | -------------------------------------------------------------------------------- /riffusion/datatypes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data model for the riffusion API. 3 | """ 4 | from __future__ import annotations 5 | 6 | import typing as T 7 | from dataclasses import dataclass 8 | 9 | 10 | @dataclass(frozen=True) 11 | class PromptInput: 12 | """ 13 | Parameters for one end of interpolation. 14 | """ 15 | 16 | # Text prompt fed into a CLIP model 17 | prompt: str 18 | 19 | # Random seed for denoising 20 | seed: int 21 | 22 | # Negative prompt to avoid (optional) 23 | negative_prompt: T.Optional[str] = None 24 | 25 | # Denoising strength 26 | denoising: float = 0.75 27 | 28 | # Classifier-free guidance strength 29 | guidance: float = 7.0 30 | 31 | 32 | @dataclass(frozen=True) 33 | class InferenceInput: 34 | """ 35 | Parameters for a single run of the riffusion model, interpolating between 36 | a start and end set of PromptInputs. This is the API required for a request 37 | to the model server. 38 | """ 39 | 40 | # Start point of interpolation 41 | start: PromptInput 42 | 43 | # End point of interpolation 44 | end: PromptInput 45 | 46 | # Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1 47 | # uses end fully. 48 | alpha: float 49 | 50 | # Number of inner loops of the diffusion model 51 | num_inference_steps: int = 50 52 | 53 | # Which seed image to use 54 | seed_image_id: str = "og_beat" 55 | 56 | # ID of mask image to use 57 | mask_image_id: T.Optional[str] = None 58 | 59 | 60 | @dataclass(frozen=True) 61 | class InferenceOutput: 62 | """ 63 | Response from the model inference server. 64 | """ 65 | 66 | # base64 encoded spectrogram image as a JPEG 67 | image: str 68 | 69 | # base64 encoded audio clip as an MP3 70 | audio: str 71 | 72 | # The duration of the audio clip 73 | duration_s: float 74 | -------------------------------------------------------------------------------- /riffusion/external/README.md: -------------------------------------------------------------------------------- 1 | # external 2 | 3 | This package contains scripts and tools from external sources. 4 | -------------------------------------------------------------------------------- /riffusion/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/external/__init__.py -------------------------------------------------------------------------------- /riffusion/external/prompt_weighting.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is taken from the diffusers community pipeline: 3 | 4 | https://github.com/huggingface/diffusers/blob/f242eba4fdc5b76dc40d3a9c01ba49b2c74b9796/examples/community/lpw_stable_diffusion.py 5 | 6 | License: Apache 2.0 7 | """ 8 | # ruff: noqa 9 | # mypy: ignore-errors 10 | 11 | import logging 12 | import re 13 | import typing as T 14 | 15 | import torch 16 | 17 | from diffusers import StableDiffusionPipeline 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | re_attention = re.compile( 24 | r""" 25 | \\\(| 26 | \\\)| 27 | \\\[| 28 | \\]| 29 | \\\\| 30 | \\| 31 | \(| 32 | \[| 33 | :([+-]?[.\d]+)\)| 34 | \)| 35 | ]| 36 | [^\\()\[\]:]+| 37 | : 38 | """, 39 | re.X, 40 | ) 41 | 42 | 43 | def parse_prompt_attention(text): 44 | """ 45 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight. 46 | Accepted tokens are: 47 | (abc) - increases attention to abc by a multiplier of 1.1 48 | (abc:3.12) - increases attention to abc by a multiplier of 3.12 49 | [abc] - decreases attention to abc by a multiplier of 1.1 50 | \( - literal character '(' 51 | \[ - literal character '[' 52 | \) - literal character ')' 53 | \] - literal character ']' 54 | \\ - literal character '\' 55 | anything else - just text 56 | >>> parse_prompt_attention('normal text') 57 | [['normal text', 1.0]] 58 | >>> parse_prompt_attention('an (important) word') 59 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]] 60 | >>> parse_prompt_attention('(unbalanced') 61 | [['unbalanced', 1.1]] 62 | >>> parse_prompt_attention('\(literal\]') 63 | [['(literal]', 1.0]] 64 | >>> parse_prompt_attention('(unnecessary)(parens)') 65 | [['unnecessaryparens', 1.1]] 66 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') 67 | [['a ', 1.0], 68 | ['house', 1.5730000000000004], 69 | [' ', 1.1], 70 | ['on', 1.0], 71 | [' a ', 1.1], 72 | ['hill', 0.55], 73 | [', sun, ', 1.1], 74 | ['sky', 1.4641000000000006], 75 | ['.', 1.1]] 76 | """ 77 | 78 | res = [] 79 | round_brackets = [] 80 | square_brackets = [] 81 | 82 | round_bracket_multiplier = 1.1 83 | square_bracket_multiplier = 1 / 1.1 84 | 85 | def multiply_range(start_position, multiplier): 86 | for p in range(start_position, len(res)): 87 | res[p][1] *= multiplier 88 | 89 | for m in re_attention.finditer(text): 90 | text = m.group(0) 91 | weight = m.group(1) 92 | 93 | if text.startswith("\\"): 94 | res.append([text[1:], 1.0]) 95 | elif text == "(": 96 | round_brackets.append(len(res)) 97 | elif text == "[": 98 | square_brackets.append(len(res)) 99 | elif weight is not None and len(round_brackets) > 0: 100 | multiply_range(round_brackets.pop(), float(weight)) 101 | elif text == ")" and len(round_brackets) > 0: 102 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 103 | elif text == "]" and len(square_brackets) > 0: 104 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 105 | else: 106 | res.append([text, 1.0]) 107 | 108 | for pos in round_brackets: 109 | multiply_range(pos, round_bracket_multiplier) 110 | 111 | for pos in square_brackets: 112 | multiply_range(pos, square_bracket_multiplier) 113 | 114 | if len(res) == 0: 115 | res = [["", 1.0]] 116 | 117 | # merge runs of identical weights 118 | i = 0 119 | while i + 1 < len(res): 120 | if res[i][1] == res[i + 1][1]: 121 | res[i][0] += res[i + 1][0] 122 | res.pop(i + 1) 123 | else: 124 | i += 1 125 | 126 | return res 127 | 128 | 129 | def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int): 130 | r""" 131 | Tokenize a list of prompts and return its tokens with weights of each token. 132 | No padding, starting or ending token is included. 133 | """ 134 | tokens = [] 135 | weights = [] 136 | truncated = False 137 | for text in prompt: 138 | texts_and_weights = parse_prompt_attention(text) 139 | text_token = [] 140 | text_weight = [] 141 | for word, weight in texts_and_weights: 142 | # tokenize and discard the starting and the ending token 143 | token = pipe.tokenizer(word).input_ids[1:-1] 144 | text_token += token 145 | # copy the weight by length of token 146 | text_weight += [weight] * len(token) 147 | # stop if the text is too long (longer than truncation limit) 148 | if len(text_token) > max_length: 149 | truncated = True 150 | break 151 | # truncate 152 | if len(text_token) > max_length: 153 | truncated = True 154 | text_token = text_token[:max_length] 155 | text_weight = text_weight[:max_length] 156 | tokens.append(text_token) 157 | weights.append(text_weight) 158 | if truncated: 159 | logger.warning( 160 | "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" 161 | ) 162 | return tokens, weights 163 | 164 | 165 | def pad_tokens_and_weights( 166 | tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77 167 | ): 168 | r""" 169 | Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. 170 | """ 171 | max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) 172 | weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length 173 | for i in range(len(tokens)): 174 | tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) 175 | if no_boseos_middle: 176 | weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) 177 | else: 178 | w = [] 179 | if len(weights[i]) == 0: 180 | w = [1.0] * weights_length 181 | else: 182 | for j in range(max_embeddings_multiples): 183 | w.append(1.0) # weight for starting token in this chunk 184 | w += weights[i][ 185 | j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2)) 186 | ] 187 | w.append(1.0) # weight for ending token in this chunk 188 | w += [1.0] * (weights_length - len(w)) 189 | weights[i] = w[:] 190 | 191 | return tokens, weights 192 | 193 | 194 | def get_unweighted_text_embeddings( 195 | pipe: StableDiffusionPipeline, 196 | text_input: torch.Tensor, 197 | chunk_length: int, 198 | no_boseos_middle: T.Optional[bool] = True, 199 | ) -> torch.FloatTensor: 200 | """ 201 | When the length of tokens is a multiple of the capacity of the text encoder, 202 | it should be split into chunks and sent to the text encoder individually. 203 | """ 204 | max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) 205 | if max_embeddings_multiples > 1: 206 | text_embeddings = [] 207 | for i in range(max_embeddings_multiples): 208 | # extract the i-th chunk 209 | text_input_chunk = text_input[ 210 | :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 211 | ].clone() 212 | 213 | # cover the head and the tail by the starting and the ending tokens 214 | text_input_chunk[:, 0] = text_input[0, 0] 215 | text_input_chunk[:, -1] = text_input[0, -1] 216 | text_embedding = pipe.text_encoder(text_input_chunk)[0] 217 | 218 | if no_boseos_middle: 219 | if i == 0: 220 | # discard the ending token 221 | text_embedding = text_embedding[:, :-1] 222 | elif i == max_embeddings_multiples - 1: 223 | # discard the starting token 224 | text_embedding = text_embedding[:, 1:] 225 | else: 226 | # discard both starting and ending tokens 227 | text_embedding = text_embedding[:, 1:-1] 228 | 229 | text_embeddings.append(text_embedding) 230 | text_embeddings = torch.concat(text_embeddings, axis=1) 231 | else: 232 | text_embeddings = pipe.text_encoder(text_input)[0] 233 | return text_embeddings 234 | 235 | 236 | def get_weighted_text_embeddings( 237 | pipe: StableDiffusionPipeline, 238 | prompt: T.Union[str, T.List[str]], 239 | uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None, 240 | max_embeddings_multiples: T.Optional[int] = 3, 241 | no_boseos_middle: T.Optional[bool] = False, 242 | skip_parsing: T.Optional[bool] = False, 243 | skip_weighting: T.Optional[bool] = False, 244 | **kwargs, 245 | ) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]: 246 | r""" 247 | Prompts can be assigned with local weights using brackets. For example, 248 | prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', 249 | and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. 250 | Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. 251 | Args: 252 | pipe (`StableDiffusionPipeline`): 253 | Pipe to provide access to the tokenizer and the text encoder. 254 | prompt (`str` or `T.List[str]`): 255 | The prompt or prompts to guide the image generation. 256 | uncond_prompt (`str` or `T.List[str]`): 257 | The unconditional prompt or prompts for guide the image generation. If unconditional prompt 258 | is provided, the embeddings of prompt and uncond_prompt are concatenated. 259 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 260 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 261 | no_boseos_middle (`bool`, *optional*, defaults to `False`): 262 | If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and 263 | ending token in each of the chunk in the middle. 264 | skip_parsing (`bool`, *optional*, defaults to `False`): 265 | Skip the parsing of brackets. 266 | skip_weighting (`bool`, *optional*, defaults to `False`): 267 | Skip the weighting. When the parsing is skipped, it is forced True. 268 | """ 269 | max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 270 | if isinstance(prompt, str): 271 | prompt = [prompt] 272 | 273 | if not skip_parsing: 274 | prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) 275 | 276 | if uncond_prompt is not None: 277 | if isinstance(uncond_prompt, str): 278 | uncond_prompt = [uncond_prompt] 279 | uncond_tokens, uncond_weights = get_prompts_with_weights( 280 | pipe, uncond_prompt, max_length - 2 281 | ) 282 | else: 283 | prompt_tokens = [ 284 | token[1:-1] 285 | for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids 286 | ] 287 | prompt_weights = [[1.0] * len(token) for token in prompt_tokens] 288 | if uncond_prompt is not None: 289 | if isinstance(uncond_prompt, str): 290 | uncond_prompt = [uncond_prompt] 291 | uncond_tokens = [ 292 | token[1:-1] 293 | for token in pipe.tokenizer( 294 | uncond_prompt, max_length=max_length, truncation=True 295 | ).input_ids 296 | ] 297 | uncond_weights = [[1.0] * len(token) for token in uncond_tokens] 298 | 299 | # round up the longest length of tokens to a multiple of (model_max_length - 2) 300 | max_length = max([len(token) for token in prompt_tokens]) 301 | if uncond_prompt is not None: 302 | max_length = max(max_length, max([len(token) for token in uncond_tokens])) 303 | 304 | max_embeddings_multiples = min( 305 | max_embeddings_multiples, 306 | (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, 307 | ) 308 | max_embeddings_multiples = max(1, max_embeddings_multiples) 309 | max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 310 | 311 | # pad the length of tokens and weights 312 | bos = pipe.tokenizer.bos_token_id 313 | eos = pipe.tokenizer.eos_token_id 314 | prompt_tokens, prompt_weights = pad_tokens_and_weights( 315 | prompt_tokens, 316 | prompt_weights, 317 | max_length, 318 | bos, 319 | eos, 320 | no_boseos_middle=no_boseos_middle, 321 | chunk_length=pipe.tokenizer.model_max_length, 322 | ) 323 | prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) 324 | if uncond_prompt is not None: 325 | uncond_tokens, uncond_weights = pad_tokens_and_weights( 326 | uncond_tokens, 327 | uncond_weights, 328 | max_length, 329 | bos, 330 | eos, 331 | no_boseos_middle=no_boseos_middle, 332 | chunk_length=pipe.tokenizer.model_max_length, 333 | ) 334 | uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) 335 | 336 | # get the embeddings 337 | text_embeddings = get_unweighted_text_embeddings( 338 | pipe, 339 | prompt_tokens, 340 | pipe.tokenizer.model_max_length, 341 | no_boseos_middle=no_boseos_middle, 342 | ) 343 | prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) 344 | if uncond_prompt is not None: 345 | uncond_embeddings = get_unweighted_text_embeddings( 346 | pipe, 347 | uncond_tokens, 348 | pipe.tokenizer.model_max_length, 349 | no_boseos_middle=no_boseos_middle, 350 | ) 351 | uncond_weights = torch.tensor( 352 | uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device 353 | ) 354 | 355 | # assign weights to the prompts and normalize in the sense of mean 356 | # TODO: should we normalize by chunk or in a whole (current implementation)? 357 | if (not skip_parsing) and (not skip_weighting): 358 | previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) 359 | text_embeddings *= prompt_weights.unsqueeze(-1) 360 | current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) 361 | text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) 362 | if uncond_prompt is not None: 363 | previous_mean = ( 364 | uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) 365 | ) 366 | uncond_embeddings *= uncond_weights.unsqueeze(-1) 367 | current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) 368 | uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) 369 | 370 | if uncond_prompt is not None: 371 | return text_embeddings, uncond_embeddings 372 | return text_embeddings, None 373 | -------------------------------------------------------------------------------- /riffusion/py.typed: -------------------------------------------------------------------------------- 1 | # https://peps.python.org/pep-0561/ 2 | -------------------------------------------------------------------------------- /riffusion/riffusion_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Riffusion inference pipeline. 3 | """ 4 | from __future__ import annotations 5 | 6 | import dataclasses 7 | import functools 8 | import inspect 9 | import typing as T 10 | 11 | import numpy as np 12 | import torch 13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 14 | from diffusers.pipeline_utils import DiffusionPipeline 15 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 16 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler 17 | from diffusers.utils import logging 18 | from huggingface_hub import hf_hub_download 19 | from PIL import Image 20 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 21 | 22 | from riffusion.datatypes import InferenceInput 23 | from riffusion.external.prompt_weighting import get_weighted_text_embeddings 24 | from riffusion.util import torch_util 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | 29 | class RiffusionPipeline(DiffusionPipeline): 30 | """ 31 | Diffusers pipeline for doing a controlled img2img interpolation for audio generation. 32 | 33 | # TODO(hayk): Document more 34 | 35 | Part of this code was adapted from the non-img2img interpolation pipeline at: 36 | 37 | https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py 38 | 39 | Check the documentation for DiffusionPipeline for full information. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | vae: AutoencoderKL, 45 | text_encoder: CLIPTextModel, 46 | tokenizer: CLIPTokenizer, 47 | unet: UNet2DConditionModel, 48 | scheduler: T.Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 49 | safety_checker: StableDiffusionSafetyChecker, 50 | feature_extractor: CLIPFeatureExtractor, 51 | ): 52 | super().__init__() 53 | self.register_modules( 54 | vae=vae, 55 | text_encoder=text_encoder, 56 | tokenizer=tokenizer, 57 | unet=unet, 58 | scheduler=scheduler, 59 | safety_checker=safety_checker, 60 | feature_extractor=feature_extractor, 61 | ) 62 | 63 | @classmethod 64 | def load_checkpoint( 65 | cls, 66 | checkpoint: str, 67 | use_traced_unet: bool = True, 68 | channels_last: bool = False, 69 | dtype: torch.dtype = torch.float16, 70 | device: str = "cuda", 71 | local_files_only: bool = False, 72 | low_cpu_mem_usage: bool = False, 73 | cache_dir: T.Optional[str] = None, 74 | ) -> RiffusionPipeline: 75 | """ 76 | Load the riffusion model pipeline. 77 | 78 | Args: 79 | checkpoint: Model checkpoint on disk in diffusers format 80 | use_traced_unet: Whether to use the traced unet for speedups 81 | device: Device to load the model on 82 | channels_last: Whether to use channels_last memory format 83 | local_files_only: Don't download, only use local files 84 | low_cpu_mem_usage: Attempt to use less memory on CPU 85 | """ 86 | device = torch_util.check_device(device) 87 | 88 | if device == "cpu" or device.lower().startswith("mps"): 89 | print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") 90 | dtype = torch.float32 91 | 92 | pipeline = RiffusionPipeline.from_pretrained( 93 | checkpoint, 94 | revision="main", 95 | torch_dtype=dtype, 96 | # Disable the NSFW filter, causes incorrect false positives 97 | # TODO(hayk): Disable the "you have passed a non-standard module" warning from this. 98 | safety_checker=lambda images, **kwargs: (images, False), 99 | low_cpu_mem_usage=low_cpu_mem_usage, 100 | local_files_only=local_files_only, 101 | cache_dir=cache_dir, 102 | ).to(device) 103 | 104 | if channels_last: 105 | pipeline.unet.to(memory_format=torch.channels_last) 106 | 107 | # Optionally load a traced unet 108 | if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet: 109 | traced_unet = cls.load_traced_unet( 110 | checkpoint=checkpoint, 111 | subfolder="unet_traced", 112 | filename="unet_traced.pt", 113 | in_channels=pipeline.unet.in_channels, 114 | dtype=dtype, 115 | device=device, 116 | local_files_only=local_files_only, 117 | cache_dir=cache_dir, 118 | ) 119 | 120 | if traced_unet is not None: 121 | pipeline.unet = traced_unet 122 | 123 | model = pipeline.to(device) 124 | 125 | return model 126 | 127 | @staticmethod 128 | def load_traced_unet( 129 | checkpoint: str, 130 | subfolder: str, 131 | filename: str, 132 | in_channels: int, 133 | dtype: torch.dtype, 134 | device: str = "cuda", 135 | local_files_only=False, 136 | cache_dir: T.Optional[str] = None, 137 | ) -> T.Optional[torch.nn.Module]: 138 | """ 139 | Load a traced unet from the huggingface hub. This can improve performance. 140 | """ 141 | if device == "cpu" or device.lower().startswith("mps"): 142 | print("WARNING: Traced UNet only available for CUDA, skipping") 143 | return None 144 | 145 | # Download and load the traced unet 146 | unet_file = hf_hub_download( 147 | checkpoint, 148 | subfolder=subfolder, 149 | filename=filename, 150 | local_files_only=local_files_only, 151 | cache_dir=cache_dir, 152 | ) 153 | unet_traced = torch.jit.load(unet_file) 154 | 155 | # Wrap it in a torch module 156 | class TracedUNet(torch.nn.Module): 157 | @dataclasses.dataclass 158 | class UNet2DConditionOutput: 159 | sample: torch.FloatTensor 160 | 161 | def __init__(self): 162 | super().__init__() 163 | self.in_channels = device 164 | self.device = device 165 | self.dtype = dtype 166 | 167 | def forward(self, latent_model_input, t, encoder_hidden_states): 168 | sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] 169 | return self.UNet2DConditionOutput(sample=sample) 170 | 171 | return TracedUNet() 172 | 173 | @property 174 | def device(self) -> str: 175 | return str(self.vae.device) 176 | 177 | @functools.lru_cache() 178 | def embed_text(self, text) -> torch.FloatTensor: 179 | """ 180 | Takes in text and turns it into text embeddings. 181 | """ 182 | text_input = self.tokenizer( 183 | text, 184 | padding="max_length", 185 | max_length=self.tokenizer.model_max_length, 186 | truncation=True, 187 | return_tensors="pt", 188 | ) 189 | with torch.no_grad(): 190 | embed = self.text_encoder(text_input.input_ids.to(self.device))[0] 191 | return embed 192 | 193 | @functools.lru_cache() 194 | def embed_text_weighted(self, text) -> torch.FloatTensor: 195 | """ 196 | Get text embedding with weights. 197 | """ 198 | return get_weighted_text_embeddings( 199 | pipe=self, 200 | prompt=text, 201 | uncond_prompt=None, 202 | max_embeddings_multiples=3, 203 | no_boseos_middle=False, 204 | skip_parsing=False, 205 | skip_weighting=False, 206 | )[0] 207 | 208 | @torch.no_grad() 209 | def riffuse( 210 | self, 211 | inputs: InferenceInput, 212 | init_image: Image.Image, 213 | mask_image: T.Optional[Image.Image] = None, 214 | use_reweighting: bool = True, 215 | ) -> Image.Image: 216 | """ 217 | Runs inference using interpolation with both img2img and text conditioning. 218 | 219 | Args: 220 | inputs: Parameter dataclass 221 | init_image: Image used for conditioning 222 | mask_image: White pixels in the mask will be replaced by noise and therefore repainted, 223 | while black pixels will be preserved. It will be converted to a single 224 | channel (luminance) before use. 225 | use_reweighting: Use prompt reweighting 226 | """ 227 | alpha = inputs.alpha 228 | start = inputs.start 229 | end = inputs.end 230 | 231 | guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha 232 | 233 | # TODO(hayk): Always generate the seed on CPU? 234 | if self.device.lower().startswith("mps"): 235 | generator_start = torch.Generator(device="cpu").manual_seed(start.seed) 236 | generator_end = torch.Generator(device="cpu").manual_seed(end.seed) 237 | else: 238 | generator_start = torch.Generator(device=self.device).manual_seed(start.seed) 239 | generator_end = torch.Generator(device=self.device).manual_seed(end.seed) 240 | 241 | # Text encodings 242 | if use_reweighting: 243 | embed_start = self.embed_text_weighted(start.prompt) 244 | embed_end = self.embed_text_weighted(end.prompt) 245 | else: 246 | embed_start = self.embed_text(start.prompt) 247 | embed_end = self.embed_text(end.prompt) 248 | 249 | text_embedding = embed_start + alpha * (embed_end - embed_start) 250 | 251 | # Image latents 252 | init_image_torch = preprocess_image(init_image).to( 253 | device=self.device, dtype=embed_start.dtype 254 | ) 255 | init_latent_dist = self.vae.encode(init_image_torch).latent_dist 256 | # TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The 257 | # result is so close no matter the seed that it doesn't really add variety. 258 | if self.device.lower().startswith("mps"): 259 | generator = torch.Generator(device="cpu").manual_seed(start.seed) 260 | else: 261 | generator = torch.Generator(device=self.device).manual_seed(start.seed) 262 | 263 | init_latents = init_latent_dist.sample(generator=generator) 264 | init_latents = 0.18215 * init_latents 265 | 266 | # Prepare mask latent 267 | mask: T.Optional[torch.Tensor] = None 268 | if mask_image: 269 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 270 | mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to( 271 | device=self.device, dtype=embed_start.dtype 272 | ) 273 | 274 | outputs = self.interpolate_img2img( 275 | text_embeddings=text_embedding, 276 | init_latents=init_latents, 277 | mask=mask, 278 | generator_a=generator_start, 279 | generator_b=generator_end, 280 | interpolate_alpha=alpha, 281 | strength_a=start.denoising, 282 | strength_b=end.denoising, 283 | num_inference_steps=inputs.num_inference_steps, 284 | guidance_scale=guidance_scale, 285 | ) 286 | 287 | return outputs["images"][0] 288 | 289 | @torch.no_grad() 290 | def interpolate_img2img( 291 | self, 292 | text_embeddings: torch.Tensor, 293 | init_latents: torch.Tensor, 294 | generator_a: torch.Generator, 295 | generator_b: torch.Generator, 296 | interpolate_alpha: float, 297 | mask: T.Optional[torch.Tensor] = None, 298 | strength_a: float = 0.8, 299 | strength_b: float = 0.8, 300 | num_inference_steps: int = 50, 301 | guidance_scale: float = 7.5, 302 | negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None, 303 | num_images_per_prompt: int = 1, 304 | eta: T.Optional[float] = 0.0, 305 | output_type: T.Optional[str] = "pil", 306 | **kwargs, 307 | ): 308 | """ 309 | TODO 310 | """ 311 | batch_size = text_embeddings.shape[0] 312 | 313 | # set timesteps 314 | self.scheduler.set_timesteps(num_inference_steps) 315 | 316 | # duplicate text embeddings for each generation per prompt, using mps friendly method 317 | bs_embed, seq_len, _ = text_embeddings.shape 318 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) 319 | text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 320 | 321 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 322 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 323 | # corresponds to doing no classifier free guidance. 324 | do_classifier_free_guidance = guidance_scale > 1.0 325 | # get unconditional embeddings for classifier free guidance 326 | if do_classifier_free_guidance: 327 | if negative_prompt is None: 328 | uncond_tokens = [""] 329 | elif isinstance(negative_prompt, str): 330 | uncond_tokens = [negative_prompt] 331 | elif batch_size != len(negative_prompt): 332 | raise ValueError("The length of `negative_prompt` should be equal to batch_size.") 333 | else: 334 | uncond_tokens = negative_prompt 335 | 336 | # max_length = text_input_ids.shape[-1] 337 | uncond_input = self.tokenizer( 338 | uncond_tokens, 339 | padding="max_length", 340 | max_length=self.tokenizer.model_max_length, 341 | truncation=True, 342 | return_tensors="pt", 343 | ) 344 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 345 | 346 | # duplicate unconditional embeddings for each generation per prompt 347 | uncond_embeddings = uncond_embeddings.repeat_interleave( 348 | batch_size * num_images_per_prompt, dim=0 349 | ) 350 | 351 | # For classifier free guidance, we need to do two forward passes. 352 | # Here we concatenate the unconditional and text embeddings into a single batch 353 | # to avoid doing two forward passes 354 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 355 | 356 | latents_dtype = text_embeddings.dtype 357 | 358 | strength = (1 - interpolate_alpha) * strength_a + interpolate_alpha * strength_b 359 | 360 | # get the original timestep using init_timestep 361 | offset = self.scheduler.config.get("steps_offset", 0) 362 | init_timestep = int(num_inference_steps * strength) + offset 363 | init_timestep = min(init_timestep, num_inference_steps) 364 | 365 | timesteps = self.scheduler.timesteps[-init_timestep] 366 | timesteps = torch.tensor( 367 | [timesteps] * batch_size * num_images_per_prompt, device=self.device 368 | ) 369 | 370 | # add noise to latents using the timesteps 371 | noise_a = torch.randn( 372 | init_latents.shape, generator=generator_a, device=self.device, dtype=latents_dtype 373 | ) 374 | noise_b = torch.randn( 375 | init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype 376 | ) 377 | noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b) 378 | init_latents_orig = init_latents 379 | init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) 380 | 381 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same args 382 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 383 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 384 | # and should be between [0, 1] 385 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 386 | extra_step_kwargs = {} 387 | if accepts_eta: 388 | extra_step_kwargs["eta"] = eta 389 | 390 | latents = init_latents.clone() 391 | 392 | t_start = max(num_inference_steps - init_timestep + offset, 0) 393 | 394 | # Some schedulers like PNDM have timesteps as arrays 395 | # It's more optimized to move all timesteps to correct device beforehand 396 | timesteps = self.scheduler.timesteps[t_start:].to(self.device) 397 | 398 | for i, t in enumerate(self.progress_bar(timesteps)): 399 | # expand the latents if we are doing classifier free guidance 400 | latent_model_input = ( 401 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 402 | ) 403 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 404 | 405 | # predict the noise residual 406 | noise_pred = self.unet( 407 | latent_model_input, t, encoder_hidden_states=text_embeddings 408 | ).sample 409 | 410 | # perform guidance 411 | if do_classifier_free_guidance: 412 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 413 | noise_pred = noise_pred_uncond + guidance_scale * ( 414 | noise_pred_text - noise_pred_uncond 415 | ) 416 | 417 | # compute the previous noisy sample x_t -> x_t-1 418 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 419 | 420 | if mask is not None: 421 | init_latents_proper = self.scheduler.add_noise( 422 | init_latents_orig, noise, torch.tensor([t]) 423 | ) 424 | # import ipdb; ipdb.set_trace() 425 | latents = (init_latents_proper * mask) + (latents * (1 - mask)) 426 | 427 | latents = 1.0 / 0.18215 * latents 428 | image = self.vae.decode(latents).sample 429 | 430 | image = (image / 2 + 0.5).clamp(0, 1) 431 | image = image.cpu().permute(0, 2, 3, 1).numpy() 432 | 433 | if output_type == "pil": 434 | image = self.numpy_to_pil(image) 435 | 436 | return dict(images=image, latents=latents, nsfw_content_detected=False) 437 | 438 | 439 | def preprocess_image(image: Image.Image) -> torch.Tensor: 440 | """ 441 | Preprocess an image for the model. 442 | """ 443 | w, h = image.size 444 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 445 | image = image.resize((w, h), resample=Image.LANCZOS) 446 | 447 | image_np = np.array(image).astype(np.float32) / 255.0 448 | image_np = image_np[None].transpose(0, 3, 1, 2) 449 | 450 | image_torch = torch.from_numpy(image_np) 451 | 452 | return 2.0 * image_torch - 1.0 453 | 454 | 455 | def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor: 456 | """ 457 | Preprocess a mask for the model. 458 | """ 459 | # Convert to grayscale 460 | mask = mask.convert("L") 461 | 462 | # Resize to integer multiple of 32 463 | w, h = mask.size 464 | w, h = map(lambda x: x - x % 32, (w, h)) 465 | mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST) 466 | 467 | # Convert to numpy array and rescale 468 | mask_np = np.array(mask).astype(np.float32) / 255.0 469 | 470 | # Tile and transpose 471 | mask_np = np.tile(mask_np, (4, 1, 1)) 472 | mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do? 473 | 474 | # Invert to repaint white and keep black 475 | mask_np = 1 - mask_np # repaint white, keep black 476 | 477 | return torch.from_numpy(mask_np) 478 | -------------------------------------------------------------------------------- /riffusion/server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flask server that serves the riffusion model as an API. 3 | """ 4 | 5 | import dataclasses 6 | import io 7 | import json 8 | import logging 9 | import time 10 | import typing as T 11 | from pathlib import Path 12 | 13 | import dacite 14 | import flask 15 | import PIL 16 | from flask_cors import CORS 17 | 18 | from riffusion.datatypes import InferenceInput, InferenceOutput 19 | from riffusion.riffusion_pipeline import RiffusionPipeline 20 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter 21 | from riffusion.spectrogram_params import SpectrogramParams 22 | from riffusion.util import base64_util 23 | 24 | # Flask app with CORS 25 | app = flask.Flask(__name__) 26 | CORS(app) 27 | 28 | # Log at the INFO level to both stdout and disk 29 | logging.basicConfig(level=logging.INFO) 30 | logging.getLogger().addHandler(logging.FileHandler("server.log")) 31 | 32 | # Global variable for the model pipeline 33 | PIPELINE: T.Optional[RiffusionPipeline] = None 34 | 35 | # Where built-in seed images are stored 36 | SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images") 37 | 38 | 39 | def run_app( 40 | *, 41 | checkpoint: str = "riffusion/riffusion-model-v1", 42 | no_traced_unet: bool = False, 43 | device: str = "cuda", 44 | host: str = "127.0.0.1", 45 | port: int = 3013, 46 | debug: bool = False, 47 | ssl_certificate: T.Optional[str] = None, 48 | ssl_key: T.Optional[str] = None, 49 | ): 50 | """ 51 | Run a flask API that serves the given riffusion model checkpoint. 52 | """ 53 | # Initialize the model 54 | global PIPELINE 55 | PIPELINE = RiffusionPipeline.load_checkpoint( 56 | checkpoint=checkpoint, 57 | use_traced_unet=not no_traced_unet, 58 | device=device, 59 | ) 60 | 61 | args = dict( 62 | debug=debug, 63 | threaded=False, 64 | host=host, 65 | port=port, 66 | ) 67 | 68 | if ssl_certificate: 69 | assert ssl_key is not None 70 | args["ssl_context"] = (ssl_certificate, ssl_key) 71 | 72 | app.run(**args) # type: ignore 73 | 74 | 75 | @app.route("/run_inference/", methods=["POST"]) 76 | def run_inference(): 77 | """ 78 | Execute the riffusion model as an API. 79 | 80 | Inputs: 81 | Serialized JSON of the InferenceInput dataclass 82 | 83 | Returns: 84 | Serialized JSON of the InferenceOutput dataclass 85 | """ 86 | start_time = time.time() 87 | 88 | # Parse the payload as JSON 89 | json_data = json.loads(flask.request.data) 90 | 91 | # Log the request 92 | logging.info(json_data) 93 | 94 | # Parse an InferenceInput dataclass from the payload 95 | try: 96 | inputs = dacite.from_dict(InferenceInput, json_data) 97 | except dacite.exceptions.WrongTypeError as exception: 98 | logging.info(json_data) 99 | return str(exception), 400 100 | except dacite.exceptions.MissingValueError as exception: 101 | logging.info(json_data) 102 | return str(exception), 400 103 | 104 | response = compute_request( 105 | inputs=inputs, 106 | seed_images_dir=SEED_IMAGES_DIR, 107 | pipeline=PIPELINE, 108 | ) 109 | 110 | # Log the total time 111 | logging.info(f"Request took {time.time() - start_time:.2f} s") 112 | 113 | return response 114 | 115 | 116 | def compute_request( 117 | inputs: InferenceInput, 118 | pipeline: RiffusionPipeline, 119 | seed_images_dir: str, 120 | ) -> T.Union[str, T.Tuple[str, int]]: 121 | """ 122 | Does all the heavy lifting of the request. 123 | 124 | Args: 125 | inputs: The input dataclass 126 | pipeline: The riffusion model pipeline 127 | seed_images_dir: The directory where seed images are stored 128 | """ 129 | # Load the seed image by ID 130 | init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png") 131 | 132 | if not init_image_path.is_file(): 133 | return f"Invalid seed image: {inputs.seed_image_id}", 400 134 | init_image = PIL.Image.open(str(init_image_path)).convert("RGB") 135 | 136 | # Load the mask image by ID 137 | mask_image: T.Optional[PIL.Image.Image] = None 138 | if inputs.mask_image_id: 139 | mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png") 140 | if not mask_image_path.is_file(): 141 | return f"Invalid mask image: {inputs.mask_image_id}", 400 142 | mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB") 143 | 144 | # Execute the model to get the spectrogram image 145 | image = pipeline.riffuse( 146 | inputs, 147 | init_image=init_image, 148 | mask_image=mask_image, 149 | ) 150 | 151 | # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained 152 | params = SpectrogramParams( 153 | min_frequency=0, 154 | max_frequency=10000, 155 | ) 156 | 157 | # Reconstruct audio from the image 158 | # TODO(hayk): It may help performance a bit to cache this object 159 | converter = SpectrogramImageConverter(params=params, device=str(pipeline.device)) 160 | 161 | segment = converter.audio_from_spectrogram_image( 162 | image, 163 | apply_filters=True, 164 | ) 165 | 166 | # Export audio to MP3 bytes 167 | mp3_bytes = io.BytesIO() 168 | segment.export(mp3_bytes, format="mp3") 169 | mp3_bytes.seek(0) 170 | 171 | # Export image to JPEG bytes 172 | image_bytes = io.BytesIO() 173 | image.save(image_bytes, exif=image.getexif(), format="JPEG") 174 | image_bytes.seek(0) 175 | 176 | # Assemble the output dataclass 177 | output = InferenceOutput( 178 | image="data:image/jpeg;base64," + base64_util.encode(image_bytes), 179 | audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes), 180 | duration_s=segment.duration_seconds, 181 | ) 182 | 183 | return json.dumps(dataclasses.asdict(output)) 184 | 185 | 186 | if __name__ == "__main__": 187 | import argh 188 | 189 | argh.dispatch_command(run_app) 190 | -------------------------------------------------------------------------------- /riffusion/spectrogram_converter.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import pydub 5 | import torch 6 | import torchaudio 7 | 8 | from riffusion.spectrogram_params import SpectrogramParams 9 | from riffusion.util import audio_util, torch_util 10 | 11 | 12 | class SpectrogramConverter: 13 | """ 14 | Convert between audio segments and spectrogram tensors using torchaudio. 15 | 16 | In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values 17 | that represent the amplitude of the frequency at that time bucket (in the frequency domain). 18 | Frequencies are given in the perceptul Mel scale defined by the params. A more specific term 19 | used in some functions is "mel amplitudes". 20 | 21 | The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only 22 | returns the amplitude, because the phase is chaotic and hard to learn. The function 23 | `audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which 24 | approximates the phase information using the Griffin-Lim algorithm. 25 | 26 | Each channel in the audio is treated independently, and the spectrogram has a batch dimension 27 | equal to the number of channels in the input audio segment. 28 | 29 | Both the Griffin Lim algorithm and the Mel scaling process are lossy. 30 | 31 | For more information, see https://pytorch.org/audio/stable/transforms.html 32 | """ 33 | 34 | def __init__(self, params: SpectrogramParams, device: str = "cuda"): 35 | self.p = params 36 | 37 | self.device = torch_util.check_device(device) 38 | 39 | if device.lower().startswith("mps"): 40 | warnings.warn( 41 | "WARNING: MPS does not support audio operations, falling back to CPU for them", 42 | stacklevel=2, 43 | ) 44 | self.device = "cpu" 45 | 46 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html 47 | self.spectrogram_func = torchaudio.transforms.Spectrogram( 48 | n_fft=params.n_fft, 49 | hop_length=params.hop_length, 50 | win_length=params.win_length, 51 | pad=0, 52 | window_fn=torch.hann_window, 53 | power=None, 54 | normalized=False, 55 | wkwargs=None, 56 | center=True, 57 | pad_mode="reflect", 58 | onesided=True, 59 | ).to(self.device) 60 | 61 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html 62 | self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim( 63 | n_fft=params.n_fft, 64 | n_iter=params.num_griffin_lim_iters, 65 | win_length=params.win_length, 66 | hop_length=params.hop_length, 67 | window_fn=torch.hann_window, 68 | power=1.0, 69 | wkwargs=None, 70 | momentum=0.99, 71 | length=None, 72 | rand_init=True, 73 | ).to(self.device) 74 | 75 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html 76 | self.mel_scaler = torchaudio.transforms.MelScale( 77 | n_mels=params.num_frequencies, 78 | sample_rate=params.sample_rate, 79 | f_min=params.min_frequency, 80 | f_max=params.max_frequency, 81 | n_stft=params.n_fft // 2 + 1, 82 | norm=params.mel_scale_norm, 83 | mel_scale=params.mel_scale_type, 84 | ).to(self.device) 85 | 86 | # https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html 87 | self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale( 88 | n_stft=params.n_fft // 2 + 1, 89 | n_mels=params.num_frequencies, 90 | sample_rate=params.sample_rate, 91 | f_min=params.min_frequency, 92 | f_max=params.max_frequency, 93 | max_iter=params.max_mel_iters, 94 | tolerance_loss=1e-5, 95 | tolerance_change=1e-8, 96 | sgdargs=None, 97 | norm=params.mel_scale_norm, 98 | mel_scale=params.mel_scale_type, 99 | ).to(self.device) 100 | 101 | def spectrogram_from_audio( 102 | self, 103 | audio: pydub.AudioSegment, 104 | ) -> np.ndarray: 105 | """ 106 | Compute a spectrogram from an audio segment. 107 | 108 | Args: 109 | audio: Audio segment which must match the sample rate of the params 110 | 111 | Returns: 112 | spectrogram: (channel, frequency, time) 113 | """ 114 | assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params" 115 | 116 | # Get the samples as a numpy array in (batch, samples) shape 117 | waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()]) 118 | 119 | # Convert to floats if necessary 120 | if waveform.dtype != np.float32: 121 | waveform = waveform.astype(np.float32) 122 | 123 | waveform_tensor = torch.from_numpy(waveform).to(self.device) 124 | amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor) 125 | return amplitudes_mel.cpu().numpy() 126 | 127 | def audio_from_spectrogram( 128 | self, 129 | spectrogram: np.ndarray, 130 | apply_filters: bool = True, 131 | ) -> pydub.AudioSegment: 132 | """ 133 | Reconstruct an audio segment from a spectrogram. 134 | 135 | Args: 136 | spectrogram: (batch, frequency, time) 137 | apply_filters: Post-process with normalization and compression 138 | 139 | Returns: 140 | audio: Audio segment with channels equal to the batch dimension 141 | """ 142 | # Move to device 143 | amplitudes_mel = torch.from_numpy(spectrogram).to(self.device) 144 | 145 | # Reconstruct the waveform 146 | waveform = self.waveform_from_mel_amplitudes(amplitudes_mel) 147 | 148 | # Convert to audio segment 149 | segment = audio_util.audio_from_waveform( 150 | samples=waveform.cpu().numpy(), 151 | sample_rate=self.p.sample_rate, 152 | # Normalize the waveform to the range [-1, 1] 153 | normalize=True, 154 | ) 155 | 156 | # Optionally apply post-processing filters 157 | if apply_filters: 158 | segment = audio_util.apply_filters( 159 | segment, 160 | compression=False, 161 | ) 162 | 163 | return segment 164 | 165 | def mel_amplitudes_from_waveform( 166 | self, 167 | waveform: torch.Tensor, 168 | ) -> torch.Tensor: 169 | """ 170 | Torch-only function to compute Mel-scale amplitudes from a waveform. 171 | 172 | Args: 173 | waveform: (batch, samples) 174 | 175 | Returns: 176 | amplitudes_mel: (batch, frequency, time) 177 | """ 178 | # Compute the complex-valued spectrogram 179 | spectrogram_complex = self.spectrogram_func(waveform) 180 | 181 | # Take the magnitude 182 | amplitudes = torch.abs(spectrogram_complex) 183 | 184 | # Convert to mel scale 185 | return self.mel_scaler(amplitudes) 186 | 187 | def waveform_from_mel_amplitudes( 188 | self, 189 | amplitudes_mel: torch.Tensor, 190 | ) -> torch.Tensor: 191 | """ 192 | Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes. 193 | 194 | Args: 195 | amplitudes_mel: (batch, frequency, time) 196 | 197 | Returns: 198 | waveform: (batch, samples) 199 | """ 200 | # Convert from mel scale to linear 201 | amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel) 202 | 203 | # Run the approximate algorithm to compute the phase and recover the waveform 204 | return self.inverse_spectrogram_func(amplitudes_linear) 205 | -------------------------------------------------------------------------------- /riffusion/spectrogram_image_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydub 3 | from PIL import Image 4 | 5 | from riffusion.spectrogram_converter import SpectrogramConverter 6 | from riffusion.spectrogram_params import SpectrogramParams 7 | from riffusion.util import image_util 8 | 9 | 10 | class SpectrogramImageConverter: 11 | """ 12 | Convert between spectrogram images and audio segments. 13 | 14 | This is a wrapper around SpectrogramConverter that additionally converts from spectrograms 15 | to images and back. The real audio processing lives in SpectrogramConverter. 16 | """ 17 | 18 | def __init__(self, params: SpectrogramParams, device: str = "cuda"): 19 | self.p = params 20 | self.device = device 21 | self.converter = SpectrogramConverter(params=params, device=device) 22 | 23 | def spectrogram_image_from_audio( 24 | self, 25 | segment: pydub.AudioSegment, 26 | ) -> Image.Image: 27 | """ 28 | Compute a spectrogram image from an audio segment. 29 | 30 | Args: 31 | segment: Audio segment to convert 32 | 33 | Returns: 34 | Spectrogram image (in pillow format) 35 | """ 36 | assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch" 37 | 38 | if self.p.stereo: 39 | if segment.channels == 1: 40 | print("WARNING: Mono audio but stereo=True, cloning channel") 41 | segment = segment.set_channels(2) 42 | elif segment.channels > 2: 43 | print("WARNING: Multi channel audio, reducing to stereo") 44 | segment = segment.set_channels(2) 45 | else: 46 | if segment.channels > 1: 47 | print("WARNING: Stereo audio but stereo=False, setting to mono") 48 | segment = segment.set_channels(1) 49 | 50 | spectrogram = self.converter.spectrogram_from_audio(segment) 51 | 52 | image = image_util.image_from_spectrogram( 53 | spectrogram, 54 | power=self.p.power_for_image, 55 | ) 56 | 57 | # Store conversion params in exif metadata of the image 58 | exif_data = self.p.to_exif() 59 | exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram)) 60 | exif = image.getexif() 61 | exif.update(exif_data.items()) 62 | 63 | return image 64 | 65 | def audio_from_spectrogram_image( 66 | self, 67 | image: Image.Image, 68 | apply_filters: bool = True, 69 | max_value: float = 30e6, 70 | ) -> pydub.AudioSegment: 71 | """ 72 | Reconstruct an audio segment from a spectrogram image. 73 | 74 | Args: 75 | image: Spectrogram image (in pillow format) 76 | apply_filters: Apply post-processing to improve the reconstructed audio 77 | max_value: Scaled max amplitude of the spectrogram. Shouldn't matter. 78 | """ 79 | spectrogram = image_util.spectrogram_from_image( 80 | image, 81 | max_value=max_value, 82 | power=self.p.power_for_image, 83 | stereo=self.p.stereo, 84 | ) 85 | 86 | segment = self.converter.audio_from_spectrogram( 87 | spectrogram, 88 | apply_filters=apply_filters, 89 | ) 90 | 91 | return segment 92 | -------------------------------------------------------------------------------- /riffusion/spectrogram_params.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing as T 4 | from dataclasses import dataclass 5 | from enum import Enum 6 | 7 | 8 | @dataclass(frozen=True) 9 | class SpectrogramParams: 10 | """ 11 | Parameters for the conversion from audio to spectrograms to images and back. 12 | 13 | Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored 14 | within spectrogram images. 15 | 16 | To understand what these parameters do and to customize them, read `spectrogram_converter.py` 17 | and the linked torchaudio documentation. 18 | """ 19 | 20 | # Whether the audio is stereo or mono 21 | stereo: bool = False 22 | 23 | # FFT parameters 24 | sample_rate: int = 44100 25 | step_size_ms: int = 10 26 | window_duration_ms: int = 100 27 | padded_duration_ms: int = 400 28 | 29 | # Mel scale parameters 30 | num_frequencies: int = 512 31 | # TODO(hayk): Set these to [20, 20000] for newer models 32 | min_frequency: int = 0 33 | max_frequency: int = 10000 34 | mel_scale_norm: T.Optional[str] = None 35 | mel_scale_type: str = "htk" 36 | max_mel_iters: int = 200 37 | 38 | # Griffin Lim parameters 39 | num_griffin_lim_iters: int = 32 40 | 41 | # Image parameterization 42 | power_for_image: float = 0.25 43 | 44 | class ExifTags(Enum): 45 | """ 46 | Custom EXIF tags for the spectrogram image. 47 | """ 48 | 49 | SAMPLE_RATE = 11000 50 | STEREO = 11005 51 | STEP_SIZE_MS = 11010 52 | WINDOW_DURATION_MS = 11020 53 | PADDED_DURATION_MS = 11030 54 | 55 | NUM_FREQUENCIES = 11040 56 | MIN_FREQUENCY = 11050 57 | MAX_FREQUENCY = 11060 58 | 59 | POWER_FOR_IMAGE = 11070 60 | MAX_VALUE = 11080 61 | 62 | @property 63 | def n_fft(self) -> int: 64 | """ 65 | The number of samples in each STFT window, with padding. 66 | """ 67 | return int(self.padded_duration_ms / 1000.0 * self.sample_rate) 68 | 69 | @property 70 | def win_length(self) -> int: 71 | """ 72 | The number of samples in each STFT window. 73 | """ 74 | return int(self.window_duration_ms / 1000.0 * self.sample_rate) 75 | 76 | @property 77 | def hop_length(self) -> int: 78 | """ 79 | The number of samples between each STFT window. 80 | """ 81 | return int(self.step_size_ms / 1000.0 * self.sample_rate) 82 | 83 | def to_exif(self) -> T.Dict[int, T.Any]: 84 | """ 85 | Return a dictionary of EXIF tags for the current values. 86 | """ 87 | return { 88 | self.ExifTags.SAMPLE_RATE.value: self.sample_rate, 89 | self.ExifTags.STEREO.value: self.stereo, 90 | self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms, 91 | self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms, 92 | self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms, 93 | self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies, 94 | self.ExifTags.MIN_FREQUENCY.value: self.min_frequency, 95 | self.ExifTags.MAX_FREQUENCY.value: self.max_frequency, 96 | self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image), 97 | } 98 | 99 | @classmethod 100 | def from_exif(cls, exif: T.Mapping[int, T.Any]) -> SpectrogramParams: 101 | """ 102 | Create a SpectrogramParams object from the EXIF tags of the given image. 103 | """ 104 | # TODO(hayk): Handle missing tags 105 | return cls( 106 | sample_rate=exif[cls.ExifTags.SAMPLE_RATE.value], 107 | stereo=bool(exif[cls.ExifTags.STEREO.value]), 108 | step_size_ms=exif[cls.ExifTags.STEP_SIZE_MS.value], 109 | window_duration_ms=exif[cls.ExifTags.WINDOW_DURATION_MS.value], 110 | padded_duration_ms=exif[cls.ExifTags.PADDED_DURATION_MS.value], 111 | num_frequencies=exif[cls.ExifTags.NUM_FREQUENCIES.value], 112 | min_frequency=exif[cls.ExifTags.MIN_FREQUENCY.value], 113 | max_frequency=exif[cls.ExifTags.MAX_FREQUENCY.value], 114 | power_for_image=exif[cls.ExifTags.POWER_FOR_IMAGE.value], 115 | ) 116 | -------------------------------------------------------------------------------- /riffusion/streamlit/README.md: -------------------------------------------------------------------------------- 1 | # streamlit 2 | 3 | This package is an interactive streamlit app for riffusion. 4 | -------------------------------------------------------------------------------- /riffusion/streamlit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/streamlit/__init__.py -------------------------------------------------------------------------------- /riffusion/streamlit/playground.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import streamlit as st 4 | import streamlit.web.cli as stcli 5 | from streamlit import runtime 6 | 7 | PAGES = { 8 | "🎛️ Home": "tasks.home", 9 | "🌊 Text to Audio": "tasks.text_to_audio", 10 | "✨ Audio to Audio": "tasks.audio_to_audio", 11 | "🎭 Interpolation": "tasks.interpolation", 12 | "✂️ Audio Splitter": "tasks.split_audio", 13 | "📜 Text to Audio Batch": "tasks.text_to_audio_batch", 14 | "📎 Sample Clips": "tasks.sample_clips", 15 | "⏈ Spectrogram to Audio": "tasks.image_to_audio", 16 | } 17 | 18 | 19 | def render() -> None: 20 | st.set_page_config( 21 | page_title="Riffusion Playground", 22 | page_icon="🎸", 23 | layout="wide", 24 | ) 25 | 26 | page = st.sidebar.selectbox("Page", list(PAGES.keys())) 27 | assert page is not None 28 | module = __import__(PAGES[page], fromlist=["render"]) 29 | module.render() 30 | 31 | 32 | if __name__ == "__main__": 33 | if runtime.exists(): 34 | render() 35 | else: 36 | sys.argv = ["streamlit", "run"] + sys.argv 37 | sys.exit(stcli.main()) 38 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/streamlit/tasks/__init__.py -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/audio_to_audio.py: -------------------------------------------------------------------------------- 1 | import io 2 | import typing as T 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pydub 7 | import streamlit as st 8 | from PIL import Image 9 | 10 | from riffusion.datatypes import InferenceInput, PromptInput 11 | from riffusion.spectrogram_params import SpectrogramParams 12 | from riffusion.streamlit import util as streamlit_util 13 | from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation 14 | from riffusion.util import audio_util 15 | 16 | 17 | def render() -> None: 18 | st.subheader("✨ Audio to Audio") 19 | st.write( 20 | """ 21 | Modify existing audio from a text prompt or interpolate between two. 22 | """ 23 | ) 24 | 25 | with st.expander("Help", False): 26 | st.write( 27 | """ 28 | This tool allows you to upload an audio file of arbitrary length and modify it with 29 | a text prompt. It does this by sweeping over the audio in overlapping clips, doing 30 | img2img style transfer with riffusion, then stitching the clips back together with 31 | cross fading to eliminate seams. 32 | 33 | Try a denoising strength of 0.4 for light modification and 0.55 for more heavy 34 | modification. The best specific denoising depends on how different the prompt is 35 | from the source audio. You can play with the seed to get infinite variations. 36 | Currently the same seed is used for all clips along the track. 37 | 38 | If the Interpolation check box is enabled, supports entering two sets of prompt, 39 | seed, and denoising value and smoothly blends between them along the selected 40 | duration of the audio. This is a great way to create a transition. 41 | """ 42 | ) 43 | 44 | device = streamlit_util.select_device(st.sidebar) 45 | extension = streamlit_util.select_audio_extension(st.sidebar) 46 | checkpoint = streamlit_util.select_checkpoint(st.sidebar) 47 | 48 | use_20k = st.sidebar.checkbox("Use 20kHz", value=False) 49 | use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False) 50 | 51 | with st.sidebar: 52 | num_inference_steps = T.cast( 53 | int, 54 | st.number_input( 55 | "Steps per sample", value=25, help="Number of denoising steps per model run" 56 | ), 57 | ) 58 | 59 | guidance = st.number_input( 60 | "Guidance", 61 | value=7.0, 62 | help="How much the model listens to the text prompt", 63 | ) 64 | 65 | scheduler = st.selectbox( 66 | "Scheduler", 67 | options=streamlit_util.SCHEDULER_OPTIONS, 68 | index=0, 69 | help="Which diffusion scheduler to use", 70 | ) 71 | assert scheduler is not None 72 | 73 | audio_file = st.file_uploader( 74 | "Upload audio", 75 | type=streamlit_util.AUDIO_EXTENSIONS, 76 | label_visibility="collapsed", 77 | ) 78 | 79 | if not audio_file: 80 | st.info("Upload audio to get started") 81 | return 82 | 83 | st.write("#### Original") 84 | st.audio(audio_file) 85 | 86 | segment = streamlit_util.load_audio_file(audio_file) 87 | 88 | # TODO(hayk): Fix 89 | if segment.frame_rate != 44100: 90 | st.warning("Audio must be 44100Hz. Converting") 91 | segment = segment.set_frame_rate(44100) 92 | st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz") 93 | 94 | clip_p = get_clip_params() 95 | start_time_s = clip_p["start_time_s"] 96 | clip_duration_s = clip_p["clip_duration_s"] 97 | overlap_duration_s = clip_p["overlap_duration_s"] 98 | 99 | duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s) 100 | increment_s = clip_duration_s - overlap_duration_s 101 | clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s) 102 | 103 | write_clip_details( 104 | clip_start_times=clip_start_times, 105 | clip_duration_s=clip_duration_s, 106 | overlap_duration_s=overlap_duration_s, 107 | ) 108 | 109 | interpolate = st.checkbox( 110 | "Interpolate between two endpoints", 111 | value=False, 112 | help="Interpolate between two prompts, seeds, or denoising values along the" 113 | "duration of the segment", 114 | ) 115 | 116 | counter = streamlit_util.StreamlitCounter() 117 | 118 | denoising_default = 0.55 119 | with st.form("audio to audio form"): 120 | if interpolate: 121 | left, right = st.columns(2) 122 | 123 | with left: 124 | st.write("##### Prompt A") 125 | prompt_input_a = PromptInput( 126 | guidance=guidance, 127 | **get_prompt_inputs(key="a", denoising_default=denoising_default), 128 | ) 129 | 130 | with right: 131 | st.write("##### Prompt B") 132 | prompt_input_b = PromptInput( 133 | guidance=guidance, 134 | **get_prompt_inputs(key="b", denoising_default=denoising_default), 135 | ) 136 | elif use_magic_mix: 137 | prompt = st.text_input("Prompt", key="prompt_a") 138 | 139 | row = st.columns(4) 140 | 141 | seed = T.cast( 142 | int, 143 | row[0].number_input( 144 | "Seed", 145 | value=42, 146 | key="seed_a", 147 | ), 148 | ) 149 | prompt_input_a = PromptInput( 150 | prompt=prompt, 151 | seed=seed, 152 | guidance=guidance, 153 | ) 154 | magic_mix_kmin = row[1].number_input("Kmin", value=0.3) 155 | magic_mix_kmax = row[2].number_input("Kmax", value=0.5) 156 | magic_mix_mix_factor = row[3].number_input("Mix Factor", value=0.5) 157 | else: 158 | prompt_input_a = PromptInput( 159 | guidance=guidance, 160 | **get_prompt_inputs( 161 | key="a", 162 | include_negative_prompt=True, 163 | cols=True, 164 | denoising_default=denoising_default, 165 | ), 166 | ) 167 | 168 | st.form_submit_button("Riff", type="primary", on_click=counter.increment) 169 | 170 | show_clip_details = st.sidebar.checkbox("Show Clip Details", True) 171 | show_difference = st.sidebar.checkbox("Show Difference", False) 172 | 173 | clip_segments = slice_audio_into_clips( 174 | segment=segment, 175 | clip_start_times=clip_start_times, 176 | clip_duration_s=clip_duration_s, 177 | ) 178 | 179 | if not prompt_input_a.prompt: 180 | st.info("Enter a prompt") 181 | return 182 | 183 | if counter.value == 0: 184 | return 185 | 186 | st.write(f"## Counter: {counter.value}") 187 | 188 | if use_20k: 189 | params = SpectrogramParams( 190 | min_frequency=10, 191 | max_frequency=20000, 192 | sample_rate=44100, 193 | stereo=True, 194 | ) 195 | else: 196 | params = SpectrogramParams( 197 | min_frequency=0, 198 | max_frequency=10000, 199 | stereo=False, 200 | ) 201 | 202 | if interpolate: 203 | # TODO(hayk): Make not linspace 204 | alphas = list(np.linspace(0, 1, len(clip_segments))) 205 | alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) 206 | st.write(f"**Alphas** : [{alphas_str}]") 207 | 208 | result_images: T.List[Image.Image] = [] 209 | result_segments: T.List[pydub.AudioSegment] = [] 210 | for i, clip_segment in enumerate(clip_segments): 211 | st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s") 212 | 213 | audio_bytes = io.BytesIO() 214 | clip_segment.export(audio_bytes, format="wav") 215 | 216 | init_image = streamlit_util.spectrogram_image_from_audio( 217 | clip_segment, 218 | params=params, 219 | device=device, 220 | ) 221 | 222 | # TODO(hayk): Roll this into spectrogram_image_from_audio? 223 | init_image_resized = scale_image_to_32_stride(init_image) 224 | 225 | progress_callback = None 226 | if show_clip_details: 227 | left, right = st.columns(2) 228 | 229 | left.write("##### Source Clip") 230 | left.image(init_image, use_column_width=False) 231 | left.audio(audio_bytes) 232 | 233 | right.write("##### Riffed Clip") 234 | empty_bin = right.empty() 235 | with empty_bin.container(): 236 | st.info("Riffing...") 237 | progress = st.progress(0.0) 238 | progress_callback = progress.progress 239 | 240 | if interpolate: 241 | assert use_magic_mix is False, "Cannot use magic mix and interpolate together" 242 | inputs = InferenceInput( 243 | alpha=float(alphas[i]), 244 | num_inference_steps=num_inference_steps, 245 | seed_image_id="og_beat", 246 | start=prompt_input_a, 247 | end=prompt_input_b, 248 | ) 249 | 250 | image, audio_bytes = run_interpolation( 251 | inputs=inputs, 252 | init_image=init_image_resized, 253 | device=device, 254 | checkpoint=checkpoint, 255 | ) 256 | elif use_magic_mix: 257 | assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix" 258 | image = streamlit_util.run_img2img_magic_mix( 259 | prompt=prompt_input_a.prompt, 260 | init_image=init_image_resized, 261 | num_inference_steps=num_inference_steps, 262 | guidance_scale=guidance, 263 | seed=prompt_input_a.seed, 264 | kmin=magic_mix_kmin, 265 | kmax=magic_mix_kmax, 266 | mix_factor=magic_mix_mix_factor, 267 | device=device, 268 | scheduler=scheduler, 269 | checkpoint=checkpoint, 270 | ) 271 | else: 272 | image = streamlit_util.run_img2img( 273 | prompt=prompt_input_a.prompt, 274 | init_image=init_image_resized, 275 | denoising_strength=prompt_input_a.denoising, 276 | num_inference_steps=num_inference_steps, 277 | guidance_scale=guidance, 278 | negative_prompt=prompt_input_a.negative_prompt, 279 | seed=prompt_input_a.seed, 280 | progress_callback=progress_callback, 281 | device=device, 282 | scheduler=scheduler, 283 | checkpoint=checkpoint, 284 | ) 285 | 286 | # Resize back to original size 287 | image = image.resize(init_image.size, Image.BICUBIC) 288 | 289 | result_images.append(image) 290 | 291 | if show_clip_details: 292 | empty_bin.empty() 293 | right.image(image, use_column_width=False) 294 | 295 | riffed_segment = streamlit_util.audio_segment_from_spectrogram_image( 296 | image=image, 297 | params=params, 298 | device=device, 299 | ) 300 | result_segments.append(riffed_segment) 301 | 302 | audio_bytes = io.BytesIO() 303 | riffed_segment.export(audio_bytes, format="wav") 304 | 305 | if show_clip_details: 306 | right.audio(audio_bytes) 307 | 308 | if show_clip_details and show_difference: 309 | diff_np = np.maximum( 310 | 0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32) 311 | ) 312 | diff_image = Image.fromarray(255 - diff_np.astype(np.uint8)) 313 | diff_segment = streamlit_util.audio_segment_from_spectrogram_image( 314 | image=diff_image, 315 | params=params, 316 | device=device, 317 | ) 318 | 319 | audio_bytes = io.BytesIO() 320 | diff_segment.export(audio_bytes, format=extension) 321 | st.audio(audio_bytes) 322 | 323 | # Combine clips with a crossfade based on overlap 324 | combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s) 325 | 326 | st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)") 327 | 328 | input_name = Path(audio_file.name).stem 329 | output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}" 330 | streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension) 331 | 332 | 333 | def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]: 334 | """ 335 | Render the parameters of slicing audio into clips. 336 | """ 337 | p: T.Dict[str, T.Any] = {} 338 | 339 | cols = st.columns(4) 340 | 341 | p["start_time_s"] = cols[0].number_input( 342 | "Start Time [s]", 343 | min_value=0.0, 344 | value=0.0, 345 | ) 346 | p["duration_s"] = cols[1].number_input( 347 | "Duration [s]", 348 | min_value=0.0, 349 | value=20.0, 350 | ) 351 | 352 | if advanced: 353 | p["clip_duration_s"] = cols[2].number_input( 354 | "Clip Duration [s]", 355 | min_value=3.0, 356 | max_value=10.0, 357 | value=5.0, 358 | ) 359 | else: 360 | p["clip_duration_s"] = 5.0 361 | 362 | if advanced: 363 | p["overlap_duration_s"] = cols[3].number_input( 364 | "Overlap Duration [s]", 365 | min_value=0.0, 366 | max_value=10.0, 367 | value=0.2, 368 | ) 369 | else: 370 | p["overlap_duration_s"] = 0.2 371 | 372 | return p 373 | 374 | 375 | def write_clip_details( 376 | clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float 377 | ): 378 | """ 379 | Write details of the clips to be sliced from an audio segment. 380 | """ 381 | clip_details_text = ( 382 | f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s " 383 | f"with overlap {overlap_duration_s}s" 384 | ) 385 | 386 | with st.expander(clip_details_text): 387 | st.dataframe( 388 | { 389 | "Start Time [s]": clip_start_times, 390 | "End Time [s]": clip_start_times + clip_duration_s, 391 | "Duration [s]": clip_duration_s, 392 | } 393 | ) 394 | 395 | 396 | def slice_audio_into_clips( 397 | segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float 398 | ) -> T.List[pydub.AudioSegment]: 399 | """ 400 | Slice an audio segment into a list of clips of a given duration at the given start times. 401 | """ 402 | clip_segments: T.List[pydub.AudioSegment] = [] 403 | for i, clip_start_time_s in enumerate(clip_start_times): 404 | clip_start_time_ms = int(clip_start_time_s * 1000) 405 | clip_duration_ms = int(clip_duration_s * 1000) 406 | clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms] 407 | 408 | # TODO(hayk): I don't think this is working properly 409 | if i == len(clip_start_times) - 1: 410 | silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000) 411 | if silence_ms > 0: 412 | clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms)) 413 | 414 | clip_segments.append(clip_segment) 415 | 416 | return clip_segments 417 | 418 | 419 | def scale_image_to_32_stride(image: Image.Image) -> Image.Image: 420 | """ 421 | Scale an image to a size that is a multiple of 32. 422 | """ 423 | closest_width = int(np.ceil(image.width / 32) * 32) 424 | closest_height = int(np.ceil(image.height / 32) * 32) 425 | return image.resize((closest_width, closest_height), Image.BICUBIC) 426 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | def render(): 5 | st.title("✨🎸 Riffusion Playground 🎸✨") 6 | 7 | st.write("Select a task from the sidebar to get started!") 8 | 9 | left, right = st.columns(2) 10 | 11 | with left: 12 | st.subheader("🌊 Text to Audio") 13 | st.write("Generate audio clips from text prompts.") 14 | 15 | st.subheader("✨ Audio to Audio") 16 | st.write("Upload audio and modify with text prompt (interpolation supported).") 17 | 18 | st.subheader("🎭 Interpolation") 19 | st.write("Interpolate between prompts in the latent space.") 20 | 21 | st.subheader("✂️ Audio Splitter") 22 | st.write("Split audio into stems like vocals, bass, drums, guitar, etc.") 23 | 24 | with right: 25 | st.subheader("📜 Text to Audio Batch") 26 | st.write("Generate audio in batch from a JSON file of text prompts.") 27 | 28 | st.subheader("📎 Sample Clips") 29 | st.write("Export short clips from an audio file.") 30 | 31 | st.subheader("⏈ Spectrogram to Audio") 32 | st.write("Reconstruct audio from spectrogram images.") 33 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/image_to_audio.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from pathlib import Path 3 | 4 | import streamlit as st 5 | from PIL import Image 6 | 7 | from riffusion.spectrogram_params import SpectrogramParams 8 | from riffusion.streamlit import util as streamlit_util 9 | from riffusion.util.image_util import exif_from_image 10 | 11 | 12 | def render() -> None: 13 | st.subheader("⏈ Image to Audio") 14 | st.write( 15 | """ 16 | Reconstruct audio from spectrogram images. 17 | """ 18 | ) 19 | 20 | with st.expander("Help", False): 21 | st.write( 22 | """ 23 | This tool takes an existing spectrogram image and reconstructs it into an audio 24 | waveform. It also displays the EXIF metadata stored inside the image, which can 25 | contain the parameters used to create the spectrogram image. If no EXIF is contained, 26 | assumes default parameters. 27 | """ 28 | ) 29 | 30 | device = streamlit_util.select_device(st.sidebar) 31 | extension = streamlit_util.select_audio_extension(st.sidebar) 32 | 33 | use_20k = st.sidebar.checkbox("Use 20kHz", value=False) 34 | 35 | image_file = st.file_uploader( 36 | "Upload a file", 37 | type=streamlit_util.IMAGE_EXTENSIONS, 38 | label_visibility="collapsed", 39 | ) 40 | if not image_file: 41 | st.info("Upload an image file to get started") 42 | return 43 | 44 | image = Image.open(image_file) 45 | st.image(image) 46 | 47 | with st.expander("Image metadata", expanded=False): 48 | exif = exif_from_image(image) 49 | st.json(exif) 50 | 51 | try: 52 | params = SpectrogramParams.from_exif(exif=image.getexif()) 53 | except KeyError: 54 | st.info("Could not find spectrogram parameters in exif data. Using defaults.") 55 | if use_20k: 56 | params = SpectrogramParams( 57 | min_frequency=10, 58 | max_frequency=20000, 59 | stereo=True, 60 | ) 61 | else: 62 | params = SpectrogramParams() 63 | 64 | with st.expander("Spectrogram Parameters", expanded=False): 65 | st.json(dataclasses.asdict(params)) 66 | 67 | segment = streamlit_util.audio_segment_from_spectrogram_image( 68 | image=image.copy(), 69 | params=params, 70 | device=device, 71 | ) 72 | 73 | streamlit_util.display_and_download_audio( 74 | segment, 75 | name=Path(image_file.name).stem, 76 | extension=extension, 77 | ) 78 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/interpolation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import io 3 | import typing as T 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pydub 8 | import streamlit as st 9 | from PIL import Image 10 | 11 | from riffusion.datatypes import InferenceInput, PromptInput 12 | from riffusion.spectrogram_params import SpectrogramParams 13 | from riffusion.streamlit import util as streamlit_util 14 | 15 | 16 | def render() -> None: 17 | st.subheader("🎭 Interpolation") 18 | st.write( 19 | """ 20 | Interpolate between prompts in the latent space. 21 | """ 22 | ) 23 | 24 | with st.expander("Help", False): 25 | st.write( 26 | """ 27 | This tool allows specifying two endpoints and generating a long-form interpolation 28 | between them that traverses the latent space. The interpolation is generated by 29 | the method described at https://www.riffusion.com/about. A seed image is used to 30 | set the beat and tempo of the generated audio, and can be set in the sidebar. 31 | Usually the seed is changed or the prompt, but not both at once. You can browse 32 | infinite variations of the same prompt by changing the seed. 33 | 34 | For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising. 35 | This will generate a 50 second clip at 5 seconds per step. Then play with the seeds 36 | or denoising to get different variations. 37 | """ 38 | ) 39 | 40 | # Sidebar params 41 | 42 | device = streamlit_util.select_device(st.sidebar) 43 | extension = streamlit_util.select_audio_extension(st.sidebar) 44 | 45 | num_interpolation_steps = T.cast( 46 | int, 47 | st.sidebar.number_input( 48 | "Interpolation steps", 49 | value=12, 50 | min_value=1, 51 | max_value=20, 52 | help="Number of model generations between the two prompts. Controls the duration.", 53 | ), 54 | ) 55 | 56 | num_inference_steps = T.cast( 57 | int, 58 | st.sidebar.number_input( 59 | "Steps per sample", value=50, help="Number of denoising steps per model run" 60 | ), 61 | ) 62 | 63 | guidance = st.sidebar.number_input( 64 | "Guidance", 65 | value=7.0, 66 | help="How much the model listens to the text prompt", 67 | ) 68 | 69 | init_image_name = st.sidebar.selectbox( 70 | "Seed image", 71 | # TODO(hayk): Read from directory 72 | options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"], 73 | index=0, 74 | help="Which seed image to use for img2img. Custom allows uploading your own.", 75 | ) 76 | assert init_image_name is not None 77 | if init_image_name == "custom": 78 | init_image_file = st.sidebar.file_uploader( 79 | "Upload a custom seed image", 80 | type=streamlit_util.IMAGE_EXTENSIONS, 81 | label_visibility="collapsed", 82 | ) 83 | if init_image_file: 84 | st.sidebar.image(init_image_file) 85 | 86 | alpha_power = st.sidebar.number_input("Alpha Power", value=1.0) 87 | 88 | show_individual_outputs = st.sidebar.checkbox( 89 | "Show individual outputs", 90 | value=False, 91 | help="Show each model output", 92 | ) 93 | show_images = st.sidebar.checkbox( 94 | "Show individual images", 95 | value=False, 96 | help="Show each generated image", 97 | ) 98 | 99 | alphas = np.linspace(0, 1, num_interpolation_steps) 100 | 101 | # Apply power scaling to alphas to customize the interpolation curve 102 | alphas_shifted = alphas * 2 - 1 103 | alphas_shifted = (np.abs(alphas_shifted) ** alpha_power * np.sign(alphas_shifted) + 1) / 2 104 | alphas = alphas_shifted 105 | 106 | alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas]) 107 | st.write(f"**Alphas** : [{alphas_str}]") 108 | 109 | # Prompt inputs A and B in two columns 110 | 111 | with st.form(key="interpolation_form"): 112 | left, right = st.columns(2) 113 | 114 | with left: 115 | st.write("##### Prompt A") 116 | prompt_input_a = PromptInput( 117 | guidance=guidance, **get_prompt_inputs(key="a", denoising_default=0.75) 118 | ) 119 | 120 | with right: 121 | st.write("##### Prompt B") 122 | prompt_input_b = PromptInput( 123 | guidance=guidance, **get_prompt_inputs(key="b", denoising_default=0.75) 124 | ) 125 | 126 | st.form_submit_button("Generate", type="primary") 127 | 128 | if not prompt_input_a.prompt or not prompt_input_b.prompt: 129 | st.info("Enter both prompts to interpolate between them") 130 | return 131 | 132 | if init_image_name == "custom": 133 | if not init_image_file: 134 | st.info("Upload a custom seed image") 135 | return 136 | init_image = Image.open(init_image_file).convert("RGB") 137 | else: 138 | init_image_path = ( 139 | Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png" 140 | ) 141 | init_image = Image.open(str(init_image_path)).convert("RGB") 142 | 143 | # TODO(hayk): Move this code into a shared place and add to riffusion.cli 144 | image_list: T.List[Image.Image] = [] 145 | audio_bytes_list: T.List[io.BytesIO] = [] 146 | for i, alpha in enumerate(alphas): 147 | inputs = InferenceInput( 148 | alpha=float(alpha), 149 | num_inference_steps=num_inference_steps, 150 | seed_image_id="og_beat", 151 | start=prompt_input_a, 152 | end=prompt_input_b, 153 | ) 154 | 155 | if i == 0: 156 | with st.expander("Example input JSON", expanded=False): 157 | st.json(dataclasses.asdict(inputs)) 158 | 159 | image, audio_bytes = run_interpolation( 160 | inputs=inputs, 161 | init_image=init_image, 162 | device=device, 163 | extension=extension, 164 | ) 165 | 166 | if show_individual_outputs: 167 | st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}") 168 | if show_images: 169 | st.image(image) 170 | st.audio(audio_bytes) 171 | 172 | image_list.append(image) 173 | audio_bytes_list.append(audio_bytes) 174 | 175 | st.write("#### Final Output") 176 | 177 | # TODO(hayk): Concatenate with overlap and better blending like in audio to audio 178 | audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list] 179 | concat_segment = audio_segments[0] 180 | for segment in audio_segments[1:]: 181 | concat_segment = concat_segment.append(segment, crossfade=0) 182 | 183 | audio_bytes = io.BytesIO() 184 | concat_segment.export(audio_bytes, format=extension) 185 | audio_bytes.seek(0) 186 | 187 | st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds") 188 | st.audio(audio_bytes) 189 | 190 | output_name = ( 191 | f"{prompt_input_a.prompt.replace(' ', '_')}_" 192 | f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}" 193 | ) 194 | st.download_button( 195 | output_name, 196 | data=audio_bytes, 197 | file_name=output_name, 198 | mime=f"audio/{extension}", 199 | ) 200 | 201 | 202 | def get_prompt_inputs( 203 | key: str, 204 | include_negative_prompt: bool = False, 205 | cols: bool = False, 206 | denoising_default: float = 0.5, 207 | ) -> T.Dict[str, T.Any]: 208 | """ 209 | Compute prompt inputs from widgets. 210 | """ 211 | p: T.Dict[str, T.Any] = {} 212 | 213 | # Optionally use columns 214 | left, right = T.cast(T.Any, st.columns(2) if cols else (st, st)) 215 | 216 | visibility = "visible" if cols else "collapsed" 217 | p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}") 218 | 219 | if include_negative_prompt: 220 | p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}") 221 | 222 | p["seed"] = T.cast( 223 | int, 224 | left.number_input( 225 | "Seed", 226 | value=42, 227 | key=f"seed_{key}", 228 | help="Integer used to generate a random result. Vary this to explore alternatives.", 229 | ), 230 | ) 231 | 232 | p["denoising"] = right.number_input( 233 | "Denoising", 234 | value=denoising_default, 235 | key=f"denoising_{key}", 236 | help="How much to modify the seed image", 237 | ) 238 | 239 | return p 240 | 241 | 242 | @st.cache_data 243 | def run_interpolation( 244 | inputs: InferenceInput, 245 | init_image: Image.Image, 246 | checkpoint: str = streamlit_util.DEFAULT_CHECKPOINT, 247 | device: str = "cuda", 248 | extension: str = "mp3", 249 | ) -> T.Tuple[Image.Image, io.BytesIO]: 250 | """ 251 | Cached function for riffusion interpolation. 252 | """ 253 | pipeline = streamlit_util.load_riffusion_checkpoint( 254 | device=device, 255 | checkpoint=checkpoint, 256 | # No trace so we can have variable width 257 | no_traced_unet=True, 258 | ) 259 | 260 | image = pipeline.riffuse( 261 | inputs, 262 | init_image=init_image, 263 | mask_image=None, 264 | ) 265 | 266 | # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained 267 | params = SpectrogramParams( 268 | min_frequency=0, 269 | max_frequency=10000, 270 | ) 271 | 272 | # Reconstruct from image to audio 273 | audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( 274 | image=image, 275 | params=params, 276 | device=device, 277 | output_format=extension, 278 | ) 279 | 280 | return image, audio_bytes 281 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/sample_clips.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import typing as T 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pydub 7 | import streamlit as st 8 | 9 | from riffusion.spectrogram_params import SpectrogramParams 10 | from riffusion.streamlit import util as streamlit_util 11 | 12 | 13 | def render() -> None: 14 | st.subheader("📎 Sample Clips") 15 | st.write( 16 | """ 17 | Export short clips from an audio file. 18 | """ 19 | ) 20 | 21 | with st.expander("Help", False): 22 | st.write( 23 | """ 24 | This tool simply allows uploading an audio file and randomly sampling short clips 25 | from it. It's useful for generating a large number of short clips from a single 26 | audio file. Outputs can be saved to a given directory with a given audio extension. 27 | """ 28 | ) 29 | 30 | audio_file = st.file_uploader( 31 | "Upload a file", 32 | type=streamlit_util.AUDIO_EXTENSIONS, 33 | label_visibility="collapsed", 34 | ) 35 | if not audio_file: 36 | st.info("Upload an audio file to get started") 37 | return 38 | 39 | st.audio(audio_file) 40 | 41 | segment = pydub.AudioSegment.from_file(audio_file) 42 | st.write( 43 | " \n".join( 44 | [ 45 | f"**Duration**: {segment.duration_seconds:.3f} seconds", 46 | f"**Channels**: {segment.channels}", 47 | f"**Sample rate**: {segment.frame_rate} Hz", 48 | f"**Sample width**: {segment.sample_width} bytes", 49 | ] 50 | ) 51 | ) 52 | 53 | device = streamlit_util.select_device(st.sidebar) 54 | extension = streamlit_util.select_audio_extension(st.sidebar) 55 | save_to_disk = st.sidebar.checkbox("Save to Disk", False) 56 | export_as_mono = st.sidebar.checkbox("Export as Mono", False) 57 | compute_spectrograms = st.sidebar.checkbox("Compute Spectrograms", False) 58 | 59 | row = st.columns(4) 60 | num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3)) 61 | duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000)) 62 | seed = T.cast(int, row[2].number_input("Seed", value=42)) 63 | 64 | counter = streamlit_util.StreamlitCounter() 65 | st.button("Sample Clips", type="primary", on_click=counter.increment) 66 | if counter.value == 0: 67 | return 68 | 69 | # Optionally pick an output directory 70 | if save_to_disk: 71 | output_dir = tempfile.mkdtemp(prefix="sample_clips_") 72 | output_path = Path(output_dir) 73 | output_path.mkdir(parents=True, exist_ok=True) 74 | st.info(f"Output directory: `{output_dir}`") 75 | 76 | if compute_spectrograms: 77 | images_dir = output_path / "images" 78 | images_dir.mkdir(parents=True, exist_ok=True) 79 | 80 | if seed >= 0: 81 | np.random.seed(seed) 82 | 83 | if export_as_mono and segment.channels > 1: 84 | segment = segment.set_channels(1) 85 | 86 | if save_to_disk: 87 | st.info(f"Writing {num_clips} clip(s) to `{str(output_path)}`") 88 | 89 | # TODO(hayk): Share code with riffusion.cli.sample_clips. 90 | segment_duration_ms = int(segment.duration_seconds * 1000) 91 | for i in range(num_clips): 92 | clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms) 93 | clip = segment[clip_start_ms : clip_start_ms + duration_ms] 94 | 95 | clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms" 96 | 97 | st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`") 98 | 99 | streamlit_util.display_and_download_audio( 100 | clip, 101 | name=clip_name, 102 | extension=extension, 103 | ) 104 | 105 | if save_to_disk: 106 | clip_path = output_path / f"{clip_name}.{extension}" 107 | clip.export(clip_path, format=extension) 108 | 109 | if compute_spectrograms: 110 | params = SpectrogramParams() 111 | 112 | image = streamlit_util.spectrogram_image_from_audio( 113 | clip, 114 | params=params, 115 | device=device, 116 | ) 117 | 118 | st.image(image) 119 | 120 | if save_to_disk: 121 | image_path = images_dir / f"{clip_name}.jpeg" 122 | image.save(image_path) 123 | 124 | if save_to_disk: 125 | st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`") 126 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/split_audio.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | from pathlib import Path 3 | 4 | import pydub 5 | import streamlit as st 6 | 7 | from riffusion.audio_splitter import split_audio 8 | from riffusion.streamlit import util as streamlit_util 9 | from riffusion.util import audio_util 10 | 11 | 12 | def render() -> None: 13 | st.subheader("✂️ Audio Splitter") 14 | st.write( 15 | """ 16 | Split audio into individual instrument stems. 17 | """ 18 | ) 19 | 20 | with st.expander("Help", False): 21 | st.write( 22 | """ 23 | This tool allows uploading an audio file of arbitrary length and splits it into 24 | stems of vocals, drums, bass, and other. It does this using a deep network that 25 | sweeps over the audio in clips, extracts the stems, and then cross fades the clips 26 | back together to construct the full length stems. It's particularly useful in 27 | combination with audio_to_audio, for example to split and preserve vocals while 28 | modifying the rest of the track with a prompt. Or, to pull out drums to add later 29 | in a DAW. 30 | """ 31 | ) 32 | 33 | device = streamlit_util.select_device(st.sidebar) 34 | 35 | extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"] 36 | extension = st.sidebar.selectbox( 37 | "Output format", 38 | options=extension_options, 39 | index=extension_options.index("mp3"), 40 | ) 41 | assert extension is not None 42 | 43 | audio_file = st.file_uploader( 44 | "Upload audio", 45 | type=extension_options, 46 | label_visibility="collapsed", 47 | ) 48 | 49 | stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"] 50 | recombine = st.sidebar.multiselect( 51 | "Recombine", 52 | options=stem_options, 53 | default=[], 54 | help="Recombine these stems at the end", 55 | ) 56 | 57 | if not audio_file: 58 | st.info("Upload audio to get started") 59 | return 60 | 61 | st.write("#### Original") 62 | st.audio(audio_file) 63 | 64 | counter = streamlit_util.StreamlitCounter() 65 | st.button("Split", type="primary", on_click=counter.increment) 66 | if counter.value == 0: 67 | return 68 | 69 | segment = streamlit_util.load_audio_file(audio_file) 70 | 71 | # Split 72 | stems = split_audio_cached(segment, device=device) 73 | 74 | input_name = Path(audio_file.name).stem 75 | 76 | # Display each 77 | for name in stem_options: 78 | stem = stems[name.lower()] 79 | st.write(f"#### Stem: {name}") 80 | 81 | output_name = f"{input_name}_{name.lower()}" 82 | streamlit_util.display_and_download_audio(stem, output_name, extension=extension) 83 | 84 | if recombine: 85 | recombine_lower = [r.lower() for r in recombine] 86 | segments = [s for name, s in stems.items() if name in recombine_lower] 87 | recombined = audio_util.overlay_segments(segments) 88 | 89 | # Display 90 | st.write(f"#### Recombined: {', '.join(recombine)}") 91 | output_name = f"{input_name}_{'_'.join(recombine_lower)}" 92 | streamlit_util.display_and_download_audio(recombined, output_name, extension=extension) 93 | 94 | 95 | @st.cache 96 | def split_audio_cached( 97 | segment: pydub.AudioSegment, device: str = "cuda" 98 | ) -> T.Dict[str, pydub.AudioSegment]: 99 | return split_audio(segment, device=device) 100 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/text_to_audio.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import streamlit as st 4 | 5 | from riffusion.spectrogram_params import SpectrogramParams 6 | from riffusion.streamlit import util as streamlit_util 7 | 8 | 9 | def render() -> None: 10 | st.subheader("🌊 Text to Audio") 11 | st.write( 12 | """ 13 | Generate audio from text prompts. 14 | """ 15 | ) 16 | 17 | with st.expander("Help", False): 18 | st.write( 19 | """ 20 | This tool runs riffusion in the simplest text to image form to generate an audio 21 | clip from a text prompt. There is no seed image or interpolation here. This mode 22 | allows more diversity and creativity than when using a seed image, but it also 23 | leads to having less control. Play with the seed to get infinite variations. 24 | """ 25 | ) 26 | 27 | device = streamlit_util.select_device(st.sidebar) 28 | extension = streamlit_util.select_audio_extension(st.sidebar) 29 | checkpoint = streamlit_util.select_checkpoint(st.sidebar) 30 | 31 | with st.form("Inputs"): 32 | prompt = st.text_input("Prompt") 33 | negative_prompt = st.text_input("Negative prompt") 34 | 35 | row = st.columns(4) 36 | num_clips = T.cast( 37 | int, 38 | row[0].number_input( 39 | "Number of clips", 40 | value=1, 41 | min_value=1, 42 | max_value=25, 43 | help="How many outputs to generate (seed gets incremented)", 44 | ), 45 | ) 46 | starting_seed = T.cast( 47 | int, 48 | row[1].number_input( 49 | "Seed", 50 | value=42, 51 | help="Change this to generate different variations", 52 | ), 53 | ) 54 | 55 | st.form_submit_button("Riff", type="primary") 56 | 57 | with st.sidebar: 58 | num_inference_steps = T.cast(int, st.number_input("Inference steps", value=30)) 59 | width = T.cast(int, st.number_input("Width", value=512)) 60 | guidance = st.number_input( 61 | "Guidance", value=7.0, help="How much the model listens to the text prompt" 62 | ) 63 | scheduler = st.selectbox( 64 | "Scheduler", 65 | options=streamlit_util.SCHEDULER_OPTIONS, 66 | index=0, 67 | help="Which diffusion scheduler to use", 68 | ) 69 | assert scheduler is not None 70 | 71 | use_20k = st.checkbox("Use 20kHz", value=False) 72 | 73 | if not prompt: 74 | st.info("Enter a prompt") 75 | return 76 | 77 | if use_20k: 78 | params = SpectrogramParams( 79 | min_frequency=10, 80 | max_frequency=20000, 81 | sample_rate=44100, 82 | stereo=True, 83 | ) 84 | else: 85 | params = SpectrogramParams( 86 | min_frequency=0, 87 | max_frequency=10000, 88 | stereo=False, 89 | ) 90 | 91 | seed = starting_seed 92 | for i in range(1, num_clips + 1): 93 | st.write(f"#### Riff {i} / {num_clips} - Seed {seed}") 94 | 95 | image = streamlit_util.run_txt2img( 96 | prompt=prompt, 97 | num_inference_steps=num_inference_steps, 98 | guidance=guidance, 99 | negative_prompt=negative_prompt, 100 | seed=seed, 101 | width=width, 102 | height=512, 103 | checkpoint=checkpoint, 104 | device=device, 105 | scheduler=scheduler, 106 | ) 107 | st.image(image) 108 | 109 | segment = streamlit_util.audio_segment_from_spectrogram_image( 110 | image=image, 111 | params=params, 112 | device=device, 113 | ) 114 | 115 | streamlit_util.display_and_download_audio( 116 | segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension 117 | ) 118 | 119 | seed += 1 120 | -------------------------------------------------------------------------------- /riffusion/streamlit/tasks/text_to_audio_batch.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing as T 3 | from pathlib import Path 4 | 5 | import streamlit as st 6 | 7 | from riffusion.spectrogram_params import SpectrogramParams 8 | from riffusion.streamlit import util as streamlit_util 9 | 10 | # Example input json file to process in batch 11 | EXAMPLE_INPUT = """ 12 | { 13 | "params": { 14 | "checkpoint": "riffusion/riffusion-model-v1", 15 | "scheduler": "DPMSolverMultistepScheduler", 16 | "num_inference_steps": 50, 17 | "guidance": 7.0, 18 | "width": 512, 19 | }, 20 | "entries": [ 21 | { 22 | "prompt": "Church bells", 23 | "seed": 42 24 | }, 25 | { 26 | "prompt": "electronic beats", 27 | "negative_prompt": "drums", 28 | "seed": 100 29 | }, 30 | { 31 | "prompt": "classical violin concerto", 32 | "seed": 4 33 | } 34 | ] 35 | } 36 | """ 37 | 38 | 39 | def render() -> None: 40 | st.subheader("📜 Text to Audio Batch") 41 | st.write( 42 | """ 43 | Generate audio in batch from a JSON file of text prompts. 44 | """ 45 | ) 46 | 47 | with st.expander("Help", False): 48 | st.write( 49 | """ 50 | This tool is a batch form of text_to_audio, where the inputs are read in from a JSON 51 | file. The input file contains a global params block and a list of entries with positive 52 | and negative prompts. It's useful for automating a larger set of generations. See the 53 | example inputs below for the format of the file. 54 | """ 55 | ) 56 | 57 | device = streamlit_util.select_device(st.sidebar) 58 | 59 | # Upload a JSON file 60 | json_file = st.file_uploader( 61 | "JSON file", 62 | type=["json"], 63 | label_visibility="collapsed", 64 | ) 65 | 66 | # Handle the null case 67 | if json_file is None: 68 | st.info("Upload a JSON file containing params and prompts") 69 | with st.expander("Example inputs.json", expanded=False): 70 | st.code(EXAMPLE_INPUT) 71 | return 72 | 73 | # Read in and print it 74 | data = json.loads(json_file.read()) 75 | with st.expander("Input Data", expanded=False): 76 | st.json(data) 77 | 78 | # Params can either be a list or a single entry 79 | if isinstance(data["params"], list): 80 | param_sets = data["params"] 81 | else: 82 | param_sets = [data["params"]] 83 | 84 | entries = data["entries"] 85 | 86 | show_images = st.sidebar.checkbox("Show Images", True) 87 | num_seeds = st.sidebar.number_input( 88 | "Num Seeds", 89 | value=1, 90 | min_value=1, 91 | max_value=10, 92 | help="When > 1, increments the seed and runs multiple for each entry", 93 | ) 94 | 95 | # Optionally specify an output directory 96 | output_dir = st.sidebar.text_input("Output Directory", "") 97 | output_path: T.Optional[Path] = None 98 | if output_dir: 99 | output_path = Path(output_dir) 100 | output_path.mkdir(parents=True, exist_ok=True) 101 | 102 | # Write title cards for each param set 103 | title_cols = st.columns(len(param_sets)) 104 | for i, params in enumerate(param_sets): 105 | col = title_cols[i] 106 | 107 | if "name" not in params: 108 | params["name"] = f"params[{i}]" 109 | 110 | col.write(f"## Param Set {i}") 111 | col.json(params) 112 | 113 | for entry_i, entry in enumerate(entries): 114 | st.write("---") 115 | print(entry) 116 | prompt = entry["prompt"] 117 | negative_prompt = entry.get("negative_prompt", None) 118 | 119 | base_seed = entry.get("seed", 42) 120 | 121 | text = f"##### ({base_seed}) {prompt}" 122 | if negative_prompt: 123 | text += f" \n**Negative**: {negative_prompt}" 124 | st.write(text) 125 | 126 | for seed in range(base_seed, base_seed + num_seeds): 127 | cols = st.columns(len(param_sets)) 128 | for i, params in enumerate(param_sets): 129 | col = cols[i] 130 | col.write(params["name"]) 131 | 132 | image = streamlit_util.run_txt2img( 133 | prompt=prompt, 134 | negative_prompt=negative_prompt, 135 | seed=seed, 136 | num_inference_steps=params.get("num_inference_steps", 50), 137 | guidance=params.get("guidance", 7.0), 138 | width=params.get("width", 512), 139 | checkpoint=params.get("checkpoint", streamlit_util.DEFAULT_CHECKPOINT), 140 | scheduler=params.get("scheduler", streamlit_util.SCHEDULER_OPTIONS[0]), 141 | height=512, 142 | device=device, 143 | ) 144 | 145 | if show_images: 146 | col.image(image) 147 | 148 | # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained 149 | p_spectrogram = SpectrogramParams( 150 | min_frequency=0, 151 | max_frequency=10000, 152 | ) 153 | 154 | output_format = "mp3" 155 | audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( 156 | image=image, 157 | params=p_spectrogram, 158 | device=device, 159 | output_format=output_format, 160 | ) 161 | col.audio(audio_bytes) 162 | 163 | if output_path: 164 | prompt_slug = entry["prompt"].replace(" ", "_") 165 | negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_") 166 | 167 | image_path = ( 168 | output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg" 169 | ) 170 | image.save(image_path, format="JPEG") 171 | entry["image_path"] = str(image_path) 172 | 173 | audio_path = ( 174 | output_path 175 | / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}" 176 | ) 177 | audio_path.write_bytes(audio_bytes.getbuffer()) 178 | entry["audio_path"] = str(audio_path) 179 | 180 | if output_path: 181 | output_json_path = output_path / "index.json" 182 | output_json_path.write_text(json.dumps(data, indent=4)) 183 | st.info(f"Output written to {str(output_path)}") 184 | else: 185 | st.info("Enter output directory in sidebar to save to disk") 186 | -------------------------------------------------------------------------------- /riffusion/streamlit/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Streamlit utilities (mostly cached wrappers around riffusion code). 3 | """ 4 | import io 5 | import threading 6 | import typing as T 7 | 8 | import pydub 9 | import streamlit as st 10 | import torch 11 | from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline 12 | from PIL import Image 13 | 14 | from riffusion.audio_splitter import AudioSplitter 15 | from riffusion.riffusion_pipeline import RiffusionPipeline 16 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter 17 | from riffusion.spectrogram_params import SpectrogramParams 18 | 19 | # TODO(hayk): Add URL params 20 | 21 | DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1" 22 | 23 | AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"] 24 | IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"] 25 | 26 | SCHEDULER_OPTIONS = [ 27 | "DPMSolverMultistepScheduler", 28 | "PNDMScheduler", 29 | "DDIMScheduler", 30 | "LMSDiscreteScheduler", 31 | "EulerDiscreteScheduler", 32 | "EulerAncestralDiscreteScheduler", 33 | ] 34 | 35 | 36 | @st.cache_resource 37 | def load_riffusion_checkpoint( 38 | checkpoint: str = DEFAULT_CHECKPOINT, 39 | no_traced_unet: bool = False, 40 | device: str = "cuda", 41 | ) -> RiffusionPipeline: 42 | """ 43 | Load the riffusion pipeline. 44 | """ 45 | return RiffusionPipeline.load_checkpoint( 46 | checkpoint=checkpoint, 47 | use_traced_unet=not no_traced_unet, 48 | device=device, 49 | ) 50 | 51 | 52 | @st.cache_resource 53 | def load_stable_diffusion_pipeline( 54 | checkpoint: str = DEFAULT_CHECKPOINT, 55 | device: str = "cuda", 56 | dtype: torch.dtype = torch.float16, 57 | scheduler: str = SCHEDULER_OPTIONS[0], 58 | ) -> StableDiffusionPipeline: 59 | """ 60 | Load the riffusion pipeline. 61 | 62 | TODO(hayk): Merge this into RiffusionPipeline to just load one model. 63 | """ 64 | if device == "cpu" or device.lower().startswith("mps"): 65 | print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") 66 | dtype = torch.float32 67 | 68 | pipeline = StableDiffusionPipeline.from_pretrained( 69 | checkpoint, 70 | revision="main", 71 | torch_dtype=dtype, 72 | safety_checker=lambda images, **kwargs: (images, False), 73 | ).to(device) 74 | 75 | pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config) 76 | 77 | return pipeline 78 | 79 | 80 | def get_scheduler(scheduler: str, config: T.Any) -> T.Any: 81 | """ 82 | Construct a denoising scheduler from a string. 83 | """ 84 | if scheduler == "PNDMScheduler": 85 | from diffusers import PNDMScheduler 86 | 87 | return PNDMScheduler.from_config(config) 88 | elif scheduler == "DPMSolverMultistepScheduler": 89 | from diffusers import DPMSolverMultistepScheduler 90 | 91 | return DPMSolverMultistepScheduler.from_config(config) 92 | elif scheduler == "DDIMScheduler": 93 | from diffusers import DDIMScheduler 94 | 95 | return DDIMScheduler.from_config(config) 96 | elif scheduler == "LMSDiscreteScheduler": 97 | from diffusers import LMSDiscreteScheduler 98 | 99 | return LMSDiscreteScheduler.from_config(config) 100 | elif scheduler == "EulerDiscreteScheduler": 101 | from diffusers import EulerDiscreteScheduler 102 | 103 | return EulerDiscreteScheduler.from_config(config) 104 | elif scheduler == "EulerAncestralDiscreteScheduler": 105 | from diffusers import EulerAncestralDiscreteScheduler 106 | 107 | return EulerAncestralDiscreteScheduler.from_config(config) 108 | else: 109 | raise ValueError(f"Unknown scheduler {scheduler}") 110 | 111 | 112 | @st.cache_resource 113 | def pipeline_lock() -> threading.Lock: 114 | """ 115 | Singleton lock used to prevent concurrent access to any model pipeline. 116 | """ 117 | return threading.Lock() 118 | 119 | 120 | @st.cache_resource 121 | def load_stable_diffusion_img2img_pipeline( 122 | checkpoint: str = DEFAULT_CHECKPOINT, 123 | device: str = "cuda", 124 | dtype: torch.dtype = torch.float16, 125 | scheduler: str = SCHEDULER_OPTIONS[0], 126 | ) -> StableDiffusionImg2ImgPipeline: 127 | """ 128 | Load the image to image pipeline. 129 | 130 | TODO(hayk): Merge this into RiffusionPipeline to just load one model. 131 | """ 132 | if device == "cpu" or device.lower().startswith("mps"): 133 | print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported") 134 | dtype = torch.float32 135 | 136 | pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( 137 | checkpoint, 138 | revision="main", 139 | torch_dtype=dtype, 140 | safety_checker=lambda images, **kwargs: (images, False), 141 | ).to(device) 142 | 143 | pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config) 144 | 145 | return pipeline 146 | 147 | 148 | @st.cache_data(persist=True) 149 | def run_txt2img( 150 | prompt: str, 151 | num_inference_steps: int, 152 | guidance: float, 153 | negative_prompt: str, 154 | seed: int, 155 | width: int, 156 | height: int, 157 | checkpoint: str = DEFAULT_CHECKPOINT, 158 | device: str = "cuda", 159 | scheduler: str = SCHEDULER_OPTIONS[0], 160 | ) -> Image.Image: 161 | """ 162 | Run the text to image pipeline with caching. 163 | """ 164 | with pipeline_lock(): 165 | pipeline = load_stable_diffusion_pipeline( 166 | checkpoint=checkpoint, 167 | device=device, 168 | scheduler=scheduler, 169 | ) 170 | 171 | generator_device = "cpu" if device.lower().startswith("mps") else device 172 | generator = torch.Generator(device=generator_device).manual_seed(seed) 173 | 174 | output = pipeline( 175 | prompt=prompt, 176 | num_inference_steps=num_inference_steps, 177 | guidance_scale=guidance, 178 | negative_prompt=negative_prompt or None, 179 | generator=generator, 180 | width=width, 181 | height=height, 182 | ) 183 | 184 | return output["images"][0] 185 | 186 | 187 | @st.cache_resource 188 | def spectrogram_image_converter( 189 | params: SpectrogramParams, 190 | device: str = "cuda", 191 | ) -> SpectrogramImageConverter: 192 | return SpectrogramImageConverter(params=params, device=device) 193 | 194 | 195 | @st.cache 196 | def spectrogram_image_from_audio( 197 | segment: pydub.AudioSegment, 198 | params: SpectrogramParams, 199 | device: str = "cuda", 200 | ) -> Image.Image: 201 | converter = spectrogram_image_converter(params=params, device=device) 202 | return converter.spectrogram_image_from_audio(segment) 203 | 204 | 205 | @st.cache_data 206 | def audio_segment_from_spectrogram_image( 207 | image: Image.Image, 208 | params: SpectrogramParams, 209 | device: str = "cuda", 210 | ) -> pydub.AudioSegment: 211 | converter = spectrogram_image_converter(params=params, device=device) 212 | return converter.audio_from_spectrogram_image(image) 213 | 214 | 215 | @st.cache_data 216 | def audio_bytes_from_spectrogram_image( 217 | image: Image.Image, 218 | params: SpectrogramParams, 219 | device: str = "cuda", 220 | output_format: str = "mp3", 221 | ) -> io.BytesIO: 222 | segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device) 223 | 224 | audio_bytes = io.BytesIO() 225 | segment.export(audio_bytes, format=output_format) 226 | 227 | return audio_bytes 228 | 229 | 230 | def select_device(container: T.Any = st.sidebar) -> str: 231 | """ 232 | Dropdown to select a torch device, with an intelligent default. 233 | """ 234 | default_device = "cpu" 235 | if torch.cuda.is_available(): 236 | default_device = "cuda" 237 | elif torch.backends.mps.is_available(): 238 | default_device = "mps" 239 | 240 | device_options = ["cuda", "cpu", "mps"] 241 | device = st.sidebar.selectbox( 242 | "Device", 243 | options=device_options, 244 | index=device_options.index(default_device), 245 | help="Which compute device to use. CUDA is recommended.", 246 | ) 247 | assert device is not None 248 | 249 | return device 250 | 251 | 252 | def select_audio_extension(container: T.Any = st.sidebar) -> str: 253 | """ 254 | Dropdown to select an audio extension, with an intelligent default. 255 | """ 256 | default = "mp3" if pydub.AudioSegment.ffmpeg else "wav" 257 | extension = container.selectbox( 258 | "Output format", 259 | options=AUDIO_EXTENSIONS, 260 | index=AUDIO_EXTENSIONS.index(default), 261 | ) 262 | assert extension is not None 263 | return extension 264 | 265 | 266 | def select_scheduler(container: T.Any = st.sidebar) -> str: 267 | """ 268 | Dropdown to select a scheduler. 269 | """ 270 | scheduler = st.sidebar.selectbox( 271 | "Scheduler", 272 | options=SCHEDULER_OPTIONS, 273 | index=0, 274 | help="Which diffusion scheduler to use", 275 | ) 276 | assert scheduler is not None 277 | return scheduler 278 | 279 | 280 | def select_checkpoint(container: T.Any = st.sidebar) -> str: 281 | """ 282 | Provide a custom model checkpoint. 283 | """ 284 | return container.text_input( 285 | "Custom Checkpoint", 286 | value=DEFAULT_CHECKPOINT, 287 | help="Provide a custom model checkpoint", 288 | ) 289 | 290 | 291 | @st.cache_data 292 | def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: 293 | return pydub.AudioSegment.from_file(audio_file) 294 | 295 | 296 | @st.cache_resource 297 | def get_audio_splitter(device: str = "cuda"): 298 | return AudioSplitter(device=device) 299 | 300 | 301 | @st.cache_resource 302 | def load_magic_mix_pipeline( 303 | checkpoint: str = DEFAULT_CHECKPOINT, 304 | device: str = "cuda", 305 | scheduler: str = SCHEDULER_OPTIONS[0], 306 | ): 307 | pipeline = DiffusionPipeline.from_pretrained( 308 | checkpoint, 309 | custom_pipeline="magic_mix", 310 | ).to(device) 311 | 312 | pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config) 313 | 314 | return pipeline 315 | 316 | 317 | @st.cache 318 | def run_img2img_magic_mix( 319 | prompt: str, 320 | init_image: Image.Image, 321 | num_inference_steps: int, 322 | guidance_scale: float, 323 | seed: int, 324 | kmin: float, 325 | kmax: float, 326 | mix_factor: float, 327 | checkpoint: str = DEFAULT_CHECKPOINT, 328 | device: str = "cuda", 329 | scheduler: str = SCHEDULER_OPTIONS[0], 330 | ): 331 | """ 332 | Run the magic mix pipeline for img2img. 333 | """ 334 | with pipeline_lock(): 335 | pipeline = load_magic_mix_pipeline( 336 | checkpoint=checkpoint, 337 | device=device, 338 | scheduler=scheduler, 339 | ) 340 | 341 | return pipeline( 342 | init_image, 343 | prompt=prompt, 344 | kmin=kmin, 345 | kmax=kmax, 346 | mix_factor=mix_factor, 347 | seed=seed, 348 | guidance_scale=guidance_scale, 349 | steps=num_inference_steps, 350 | ) 351 | 352 | 353 | @st.cache 354 | def run_img2img( 355 | prompt: str, 356 | init_image: Image.Image, 357 | denoising_strength: float, 358 | num_inference_steps: int, 359 | guidance_scale: float, 360 | seed: int, 361 | negative_prompt: T.Optional[str] = None, 362 | checkpoint: str = DEFAULT_CHECKPOINT, 363 | device: str = "cuda", 364 | scheduler: str = SCHEDULER_OPTIONS[0], 365 | progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, 366 | ) -> Image.Image: 367 | with pipeline_lock(): 368 | pipeline = load_stable_diffusion_img2img_pipeline( 369 | checkpoint=checkpoint, 370 | device=device, 371 | scheduler=scheduler, 372 | ) 373 | 374 | generator_device = "cpu" if device.lower().startswith("mps") else device 375 | generator = torch.Generator(device=generator_device).manual_seed(seed) 376 | 377 | num_expected_steps = max(int(num_inference_steps * denoising_strength), 1) 378 | 379 | def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None: 380 | if progress_callback is not None: 381 | progress_callback(step / num_expected_steps) 382 | 383 | result = pipeline( 384 | prompt=prompt, 385 | image=init_image, 386 | strength=denoising_strength, 387 | num_inference_steps=num_inference_steps, 388 | guidance_scale=guidance_scale, 389 | negative_prompt=negative_prompt or None, 390 | num_images_per_prompt=1, 391 | generator=generator, 392 | callback=callback, 393 | callback_steps=1, 394 | ) 395 | 396 | return result.images[0] 397 | 398 | 399 | class StreamlitCounter: 400 | """ 401 | Simple counter stored in streamlit session state. 402 | """ 403 | 404 | def __init__(self, key="_counter"): 405 | self.key = key 406 | if not st.session_state.get(self.key): 407 | st.session_state[self.key] = 0 408 | 409 | def increment(self): 410 | st.session_state[self.key] += 1 411 | 412 | @property 413 | def value(self): 414 | return st.session_state[self.key] 415 | 416 | 417 | def display_and_download_audio( 418 | segment: pydub.AudioSegment, 419 | name: str, 420 | extension: str = "mp3", 421 | ) -> None: 422 | """ 423 | Display the given audio segment and provide a button to download it with 424 | a proper file name, since st.audio doesn't support that. 425 | """ 426 | mime_type = f"audio/{extension}" 427 | audio_bytes = io.BytesIO() 428 | segment.export(audio_bytes, format=extension) 429 | st.audio(audio_bytes, format=mime_type) 430 | 431 | st.download_button( 432 | f"{name}.{extension}", 433 | data=audio_bytes, 434 | file_name=f"{name}.{extension}", 435 | mime=mime_type, 436 | ) 437 | -------------------------------------------------------------------------------- /riffusion/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/riffusion/util/__init__.py -------------------------------------------------------------------------------- /riffusion/util/audio_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Audio utility functions. 3 | """ 4 | 5 | import io 6 | import typing as T 7 | 8 | import numpy as np 9 | import pydub 10 | from scipy.io import wavfile 11 | 12 | 13 | def audio_from_waveform( 14 | samples: np.ndarray, sample_rate: int, normalize: bool = False 15 | ) -> pydub.AudioSegment: 16 | """ 17 | Convert a numpy array of samples of a waveform to an audio segment. 18 | 19 | Args: 20 | samples: (channels, samples) array 21 | """ 22 | # Normalize volume to fit in int16 23 | if normalize: 24 | samples *= np.iinfo(np.int16).max / np.max(np.abs(samples)) 25 | 26 | # Transpose and convert to int16 27 | samples = samples.transpose(1, 0) 28 | samples = samples.astype(np.int16) 29 | 30 | # Write to the bytes of a WAV file 31 | wav_bytes = io.BytesIO() 32 | wavfile.write(wav_bytes, sample_rate, samples) 33 | wav_bytes.seek(0) 34 | 35 | # Read into pydub 36 | return pydub.AudioSegment.from_wav(wav_bytes) 37 | 38 | 39 | def apply_filters(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment: 40 | """ 41 | Apply post-processing filters to the audio segment to compress it and 42 | keep at a -10 dBFS level. 43 | """ 44 | # TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end. 45 | # TODO(hayk): Is this going to make audio unbalanced between sequential clips? 46 | 47 | if compression: 48 | segment = pydub.effects.normalize( 49 | segment, 50 | headroom=0.1, 51 | ) 52 | 53 | segment = segment.apply_gain(-10 - segment.dBFS) 54 | 55 | # TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU 56 | segment = pydub.effects.compress_dynamic_range( 57 | segment, 58 | threshold=-20.0, 59 | ratio=4.0, 60 | attack=5.0, 61 | release=50.0, 62 | ) 63 | 64 | desired_db = -12 65 | segment = segment.apply_gain(desired_db - segment.dBFS) 66 | 67 | segment = pydub.effects.normalize( 68 | segment, 69 | headroom=0.1, 70 | ) 71 | 72 | return segment 73 | 74 | 75 | def stitch_segments( 76 | segments: T.Sequence[pydub.AudioSegment], crossfade_s: float 77 | ) -> pydub.AudioSegment: 78 | """ 79 | Stitch together a sequence of audio segments with a crossfade between each segment. 80 | """ 81 | crossfade_ms = int(crossfade_s * 1000) 82 | combined_segment = segments[0] 83 | for segment in segments[1:]: 84 | combined_segment = combined_segment.append(segment, crossfade=crossfade_ms) 85 | return combined_segment 86 | 87 | 88 | def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment: 89 | """ 90 | Overlay a sequence of audio segments on top of each other. 91 | """ 92 | assert len(segments) > 0 93 | output: pydub.AudioSegment = None 94 | for segment in segments: 95 | if output is None: 96 | output = segment 97 | else: 98 | output = output.overlay(segment) 99 | return output 100 | -------------------------------------------------------------------------------- /riffusion/util/base64_util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | 4 | 5 | def encode(buffer: io.BytesIO) -> str: 6 | """ 7 | Encode the given buffer as base64. 8 | """ 9 | return base64.encodebytes(buffer.getvalue()).decode("ascii") 10 | -------------------------------------------------------------------------------- /riffusion/util/fft_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | FFT tools to analyze frequency content of audio segments. This is not code for 3 | dealing with spectrogram images, but for analysis of waveforms. 4 | """ 5 | import struct 6 | import typing as T 7 | 8 | import numpy as np 9 | import plotly.graph_objects as go 10 | import pydub 11 | from scipy.fft import rfft, rfftfreq 12 | 13 | 14 | def plot_ffts( 15 | segments: T.Dict[str, pydub.AudioSegment], 16 | title: str = "FFT", 17 | min_frequency: float = 20, 18 | max_frequency: float = 20000, 19 | ) -> None: 20 | """ 21 | Plot an FFT analysis of the given audio segments. 22 | """ 23 | ffts = {name: compute_fft(seg) for name, seg in segments.items()} 24 | 25 | fig = go.Figure( 26 | data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()], 27 | layout={"title": title}, 28 | ) 29 | fig.update_xaxes( 30 | range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)], 31 | type="log", 32 | title="Frequency", 33 | ) 34 | fig.update_yaxes(title="Value") 35 | fig.show() 36 | 37 | 38 | def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]: 39 | """ 40 | Compute the FFT of the given audio segment as a mono signal. 41 | 42 | Returns: 43 | frequencies: FFT computed frequencies 44 | amplitudes: Amplitudes of each frequency 45 | """ 46 | # Convert to mono if needed. 47 | if sound.channels > 1: 48 | sound = sound.set_channels(1) 49 | 50 | sample_rate = sound.frame_rate 51 | 52 | num_samples = int(sound.frame_count()) 53 | samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data) 54 | 55 | fft_values = rfft(samples) 56 | amplitudes = np.abs(fft_values) 57 | 58 | frequencies = rfftfreq(n=num_samples, d=1 / sample_rate) 59 | 60 | return frequencies, amplitudes 61 | -------------------------------------------------------------------------------- /riffusion/util/image_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for converting between spectrograms tensors and spectrogram images, as well as 3 | general helpers for operating on pillow images. 4 | """ 5 | import typing as T 6 | 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from riffusion.spectrogram_params import SpectrogramParams 11 | 12 | 13 | def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image: 14 | """ 15 | Compute a spectrogram image from a spectrogram magnitude array. 16 | 17 | This is the inverse of spectrogram_from_image, except for discretization error from 18 | quantizing to uint8. 19 | 20 | Args: 21 | spectrogram: (channels, frequency, time) 22 | power: A power curve to apply to the spectrogram to preserve contrast 23 | 24 | Returns: 25 | image: (frequency, time, channels) 26 | """ 27 | # Rescale to 0-1 28 | max_value = np.max(spectrogram) 29 | data = spectrogram / max_value 30 | 31 | # Apply the power curve 32 | data = np.power(data, power) 33 | 34 | # Rescale to 0-255 35 | data = data * 255 36 | 37 | # Invert 38 | data = 255 - data 39 | 40 | # Convert to uint8 41 | data = data.astype(np.uint8) 42 | 43 | # Munge channels into a PIL image 44 | if data.shape[0] == 1: 45 | # TODO(hayk): Do we want to write single channel to disk instead? 46 | image = Image.fromarray(data[0], mode="L").convert("RGB") 47 | elif data.shape[0] == 2: 48 | data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0) 49 | image = Image.fromarray(data, mode="RGB") 50 | else: 51 | raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}") 52 | 53 | # Flip Y 54 | image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) 55 | 56 | return image 57 | 58 | 59 | def spectrogram_from_image( 60 | image: Image.Image, 61 | power: float = 0.25, 62 | stereo: bool = False, 63 | max_value: float = 30e6, 64 | ) -> np.ndarray: 65 | """ 66 | Compute a spectrogram magnitude array from a spectrogram image. 67 | 68 | This is the inverse of image_from_spectrogram, except for discretization error from 69 | quantizing to uint8. 70 | 71 | Args: 72 | image: (frequency, time, channels) 73 | power: The power curve applied to the spectrogram 74 | stereo: Whether the spectrogram encodes stereo data 75 | max_value: The max value of the original spectrogram. In practice doesn't matter. 76 | 77 | Returns: 78 | spectrogram: (channels, frequency, time) 79 | """ 80 | # Convert to RGB if single channel 81 | if image.mode in ("P", "L"): 82 | image = image.convert("RGB") 83 | 84 | # Flip Y 85 | image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM) 86 | 87 | # Munge channels into a numpy array of (channels, frequency, time) 88 | data = np.array(image).transpose(2, 0, 1) 89 | if stereo: 90 | # Take the G and B channels as done in image_from_spectrogram 91 | data = data[[1, 2], :, :] 92 | else: 93 | data = data[0:1, :, :] 94 | 95 | # Convert to floats 96 | data = data.astype(np.float32) 97 | 98 | # Invert 99 | data = 255 - data 100 | 101 | # Rescale to 0-1 102 | data = data / 255 103 | 104 | # Reverse the power curve 105 | data = np.power(data, 1 / power) 106 | 107 | # Rescale to max value 108 | data = data * max_value 109 | 110 | return data 111 | 112 | 113 | def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]: 114 | """ 115 | Get the EXIF data from a PIL image as a dict. 116 | """ 117 | exif = pil_image.getexif() 118 | 119 | if exif is None or len(exif) == 0: 120 | return {} 121 | 122 | return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()} 123 | -------------------------------------------------------------------------------- /riffusion/util/torch_util.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def check_device(device: str, backup: str = "cpu") -> str: 8 | """ 9 | Check that the device is valid and available. If not, 10 | """ 11 | cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available() 12 | mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available() 13 | 14 | if cuda_not_found or mps_not_found: 15 | warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3) 16 | return backup 17 | 18 | return device 19 | 20 | 21 | def slerp( 22 | t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995 23 | ) -> torch.Tensor: 24 | """ 25 | Helper function to spherically interpolate two arrays v1 v2. 26 | """ 27 | if not isinstance(v0, np.ndarray): 28 | inputs_are_torch = True 29 | input_device = v0.device 30 | v0 = v0.cpu().numpy() 31 | v1 = v1.cpu().numpy() 32 | 33 | dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) 34 | if np.abs(dot) > dot_threshold: 35 | v2 = (1 - t) * v0 + t * v1 36 | else: 37 | theta_0 = np.arccos(dot) 38 | sin_theta_0 = np.sin(theta_0) 39 | theta_t = theta_0 * t 40 | sin_theta_t = np.sin(theta_t) 41 | s0 = np.sin(theta_0 - theta_t) / sin_theta_0 42 | s1 = sin_theta_t / sin_theta_0 43 | v2 = s0 * v0 + s1 * v1 44 | 45 | if inputs_are_torch: 46 | v2 = torch.from_numpy(v2).to(input_device) 47 | 48 | return v2 49 | -------------------------------------------------------------------------------- /seed_images/agile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/agile.png -------------------------------------------------------------------------------- /seed_images/marim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/marim.png -------------------------------------------------------------------------------- /seed_images/mask_beat_lines_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_beat_lines_80.png -------------------------------------------------------------------------------- /seed_images/mask_gradient_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_gradient_dark.png -------------------------------------------------------------------------------- /seed_images/mask_gradient_top_70.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_gradient_top_70.png -------------------------------------------------------------------------------- /seed_images/mask_gradient_top_fifth_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_gradient_top_fifth_75.png -------------------------------------------------------------------------------- /seed_images/mask_top_third_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_top_third_75.png -------------------------------------------------------------------------------- /seed_images/mask_top_third_95.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/mask_top_third_95.png -------------------------------------------------------------------------------- /seed_images/motorway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/motorway.png -------------------------------------------------------------------------------- /seed_images/og_beat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/og_beat.png -------------------------------------------------------------------------------- /seed_images/vibes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/seed_images/vibes.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # Load the version from file 6 | __version__ = Path("VERSION").read_text().strip() 7 | 8 | # Load packages 9 | packages = Path("requirements.txt").read_text().splitlines() 10 | 11 | setup( 12 | name="riffusion", 13 | packages=find_packages(exclude=[]), 14 | version=__version__, 15 | license="MIT", 16 | description="Riffusion - Stable diffusion for real-time music generation", 17 | author="Hayk Martiros", 18 | author_email="hayk.mart@gmail.com", 19 | long_description_content_type="text/markdown", 20 | url="https://github.com/riffusion/riffusion", 21 | keywords=[ 22 | "artificial intelligence", 23 | "audio generation", 24 | "music", 25 | "diffusion", 26 | "riffusion", 27 | "deep learning", 28 | "transformers", 29 | ], 30 | install_requires=packages, 31 | package_data={ 32 | "riffusion": ["py.typed"], 33 | }, 34 | classifiers=[ 35 | "Development Status :: 4 - Beta", 36 | "Intended Audience :: Developers", 37 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 38 | "License :: OSI Approved :: MIT License", 39 | "Programming Language :: Python :: 3.9", 40 | ], 41 | ) 42 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/__init__.py -------------------------------------------------------------------------------- /test/audio_to_image_test.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from riffusion.cli import audio_to_image 7 | from riffusion.spectrogram_params import SpectrogramParams 8 | 9 | from .test_case import TestCase 10 | 11 | 12 | class AudioToImageTest(TestCase): 13 | """ 14 | Test riffusion.cli audio-to-image 15 | """ 16 | 17 | @classmethod 18 | def default_params(cls) -> T.Dict: 19 | return dict( 20 | step_size_ms=10, 21 | num_frequencies=512, 22 | # TODO(hayk): Change these to [20, 20000] once a model is updated 23 | min_frequency=0, 24 | max_frequency=10000, 25 | stereo=False, 26 | device=cls.DEVICE, 27 | ) 28 | 29 | def test_audio_to_image(self) -> None: 30 | """ 31 | Test audio-to-image with default params. 32 | """ 33 | params = self.default_params() 34 | self.helper_test_with_params(params) 35 | 36 | def test_stereo(self) -> None: 37 | """ 38 | Test audio-to-image with stereo=True. 39 | """ 40 | params = self.default_params() 41 | params["stereo"] = True 42 | self.helper_test_with_params(params) 43 | 44 | def helper_test_with_params(self, params: T.Dict) -> None: 45 | audio_path = ( 46 | self.TEST_DATA_PATH 47 | / "tired_traveler" 48 | / "clips" 49 | / "clip_2_start_103694_ms_duration_5678_ms.wav" 50 | ) 51 | output_dir = self.get_tmp_dir("audio_to_image_") 52 | 53 | if params["stereo"]: 54 | stem = f"{audio_path.stem}_stereo" 55 | else: 56 | stem = audio_path.stem 57 | 58 | image_path = output_dir / f"{stem}.png" 59 | 60 | audio_to_image(audio=str(audio_path), image=str(image_path), **params) 61 | 62 | # Check that the image exists 63 | self.assertTrue(image_path.exists()) 64 | 65 | pil_image = Image.open(image_path) 66 | 67 | # Check the image mode 68 | self.assertEqual(pil_image.mode, "RGB") 69 | 70 | # Check the image dimensions 71 | duration_ms = 5678 72 | self.assertTrue(str(duration_ms) in audio_path.name) 73 | expected_image_width = round(duration_ms / params["step_size_ms"]) 74 | self.assertEqual(pil_image.width, expected_image_width) 75 | self.assertEqual(pil_image.height, params["num_frequencies"]) 76 | 77 | # Get channels as numpy arrays 78 | channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))] 79 | self.assertEqual(len(channels), 3) 80 | 81 | if params["stereo"]: 82 | # Check that the first channel is zero 83 | self.assertTrue(np.all(channels[0] == 0)) 84 | else: 85 | # Check that all channels are the same 86 | self.assertTrue(np.all(channels[0] == channels[1])) 87 | self.assertTrue(np.all(channels[0] == channels[2])) 88 | 89 | # Check that the image has exif data 90 | exif = pil_image.getexif() 91 | self.assertIsNotNone(exif) 92 | params_from_exif = SpectrogramParams.from_exif(exif) 93 | expected_params = SpectrogramParams( 94 | stereo=params["stereo"], 95 | step_size_ms=params["step_size_ms"], 96 | num_frequencies=params["num_frequencies"], 97 | max_frequency=params["max_frequency"], 98 | ) 99 | self.assertTrue(params_from_exif == expected_params) 100 | -------------------------------------------------------------------------------- /test/image_to_audio_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pydub 4 | 5 | from riffusion.cli import image_to_audio 6 | 7 | from .test_case import TestCase 8 | 9 | 10 | class ImageToAudioTest(TestCase): 11 | """ 12 | Test riffusion.cli image-to-audio 13 | """ 14 | 15 | def test_image_to_audio_mono(self) -> None: 16 | self.helper_image_to_audio( 17 | song_dir=self.TEST_DATA_PATH / "tired_traveler", 18 | clip_name="clip_2_start_103694_ms_duration_5678_ms", 19 | stereo=False, 20 | ) 21 | 22 | def test_image_to_audio_stereo(self) -> None: 23 | self.helper_image_to_audio( 24 | song_dir=self.TEST_DATA_PATH / "tired_traveler", 25 | clip_name="clip_2_start_103694_ms_duration_5678_ms", 26 | stereo=True, 27 | ) 28 | 29 | def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None: 30 | if stereo: 31 | image_stem = clip_name + "_stereo" 32 | else: 33 | image_stem = clip_name 34 | 35 | image_path = song_dir / "images" / f"{image_stem}.png" 36 | output_dir = self.get_tmp_dir("image_to_audio_") 37 | audio_path = output_dir / f"{image_path.stem}.wav" 38 | 39 | image_to_audio( 40 | image=str(image_path), 41 | audio=str(audio_path), 42 | device=self.DEVICE, 43 | ) 44 | 45 | # Check that the audio exists 46 | self.assertTrue(audio_path.exists()) 47 | 48 | # Load the reconstructed audio and the original clip 49 | segment = pydub.AudioSegment.from_file(str(audio_path)) 50 | expected_segment = pydub.AudioSegment.from_file( 51 | str(song_dir / "clips" / f"{clip_name}.wav") 52 | ) 53 | 54 | # Check sample rate 55 | self.assertEqual(segment.frame_rate, expected_segment.frame_rate) 56 | 57 | # Check duration 58 | actual_duration_ms = round(segment.duration_seconds * 1000) 59 | expected_duration_ms = round(expected_segment.duration_seconds * 1000) 60 | self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10) 61 | 62 | # Check the number of channels 63 | self.assertEqual(expected_segment.channels, 2) 64 | if stereo: 65 | self.assertEqual(segment.channels, 2) 66 | else: 67 | self.assertEqual(segment.channels, 1) 68 | 69 | 70 | if __name__ == "__main__": 71 | TestCase.main() 72 | -------------------------------------------------------------------------------- /test/image_util_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydub 3 | 4 | from riffusion.spectrogram_converter import SpectrogramConverter 5 | from riffusion.spectrogram_params import SpectrogramParams 6 | from riffusion.util import image_util 7 | 8 | from .test_case import TestCase 9 | 10 | 11 | class ImageUtilTest(TestCase): 12 | """ 13 | Test riffusion.util.image_util 14 | """ 15 | 16 | def test_spectrogram_to_image_round_trip(self) -> None: 17 | audio_path = ( 18 | self.TEST_DATA_PATH 19 | / "tired_traveler" 20 | / "clips" 21 | / "clip_2_start_103694_ms_duration_5678_ms.wav" 22 | ) 23 | 24 | # Load up the audio file 25 | segment = pydub.AudioSegment.from_file(audio_path) 26 | 27 | # Convert to mono 28 | segment = segment.set_channels(1) 29 | 30 | # Compute a spectrogram with default params 31 | params = SpectrogramParams(sample_rate=segment.frame_rate) 32 | converter = SpectrogramConverter(params=params, device=self.DEVICE) 33 | spectrogram = converter.spectrogram_from_audio(segment) 34 | 35 | # Compute the image from the spectrogram 36 | image = image_util.image_from_spectrogram( 37 | spectrogram=spectrogram, 38 | power=params.power_for_image, 39 | ) 40 | 41 | # Save the max value 42 | max_value = np.max(spectrogram) 43 | 44 | # Compute the spectrogram from the image 45 | spectrogram_reversed = image_util.spectrogram_from_image( 46 | image=image, 47 | max_value=max_value, 48 | power=params.power_for_image, 49 | stereo=params.stereo, 50 | ) 51 | 52 | # Check the shapes 53 | self.assertEqual(spectrogram.shape, spectrogram_reversed.shape) 54 | 55 | # Check the max values 56 | self.assertEqual(np.max(spectrogram), np.max(spectrogram_reversed)) 57 | 58 | # Check the median values 59 | self.assertTrue( 60 | np.allclose(np.median(spectrogram), np.median(spectrogram_reversed), rtol=0.05) 61 | ) 62 | 63 | # Make sure all values are somewhat similar, but allow for discretization error 64 | # TODO(hayk): Investigate error more closely 65 | self.assertTrue(np.allclose(spectrogram, spectrogram_reversed, rtol=0.15)) 66 | -------------------------------------------------------------------------------- /test/print_exif_test.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | 4 | from riffusion.cli import print_exif 5 | 6 | from .test_case import TestCase 7 | 8 | 9 | class PrintExifTest(TestCase): 10 | """ 11 | Test riffusion.cli print-exif 12 | """ 13 | 14 | def test_print_exif(self) -> None: 15 | """ 16 | Test print-exif. 17 | """ 18 | image_path = ( 19 | self.TEST_DATA_PATH 20 | / "tired_traveler" 21 | / "images" 22 | / "clip_2_start_103694_ms_duration_5678_ms.png" 23 | ) 24 | 25 | # Redirect stdout 26 | stdout = io.StringIO() 27 | with contextlib.redirect_stdout(stdout): 28 | print_exif(image=str(image_path)) 29 | 30 | # Check that a couple of values are printed 31 | self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue()) 32 | self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue()) 33 | -------------------------------------------------------------------------------- /test/sample_clips_test.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import pydub 4 | 5 | from riffusion.cli import sample_clips 6 | 7 | from .test_case import TestCase 8 | 9 | 10 | class SampleClipsTest(TestCase): 11 | """ 12 | Test riffusion.cli sample-clips 13 | """ 14 | 15 | @staticmethod 16 | def default_params() -> T.Dict: 17 | return dict( 18 | num_clips=3, 19 | duration_ms=5678, 20 | mono=False, 21 | extension="wav", 22 | seed=42, 23 | ) 24 | 25 | def test_sample_clips(self) -> None: 26 | """ 27 | Test sample-clips with default params. 28 | """ 29 | params = self.default_params() 30 | self.helper_test_with_params(params) 31 | 32 | def test_mono(self) -> None: 33 | """ 34 | Test sample-clips with mono=True. 35 | """ 36 | params = self.default_params() 37 | params["mono"] = True 38 | params["num_clips"] = 1 39 | self.helper_test_with_params(params) 40 | 41 | def test_mp3(self) -> None: 42 | """ 43 | Test sample-clips with extension=mp3. 44 | """ 45 | if pydub.AudioSegment.converter is None: 46 | self.skipTest("skipping, ffmpeg not found") 47 | 48 | params = self.default_params() 49 | params["extension"] = "mp3" 50 | params["num_clips"] = 1 51 | self.helper_test_with_params(params) 52 | 53 | def helper_test_with_params(self, params: T.Dict) -> None: 54 | """ 55 | Test sample-clips with the given params. 56 | """ 57 | audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3" 58 | output_dir = self.get_tmp_dir("sample_clips_") 59 | 60 | sample_clips( 61 | audio=str(audio_path), 62 | output_dir=str(output_dir), 63 | **params, 64 | ) 65 | 66 | # For each file in output dir 67 | counter = 0 68 | for clip_path in output_dir.iterdir(): 69 | # Check that it has the right extension 70 | self.assertEqual(clip_path.suffix, f".{params['extension']}") 71 | 72 | # Check that it has the right duration 73 | segment = pydub.AudioSegment.from_file(clip_path) 74 | self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"]) 75 | 76 | # Check that it has the right number of channels 77 | if params["mono"]: 78 | self.assertEqual(segment.channels, 1) 79 | else: 80 | self.assertEqual(segment.channels, 2) 81 | 82 | counter += 1 83 | 84 | self.assertEqual(counter, params["num_clips"]) 85 | 86 | 87 | if __name__ == "__main__": 88 | TestCase.main() 89 | -------------------------------------------------------------------------------- /test/spectrogram_converter_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing as T 3 | 4 | import pydub 5 | 6 | from riffusion.spectrogram_converter import SpectrogramConverter 7 | from riffusion.spectrogram_params import SpectrogramParams 8 | from riffusion.util import fft_util 9 | 10 | from .test_case import TestCase 11 | 12 | 13 | class SpectrogramConverterTest(TestCase): 14 | """ 15 | Test going from audio to spectrogram to audio, without converting to 16 | an image, to check quality loss of the reconstruction. 17 | 18 | This test allows comparing multiple sets of spectrogram params by listening to output audio 19 | and by plotting their FFTs. 20 | """ 21 | 22 | # TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces. 23 | 24 | def test_round_trip(self) -> None: 25 | audio_path = ( 26 | self.TEST_DATA_PATH 27 | / "tired_traveler" 28 | / "clips" 29 | / "clip_2_start_103694_ms_duration_5678_ms.wav" 30 | ) 31 | output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_") 32 | 33 | # Load up the audio file 34 | segment = pydub.AudioSegment.from_file(audio_path) 35 | 36 | # Convert to mono if desired 37 | use_stereo = False 38 | if use_stereo: 39 | assert segment.channels == 2 40 | else: 41 | segment = segment.set_channels(1) 42 | 43 | # Define named sets of parameters 44 | param_sets: T.Dict[str, SpectrogramParams] = {} 45 | 46 | param_sets["default"] = SpectrogramParams( 47 | sample_rate=segment.frame_rate, 48 | stereo=use_stereo, 49 | step_size_ms=10, 50 | min_frequency=20, 51 | max_frequency=20000, 52 | num_frequencies=512, 53 | ) 54 | 55 | if self.DEBUG: 56 | param_sets["freq_0_to_10k"] = dataclasses.replace( 57 | param_sets["default"], 58 | min_frequency=0, 59 | max_frequency=10000, 60 | ) 61 | 62 | segments: T.Dict[str, pydub.AudioSegment] = { 63 | "original": segment, 64 | } 65 | for name, params in param_sets.items(): 66 | converter = SpectrogramConverter(params=params, device=self.DEVICE) 67 | spectrogram = converter.spectrogram_from_audio(segment) 68 | segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True) 69 | 70 | # Save segments to disk 71 | for name, segment in segments.items(): 72 | audio_out = output_dir / f"{name}.wav" 73 | segment.export(audio_out, format="wav") 74 | print(f"Saved {audio_out}") 75 | 76 | # Check params 77 | self.assertEqual(segments["default"].channels, 2 if use_stereo else 1) 78 | self.assertEqual(segments["original"].channels, segments["default"].channels) 79 | self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate) 80 | self.assertEqual(segments["original"].sample_width, segments["default"].sample_width) 81 | 82 | # TODO(hayk): Test something more rigorous about the quality of the reconstruction. 83 | 84 | # If debugging, load up a browser tab plotting the FFTs 85 | if self.DEBUG: 86 | fft_util.plot_ffts(segments) 87 | -------------------------------------------------------------------------------- /test/spectrogram_image_converter_test.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing as T 3 | 4 | import pydub 5 | from PIL import Image 6 | 7 | from riffusion.spectrogram_image_converter import SpectrogramImageConverter 8 | from riffusion.spectrogram_params import SpectrogramParams 9 | from riffusion.util import fft_util 10 | 11 | from .test_case import TestCase 12 | 13 | 14 | class SpectrogramImageConverterTest(TestCase): 15 | """ 16 | Test going from audio to spectrogram images to audio, testing the quality loss of the 17 | end-to-end pipeline. 18 | 19 | This test allows comparing multiple sets of spectrogram params by listening to output audio 20 | and by plotting their FFTs. 21 | 22 | See spectrogram_converter_test.py for a similar test that does not convert to images. 23 | """ 24 | 25 | def test_round_trip(self) -> None: 26 | audio_path = ( 27 | self.TEST_DATA_PATH 28 | / "tired_traveler" 29 | / "clips" 30 | / "clip_2_start_103694_ms_duration_5678_ms.wav" 31 | ) 32 | output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_") 33 | 34 | # Load up the audio file 35 | segment = pydub.AudioSegment.from_file(audio_path) 36 | 37 | # Convert to mono if desired 38 | use_stereo = False 39 | if use_stereo: 40 | assert segment.channels == 2 41 | else: 42 | segment = segment.set_channels(1) 43 | 44 | # Define named sets of parameters 45 | param_sets: T.Dict[str, SpectrogramParams] = {} 46 | 47 | param_sets["default"] = SpectrogramParams( 48 | sample_rate=segment.frame_rate, 49 | stereo=use_stereo, 50 | step_size_ms=10, 51 | min_frequency=20, 52 | max_frequency=20000, 53 | num_frequencies=512, 54 | ) 55 | 56 | if self.DEBUG: 57 | param_sets["freq_0_to_10k"] = dataclasses.replace( 58 | param_sets["default"], 59 | min_frequency=0, 60 | max_frequency=10000, 61 | ) 62 | 63 | segments: T.Dict[str, pydub.AudioSegment] = { 64 | "original": segment, 65 | } 66 | images: T.Dict[str, Image.Image] = {} 67 | for name, params in param_sets.items(): 68 | converter = SpectrogramImageConverter(params=params, device=self.DEVICE) 69 | images[name] = converter.spectrogram_image_from_audio(segment) 70 | segments[name] = converter.audio_from_spectrogram_image( 71 | image=images[name], 72 | apply_filters=True, 73 | ) 74 | 75 | # Save images to disk 76 | for name, image in images.items(): 77 | image_out = output_dir / f"{name}.png" 78 | image.save(image_out, exif=image.getexif(), format="PNG") 79 | print(f"Saved {image_out}") 80 | 81 | # Save segments to disk 82 | for name, segment in segments.items(): 83 | audio_out = output_dir / f"{name}.wav" 84 | segment.export(audio_out, format="wav") 85 | print(f"Saved {audio_out}") 86 | 87 | # Check params 88 | self.assertEqual(segments["default"].channels, 2 if use_stereo else 1) 89 | self.assertEqual(segments["original"].channels, segments["default"].channels) 90 | self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate) 91 | self.assertEqual(segments["original"].sample_width, segments["default"].sample_width) 92 | 93 | # TODO(hayk): Test something more rigorous about the quality of the reconstruction. 94 | 95 | # If debugging, load up a browser tab plotting the FFTs 96 | if self.DEBUG: 97 | fft_util.plot_ffts(segments) 98 | -------------------------------------------------------------------------------- /test/test_case.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import typing as T 5 | import unittest 6 | import warnings 7 | from pathlib import Path 8 | 9 | 10 | class TestCase(unittest.TestCase): 11 | """ 12 | Base class for tests. 13 | """ 14 | 15 | # Where checked-in test data is stored 16 | TEST_DATA_PATH = Path(__file__).resolve().parent / "test_data" 17 | 18 | # Whether to run tests in debug mode (e.g. don't clean up temporary directories, show plots) 19 | DEBUG = bool(os.environ.get("RIFFUSION_TEST_DEBUG")) 20 | 21 | # Which torch device to use for tests 22 | DEVICE = os.environ.get("RIFFUSION_TEST_DEVICE", "cuda") 23 | 24 | @staticmethod 25 | def main(*args: T.Any, **kwargs: T.Any) -> None: 26 | """ 27 | Run the tests. 28 | """ 29 | unittest.main(*args, **kwargs) 30 | 31 | @classmethod 32 | def setUpClass(cls): 33 | warnings.filterwarnings("ignore", category=ResourceWarning) 34 | 35 | def get_tmp_dir(self, prefix: str) -> Path: 36 | """ 37 | Create a temporary directory. 38 | """ 39 | tmp_dir = tempfile.mkdtemp(prefix=prefix) 40 | 41 | # Clean up the temporary directory if not debugging 42 | if not self.DEBUG: 43 | self.addCleanup(lambda: shutil.rmtree(tmp_dir, ignore_errors=True)) 44 | 45 | dir_path = Path(tmp_dir) 46 | assert dir_path.is_dir() 47 | 48 | return dir_path 49 | -------------------------------------------------------------------------------- /test/test_data/README.md: -------------------------------------------------------------------------------- 1 | # Test Data 2 | 3 | ### tired_traveler 4 | 5 | * Song: Tired traveler on the way to home 6 | * Artist: Andrew Codeman 7 | * Source: https://freemusicarchive.org/ 8 | -------------------------------------------------------------------------------- /test/test_data/tired_traveler/clips/clip_0_start_15795_ms_duration_5678_ms.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/clips/clip_0_start_15795_ms_duration_5678_ms.wav -------------------------------------------------------------------------------- /test/test_data/tired_traveler/clips/clip_1_start_860_ms_duration_5678_ms.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/clips/clip_1_start_860_ms_duration_5678_ms.wav -------------------------------------------------------------------------------- /test/test_data/tired_traveler/clips/clip_2_start_103694_ms_duration_5678_ms.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/clips/clip_2_start_103694_ms_duration_5678_ms.wav -------------------------------------------------------------------------------- /test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms.png -------------------------------------------------------------------------------- /test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms_stereo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/images/clip_2_start_103694_ms_duration_5678_ms_stereo.png -------------------------------------------------------------------------------- /test/test_data/tired_traveler/tired_traveler.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riffusion/riffusion-hobby/94c29abbdc1de60b3e03131715ea8c3da12bd933/test/test_data/tired_traveler/tired_traveler.mp3 --------------------------------------------------------------------------------