├── .github └── workflows │ └── build.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ddsp ├── __init__.py ├── colab │ ├── README.md │ ├── __init__.py │ ├── colab_utils.py │ ├── demos │ │ ├── README.md │ │ ├── Train_VST.ipynb │ │ ├── pitch_detection.ipynb │ │ ├── timbre_transfer.ipynb │ │ └── train_autoencoder.ipynb │ └── tutorials │ │ ├── 0_processor.ipynb │ │ ├── 1_synths_and_effects.ipynb │ │ ├── 2_processor_group.ipynb │ │ ├── 3_training.ipynb │ │ ├── 4_core_functions.ipynb │ │ └── README.md ├── core.py ├── core_test.py ├── dags.py ├── dags_test.py ├── effects.py ├── effects_test.py ├── losses.py ├── losses_test.py ├── processors.py ├── processors_test.py ├── spectral_ops.py ├── spectral_ops_test.py ├── synths.py ├── synths_test.py ├── test_util.py ├── training │ ├── README.md │ ├── __init__.py │ ├── cloud.py │ ├── cloud_test.py │ ├── data.py │ ├── data_preparation │ │ ├── README.md │ │ ├── __init__.py │ │ ├── ddsp_generate_synthetic_dataset.py │ │ ├── ddsp_prepare_tfrecord.py │ │ ├── prepare_tfrecord_lib.py │ │ ├── prepare_tfrecord_lib_test.py │ │ └── synthetic_data.py │ ├── ddsp_export.py │ ├── ddsp_run.py │ ├── decoders.py │ ├── decoders_test.py │ ├── docker │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config_hypertune.yaml │ │ ├── config_multiple_vms.yaml │ │ ├── config_single_vm.yaml │ │ ├── ddsp_ai_platform.py │ │ ├── task.py │ │ └── task_test.py │ ├── encoders.py │ ├── eval_util.py │ ├── evaluators.py │ ├── gin │ │ ├── __init__.py │ │ ├── datasets │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── base.gin │ │ │ ├── nsynth.gin │ │ │ ├── tfrecord.gin │ │ │ └── urmp │ │ │ │ ├── README.md │ │ │ │ ├── all.gin │ │ │ │ ├── all_midi.gin │ │ │ │ ├── base.gin │ │ │ │ └── midi_base.gin │ │ ├── eval │ │ │ ├── __init__.py │ │ │ ├── basic.gin │ │ │ ├── basic_f0_ld.gin │ │ │ ├── basic_f0_ld_twm.gin │ │ │ ├── heuristic.gin │ │ │ ├── heuristic_power.gin │ │ │ └── midi_ae.gin │ │ ├── models │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── ae.gin │ │ │ ├── midiae │ │ │ │ ├── README.md │ │ │ │ ├── midiae.gin │ │ │ │ ├── mixins │ │ │ │ │ ├── README.md │ │ │ │ │ ├── _.gin │ │ │ │ │ ├── hmm_prior.gin │ │ │ │ │ ├── midi_encoder.gin │ │ │ │ │ └── recon_lossgroup.gin │ │ │ │ └── z_midiae.gin │ │ │ ├── solo_instrument.gin │ │ │ └── vst │ │ │ │ ├── __init__.py │ │ │ │ ├── vst.gin │ │ │ │ ├── vst_32k.gin │ │ │ │ └── vst_48k.gin │ │ ├── optimization │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── base.gin │ │ │ └── base_tpu.gin │ │ └── papers │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── iclr2020 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── nsynth_ae.gin │ │ │ ├── solo_instrument.gin │ │ │ └── tiny_instrument.gin │ │ │ └── icml2020 │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── finetune_dataset.gin │ │ │ ├── finetune_model.gin │ │ │ ├── pretrain_dataset.gin │ │ │ └── pretrain_model.gin │ ├── heuristics.py │ ├── heuristics_test.py │ ├── inference.py │ ├── metrics.py │ ├── metrics_test.py │ ├── models │ │ ├── __init__.py │ │ ├── autoencoder.py │ │ ├── autoencoder_test.py │ │ ├── inverse_synthesis.py │ │ ├── midi_autoencoder.py │ │ └── model.py │ ├── nn.py │ ├── nn_test.py │ ├── plotting.py │ ├── postprocessing.py │ ├── preprocessing.py │ ├── preprocessing_test.py │ ├── summaries.py │ ├── train_util.py │ └── trainers.py └── version.py ├── pylintrc ├── setup.cfg ├── setup.py ├── update_gin_config.py └── update_pip.sh /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: '3.8' 14 | - name: Install dependencies 15 | # TODO(jesseengel): Remove `legacy-resolver` when pip dependency resolution 16 | # is no longer broken. Currently stalls (backtracking), only on GitHub, not 17 | # locally. 18 | run: | 19 | sudo apt update 20 | sudo apt-get -y install ffmpeg 21 | pip install -U pip 22 | pip install -e .[data_preparation,test] --use-deprecated=legacy-resolver 23 | - name: Test with pytest 24 | run: pytest 25 | - name: Lint with pylint 26 | run: pylint ddsp 27 | # The below step just reports the success or failure of tests as a "commit status". 28 | # This is needed for copybara integration. 29 | - name: Report success or failure as github status 30 | if: always() 31 | shell: bash 32 | run: | 33 | status="${{ job.status }}" 34 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') 35 | curl -sS --request POST \ 36 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ 37 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ 38 | --header 'content-type: application/json' \ 39 | --data '{ 40 | "state": "'$lowercase_status'", 41 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", 42 | "description": "'$status'", 43 | "context": "github-actions/build" 44 | }' 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | DSP can be subtle to get completely right, so we particularly appreciate the 5 | contributions of those with expertise in signal processing to help fix any 6 | mistakes we may have made 😄. 7 | 8 | # Versioning 9 | 10 | We'll do our best to keep the version updated. This repo contains two code bases 11 | which makes versioning a bit tricky. The core code base `ddsp/` and a more 12 | experimental training code base `ddsp/training/` that is used for active 13 | research. We will thus adopt the following scheme for incrementing version: 14 | 15 | `vMajor.Minor.Revision` 16 | 17 | * Major: Breaking change in `ddsp/` 18 | * Minor: New feature in `ddsp/`, breaking change in `training/` 19 | * Revision: New feature in `training/`, minor bug fix anywhere 20 | 21 | ## Code Design Goals 22 | As much as we can, we would like the DDSP library to be approachable, 23 | well-tested, well-documented, and full of useful examples. Thus, PRs that add 24 | new functionality should be accompanined with ample documentation and tests to 25 | help newcomers understand a typical use case, and guard against silent failures 26 | from breaking changes in the future. Please follow the existing doc/testing 27 | style when you can. 28 | 29 | To ensure a consistent style, new code should follow the [Google's Python Style Guide](https://google.github.io/styleguide/pyguide.html) 30 | and will need to pass a google style linter before acceptance. While this can 31 | add a little work up front, and occasionally make things more verbose, it helps 32 | reduce mental overhead and makes the code more readable. 33 | 34 | ## Code reviews 35 | 36 | All submissions, including submissions by project members, require review. We 37 | use GitHub pull requests for this purpose. Consult 38 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 39 | information on using pull requests. 40 | 41 | Please be sure to test your code by running `pytest` and `pylint` before 42 | submitting a pull request for review. Note that code cannot be merged until 43 | these tests pass on [GitHub Actions](https://github.com/magenta/ddsp/actions?query=workflow%3Abuild). 44 | 45 | 46 | ## Getting Started 47 | 48 | If you're looking for a way to contribute, but not sure where to start, you 49 | could: 50 | 51 | * Add some documentation to an existing function. 52 | * Add a missing test to improve coverage. 53 | * Add type hints to functions in a new file. 54 | * Add a new colab tutorial or demo, covering a typical use case or showing something cool. 55 | * Respond to a bug or feature request in the github [Issues](github.com/magenta/ddsp/issues). 56 | * Add a new signal `Processor` and corresponding test. 57 | 58 | ## Contributor License Agreement 59 | 60 | Contributions to this project must be accompanied by a Contributor License 61 | Agreement. You (or your employer) retain the copyright to your contribution; 62 | this simply gives us permission to use and redistribute your contributions as 63 | part of the project. Head over to to see 64 | your current agreements on file or to sign a new one. 65 | 66 | You generally only need to submit a CLA once, so if you've already submitted one 67 | (even if it was for a different project), you probably don't need to do it 68 | again. 69 | 70 | 71 | ## Community Guidelines 72 | 73 | This project follows 74 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 75 | -------------------------------------------------------------------------------- /ddsp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base module for the differentiable digital signal processing library.""" 16 | 17 | # Module imports. 18 | from ddsp import core 19 | from ddsp import dags 20 | from ddsp import effects 21 | from ddsp import losses 22 | from ddsp import processors 23 | from ddsp import spectral_ops 24 | from ddsp import synths 25 | 26 | # Version number. 27 | from ddsp.version import __version__ 28 | -------------------------------------------------------------------------------- /ddsp/colab/README.md: -------------------------------------------------------------------------------- 1 | # Colab Notebooks 2 | 3 | Interactive notebooks to demonstrate DDSP. 4 | 5 | * [demos](./demos/): Self-contained demonstrations for training models and showing them in action. 6 | 7 | * [tutorials](./tutorials/): Interactive walkthroughs of the DDSP functions and APIs. 8 | 9 | -------------------------------------------------------------------------------- /ddsp/colab/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/colab/demos/README.md: -------------------------------------------------------------------------------- 1 | # Demos 2 | 3 | Here are colab notebooks for demonstrating neat things you can do with DDSP. 4 | 5 | * [timbre_transfer](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb): 6 | Convert audio between sound sources with pretrained models. Try turning your voice into a violin, or scratching your laptop and seeing how it sounds as a flute :). Pick from a selection of pretrained models or upload your own that you can train with the `train_autoencoder` demo. 7 | 8 | * [train_autoencoder](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/train_autoencoder.ipynb): 9 | Takes you through all the steps to convert audio files into a dataset and train your own DDSP autoencoder model. You can transfer data and models to/from google drive, and download a .zip file of your trained model to be used with the `timbre_transfer` demo. 10 | 11 | * [pitch_detection](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/pitch_detection.ipynb): 12 | Demonstration of self-supervised pitch detection models from [2020 ICML Workshop paper](https://openreview.net/forum?id=RlVTYWhsky7). 13 | 14 | * [Train_VST](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/Train_VST.ipynb): 15 | Simplified training colab for the real-time audio plugin (WIP). 16 | -------------------------------------------------------------------------------- /ddsp/colab/tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | This is the best place to start is the step-by-step tutorials for all the major library components. 4 | 5 | * [0_processor](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/0_processor.ipynb): 6 | Introduction to the Processor class. 7 | 8 | * [1_synths_and_effects](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/1_synths_and_effects.ipynb): 9 | Example usage of processors. 10 | 11 | * [2_processor_group](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/2_processor_group.ipynb): 12 | Stringing processors together in a ProcessorGroup. 13 | 14 | * [3_training](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/3_training.ipynb): 15 | Example of training on a single sound. 16 | 17 | * [4_core_functions](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/4_core_functions.ipynb): 18 | Extensive examples for most of the core DDSP functions. 19 | -------------------------------------------------------------------------------- /ddsp/dags.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of functions and layers of Directed Acyclical Graphs. 16 | 17 | DAGLayer exists as an alternative to manually specifying the forward pass in 18 | python. The advantage is that a variety of configurations can be 19 | programmatically specified via external dependency injection, such as with the 20 | `gin` library. 21 | """ 22 | 23 | from typing import Dict, Sequence, Tuple, Text, TypeVar 24 | 25 | from absl import logging 26 | from ddsp import core 27 | import gin 28 | import tensorflow.compat.v2 as tf 29 | 30 | tfkl = tf.keras.layers 31 | 32 | # Define Types. 33 | TensorDict = Dict[Text, tf.Tensor] 34 | KeyOrModule = TypeVar('KeyOrModule', Text, tf.Module) 35 | Node = Tuple[KeyOrModule, Sequence[Text], Sequence[Text]] 36 | DAG = Sequence[Node] 37 | 38 | # Helper Functions for DAGs --------------------------------------------------- 39 | filter_by_value = lambda d, cond: dict(filter(lambda e: cond(e[1]), d.items())) 40 | is_module = lambda v: isinstance(v, tf.Module) 41 | 42 | # Duck typing. 43 | is_loss = lambda v: hasattr(v, 'get_losses_dict') 44 | is_processor = lambda v: hasattr(v, 'get_signal') and hasattr(v, 'get_controls') 45 | 46 | 47 | def split_keras_kwargs(kwargs): 48 | """Strip keras specific kwargs.""" 49 | keras_kwargs = {} 50 | for key in ['training', 'mask', 'name']: 51 | if kwargs.get(key) is not None: 52 | keras_kwargs[key] = kwargs.pop(key) 53 | return keras_kwargs, kwargs 54 | 55 | 56 | # DAG and ProcessorGroup Classes ----------------------------------------------- 57 | @gin.register 58 | class DAGLayer(tfkl.Layer): 59 | """String modules together.""" 60 | 61 | def __init__(self, dag: DAG, **kwarg_modules): 62 | """Constructor. 63 | 64 | Args: 65 | dag: A directed acyclical graph in the form of a list of nodes. Each node 66 | has the form 67 | 68 | ['module', ['input_key', ...], ['output_key', ...]] 69 | 70 | 'module': Module instance or string name of module. For example, 71 | 'encoder' woud access the attribute `dag_layer.encoder`. 72 | 'input_key': List of strings, nested keys in dictionary of dag outputs. 73 | For example, 'inputs/f0_hz' would access `outputs[inputs]['f0_hz']`. 74 | Inputs to the dag are wrapped in a `inputs` dict as shown in the 75 | example. This list is ordered and has one key per a module input 76 | argument. Each node's outputs are prefixed by their module name. 77 | 'output_key': List of strings, keys for each return value of the module. 78 | For example, ['amps', 'freqs'] would have the module return a dict 79 | {'module_name': {'amps': return_value_0, 'freqs': return_value_1}}. 80 | If the module returns a dictionary, the keys of the dictionary will be 81 | used and these values (if provided) will be ignored. 82 | 83 | The graph is read sequentially and must be topologically sorted. This 84 | means that all inputs for a module must already be generated by earlier 85 | modules (or in the input dictionary). 86 | **kwarg_modules: A series of modules to add to DAGLayer. Each kwarg that 87 | is a tf.Module will be added as a property of the layer, so that it will 88 | be accessible as `dag_layer.kwarg`. Also, other keras kwargs such as 89 | 'name' are split off before adding modules. 90 | """ 91 | keras_kwargs, kwarg_modules = split_keras_kwargs(kwarg_modules) 92 | super().__init__(**keras_kwargs) 93 | 94 | # Create properties/submodules from other kwargs. 95 | modules = filter_by_value(kwarg_modules, is_module) 96 | 97 | # Remove modules from the dag, make properties of dag_layer. 98 | dag, dag_modules = self.format_dag(dag) 99 | # DAG is now just strings. 100 | self.dag = dag 101 | modules.update(dag_modules) 102 | 103 | # Make as propreties of DAGLayer to keep track of variables in checkpoints. 104 | self.module_names = list(modules.keys()) 105 | for module_name, module in modules.items(): 106 | setattr(self, module_name, module) 107 | 108 | @property 109 | def modules(self): 110 | """Module getter.""" 111 | return [getattr(self, name) for name in self.module_names] 112 | 113 | @staticmethod 114 | def format_dag(dag): 115 | """Remove modules from dag, and replace with module names.""" 116 | modules = {} 117 | dag = list(dag) # Make mutable in case it's a tuple. 118 | for i, node in enumerate(dag): 119 | node = list(node) # Make mutable in case it's a tuple. 120 | module = node[0] 121 | if is_module(module): 122 | # Strip module from the dag. 123 | modules[module.name] = module 124 | # Replace with module name. 125 | node[0] = module.name 126 | dag[i] = node 127 | return dag, modules 128 | 129 | def call(self, inputs: TensorDict, **kwargs) -> tf.Tensor: 130 | """Run dag for an input dictionary.""" 131 | return self.run_dag(inputs, **kwargs) 132 | 133 | @gin.configurable(allowlist=['verbose']) # For debugging. 134 | def run_dag(self, 135 | inputs: TensorDict, 136 | verbose: bool = False, 137 | **kwargs) -> TensorDict: 138 | """Connects and runs submodules of dag. 139 | 140 | Args: 141 | inputs: A dictionary of input tensors fed to the dag. 142 | verbose: Print out dag routing when running. 143 | **kwargs: Other kwargs to pass to submodules, such as keras kwargs. 144 | 145 | Returns: 146 | A nested dictionary of all the output tensors. 147 | """ 148 | # Initialize the outputs with inputs to the dag. 149 | outputs = {'inputs': inputs} 150 | # TODO(jesseengel): Remove this cluttering of the base namespace. Only there 151 | # for backwards compatability. 152 | outputs.update(inputs) 153 | 154 | # Run through the DAG nodes in sequential order. 155 | for node in self.dag: 156 | # The first element of the node can be either a module or module_key. 157 | module_key, input_keys = node[0], node[1] 158 | module = getattr(self, module_key) 159 | # Optionally specify output keys if module does not return dict. 160 | output_keys = node[2] if len(node) > 2 else None 161 | 162 | # Get the inputs to the node. 163 | inputs = [core.nested_lookup(key, outputs) for key in input_keys] 164 | 165 | if verbose: 166 | shape = lambda d: tf.nest.map_structure(lambda x: list(x.shape), d) 167 | logging.info('Input to Module: %s\nKeys: %s\nIn: %s\n', 168 | module_key, input_keys, shape(inputs)) 169 | 170 | # Duck typing to avoid dealing with multiple inheritance of Group modules. 171 | if is_processor(module): 172 | # Processor modules. 173 | module_outputs = module(*inputs, return_outputs_dict=True, **kwargs) 174 | elif is_loss(module): 175 | # Loss modules. 176 | module_outputs = module.get_losses_dict(*inputs, **kwargs) 177 | else: 178 | # Network modules. 179 | module_outputs = module(*inputs, **kwargs) 180 | 181 | if not isinstance(module_outputs, dict): 182 | module_outputs = core.to_dict(module_outputs, output_keys) 183 | 184 | if verbose: 185 | logging.info('Output from Module: %s\nOut: %s\n', 186 | module_key, shape(module_outputs)) 187 | 188 | # Add module outputs to the dictionary. 189 | outputs[module_key] = module_outputs 190 | 191 | # Alias final module output as dag output. 192 | # 'out' is a reserved key for final dag output. 193 | outputs['out'] = module_outputs 194 | 195 | return outputs 196 | -------------------------------------------------------------------------------- /ddsp/dags_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.dags.py.""" 16 | 17 | from absl.testing import parameterized 18 | from ddsp import dags 19 | import gin 20 | import tensorflow as tf 21 | 22 | # Make dense layers configurable for this test. 23 | gin.external_configurable(tf.keras.layers.Dense, 'tf.keras.layers.Dense') 24 | 25 | 26 | @gin.configurable 27 | class ConfigurableDAGLayer(dags.DAGLayer): 28 | """Configurable wrapper DAGLayer encapsulated for this test.""" 29 | pass 30 | 31 | 32 | class DAGLayerTest(parameterized.TestCase, tf.test.TestCase): 33 | 34 | def setUp(self): 35 | """Create some dummy input data for the chain.""" 36 | super().setUp() 37 | # Create inputs. 38 | self.n_batch = 4 39 | self.x_dims = 5 40 | self.z_dims = 2 41 | self.x = tf.ones([self.n_batch, self.x_dims]) 42 | self.inputs = {'test_data': self.x} 43 | self.gin_config_kwarg_modules = f""" 44 | import ddsp 45 | 46 | ### Modules 47 | ConfigurableDAGLayer.dag = [ 48 | ('encoder', ['inputs/test_data'], ['z']), 49 | ('bottleneck', ['encoder/z'], ['z_bottleneck']), 50 | ('decoder', ['bottleneck/z_bottleneck'], ['reconstruction']), 51 | ] 52 | ConfigurableDAGLayer.encoder = @encoder/layers.Dense() 53 | encoder/layers.Dense.units = {self.x_dims} 54 | 55 | ConfigurableDAGLayer.bottleneck = @bottleneck/layers.Dense() 56 | bottleneck/layers.Dense.units = {self.z_dims} 57 | 58 | ConfigurableDAGLayer.decoder = @decoder/layers.Dense() 59 | decoder/layers.Dense.units = {self.x_dims} 60 | """ 61 | self.gin_config_dag_modules = f""" 62 | import ddsp 63 | 64 | ### Modules 65 | ConfigurableDAGLayer.dag = [ 66 | (@encoder/layers.Dense(), ['inputs/test_data'], ['z']), 67 | (@bottleneck/layers.Dense(), ['encoder/z'], ['z_bottleneck']), 68 | (@decoder/layers.Dense(), ['bottleneck/z_bottleneck'], ['reconstruction']), 69 | ] 70 | encoder/layers.Dense.name = 'encoder' 71 | encoder/layers.Dense.units = {self.x_dims} 72 | 73 | bottleneck/layers.Dense.name = 'bottleneck' 74 | bottleneck/layers.Dense.units = {self.z_dims} 75 | 76 | decoder/layers.Dense.name = 'decoder' 77 | decoder/layers.Dense.units = {self.x_dims} 78 | """ 79 | 80 | @parameterized.named_parameters( 81 | ('kwarg_modules', True), 82 | ('dag_modules', False), 83 | ) 84 | def test_build_layer(self, kwarg_modules): 85 | """Tests if layer builds properly and produces outputs of correct shape.""" 86 | gin_config = (self.gin_config_kwarg_modules if kwarg_modules else 87 | self.gin_config_dag_modules) 88 | with gin.unlock_config(): 89 | gin.clear_config() 90 | gin.parse_config(gin_config) 91 | 92 | dag_layer = ConfigurableDAGLayer() 93 | outputs = dag_layer(self.inputs) 94 | self.assertIsInstance(outputs, dict) 95 | 96 | z = outputs['bottleneck']['z_bottleneck'] 97 | x_rec = outputs['decoder']['reconstruction'] 98 | x_rec2 = outputs['out']['reconstruction'] 99 | 100 | # Confirm that layer generates correctly sized tensors. 101 | self.assertEqual(outputs['test_data'].shape, self.x.shape) 102 | self.assertEqual(outputs['inputs']['test_data'].shape, self.x.shape) 103 | self.assertEqual(x_rec.shape, self.x.shape) 104 | self.assertEqual(z.shape[-1], self.z_dims) 105 | self.assertAllClose(x_rec, x_rec2) 106 | 107 | # Confirm that variables are inherited by DAGLayer. 108 | self.assertLen(dag_layer.trainable_variables, 6) # 3 weights, 3 biases. 109 | 110 | if __name__ == '__main__': 111 | tf.test.main() 112 | -------------------------------------------------------------------------------- /ddsp/effects_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.effects.""" 16 | 17 | from absl.testing import parameterized 18 | from ddsp import effects 19 | import tensorflow.compat.v2 as tf 20 | 21 | 22 | class ReverbTest(parameterized.TestCase, tf.test.TestCase): 23 | 24 | def setUp(self): 25 | """Creates some test specific attributes.""" 26 | super().setUp() 27 | self.reverb_class = effects.Reverb 28 | self.audio = tf.zeros((3, 16000)) 29 | self.construct_args = {'reverb_length': 100} 30 | self.call_args = {'ir': tf.zeros((3, 100, 1))} 31 | self.controls_keys = ['audio', 'ir'] 32 | 33 | @parameterized.named_parameters( 34 | ('trainable', True), 35 | ('not_trainable', False), 36 | ) 37 | def test_output_shape_and_variables_are_correct(self, trainable): 38 | reverb = self.reverb_class(trainable=trainable, **self.construct_args) 39 | if trainable: 40 | output = reverb(self.audio) 41 | else: 42 | output = reverb(self.audio, **self.call_args) 43 | 44 | self.assertListEqual(list(self.audio.shape), output.shape.as_list()) 45 | self.assertEqual(reverb.trainable, trainable) 46 | self.assertEmpty(reverb.non_trainable_variables) 47 | assert_variables = self.assertNotEmpty if trainable else self.assertEmpty 48 | assert_variables(reverb.trainable_variables) 49 | 50 | def test_non_trainable_raises_value_error(self): 51 | reverb = self.reverb_class(trainable=False, **self.construct_args) 52 | with self.assertRaises(ValueError): 53 | _ = reverb(self.audio) 54 | 55 | @parameterized.named_parameters( 56 | ('trainable', True), 57 | ('not_trainable', False), 58 | ) 59 | def test_get_controls_returns_correct_keys(self, trainable): 60 | reverb = self.reverb_class(trainable=trainable, **self.construct_args) 61 | reverb.build(self.audio.shape) 62 | if trainable: 63 | controls = reverb.get_controls(self.audio) 64 | else: 65 | controls = reverb.get_controls(self.audio, **self.call_args) 66 | 67 | self.assertListEqual(list(controls.keys()), self.controls_keys) 68 | 69 | 70 | class ExpDecayReverbTest(ReverbTest): 71 | 72 | def setUp(self): 73 | """Creates some test specific attributes.""" 74 | super().setUp() 75 | self.reverb_class = effects.ExpDecayReverb 76 | self.audio = tf.zeros((3, 16000)) 77 | self.construct_args = {'reverb_length': 100} 78 | self.call_args = {'gain': tf.zeros((3, 1)), 79 | 'decay': tf.zeros((3, 1))} 80 | 81 | 82 | class FilteredNoiseReverbTest(ReverbTest): 83 | 84 | def setUp(self): 85 | """Creates some test specific attributes.""" 86 | super().setUp() 87 | self.reverb_class = effects.FilteredNoiseReverb 88 | self.audio = tf.zeros((3, 16000)) 89 | self.construct_args = {'reverb_length': 100, 90 | 'n_frames': 10, 91 | 'n_filter_banks': 20} 92 | self.call_args = {'magnitudes': tf.zeros((3, 10, 20))} 93 | 94 | 95 | class FIRFilterTest(tf.test.TestCase): 96 | 97 | def test_output_shape_is_correct(self): 98 | processor = effects.FIRFilter() 99 | 100 | audio = tf.zeros((3, 16000)) 101 | magnitudes = tf.zeros((3, 100, 30)) 102 | output = processor(audio, magnitudes) 103 | 104 | self.assertListEqual([3, 16000], output.shape.as_list()) 105 | 106 | 107 | class ModDelayTest(tf.test.TestCase): 108 | 109 | def test_output_shape_is_correct(self): 110 | processor = effects.ModDelay() 111 | 112 | audio = tf.zeros((3, 16000)) 113 | gain = tf.zeros((3, 16000, 1)) 114 | phase = tf.zeros((3, 16000, 1)) 115 | output = processor(audio, gain, phase) 116 | 117 | self.assertListEqual([3, 16000], output.shape.as_list()) 118 | 119 | 120 | if __name__ == '__main__': 121 | tf.test.main() 122 | -------------------------------------------------------------------------------- /ddsp/losses_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.losses.""" 16 | 17 | from ddsp import core 18 | from ddsp import losses 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | 23 | class LossGroupTest(tf.test.TestCase): 24 | 25 | def setUp(self): 26 | """Create some dummy input data for the chain.""" 27 | super().setUp() 28 | 29 | # Create a network output dictionary. 30 | self.nn_outputs = { 31 | 'audio': tf.ones((3, 8000), dtype=tf.float32), 32 | 'audio_synth': tf.ones((3, 8000), dtype=tf.float32), 33 | 'magnitudes': tf.ones((3, 200, 2), dtype=tf.float32), 34 | 'f0_hz': 200 + tf.ones((3, 200, 1), dtype=tf.float32), 35 | } 36 | 37 | # Create Processors. 38 | spectral_loss = losses.SpectralLoss() 39 | crepe_loss = losses.PretrainedCREPEEmbeddingLoss(name='crepe_loss') 40 | 41 | # Create DAG for testing. 42 | self.dag = [ 43 | (spectral_loss, ['audio', 'audio_synth']), 44 | (crepe_loss, ['audio', 'audio_synth']), 45 | ] 46 | self.expected_outputs = [ 47 | 'spectral_loss', 48 | 'crepe_loss' 49 | ] 50 | 51 | def _check_tensor_outputs(self, strings_to_check, outputs): 52 | for tensor_string in strings_to_check: 53 | tensor = core.nested_lookup(tensor_string, outputs) 54 | self.assertIsInstance(tensor, (np.ndarray, tf.Tensor)) 55 | 56 | def test_dag_construction(self): 57 | """Tests if DAG is built properly and runs. 58 | """ 59 | loss_group = losses.LossGroup(dag=self.dag) 60 | print('!!!!!!!!!!!', loss_group.dag, loss_group.loss_names, self.dag) 61 | loss_outputs = loss_group(self.nn_outputs) 62 | self.assertIsInstance(loss_outputs, dict) 63 | self._check_tensor_outputs(self.expected_outputs, loss_outputs) 64 | 65 | 66 | class SpectralLossTest(tf.test.TestCase): 67 | 68 | def test_output_shape_is_correct(self): 69 | """Test correct shape with all losses active.""" 70 | loss_obj = losses.SpectralLoss( 71 | mag_weight=1.0, 72 | delta_time_weight=1.0, 73 | delta_freq_weight=1.0, 74 | cumsum_freq_weight=1.0, 75 | logmag_weight=1.0, 76 | loudness_weight=1.0, 77 | ) 78 | 79 | input_audio = tf.ones((3, 8000), dtype=tf.float32) 80 | target_audio = tf.ones((3, 8000), dtype=tf.float32) 81 | 82 | loss = loss_obj(input_audio, target_audio) 83 | 84 | self.assertListEqual([], loss.shape.as_list()) 85 | self.assertTrue(np.isfinite(loss)) 86 | 87 | 88 | 89 | 90 | class PretrainedCREPEEmbeddingLossTest(tf.test.TestCase): 91 | 92 | def test_output_shape_is_correct(self): 93 | loss_obj = losses.PretrainedCREPEEmbeddingLoss() 94 | 95 | input_audio = tf.ones((3, 16000), dtype=tf.float32) 96 | target_audio = tf.ones((3, 16000), dtype=tf.float32) 97 | 98 | loss = loss_obj(input_audio, target_audio) 99 | 100 | self.assertListEqual([], loss.shape.as_list()) 101 | self.assertTrue(np.isfinite(loss)) 102 | 103 | 104 | if __name__ == '__main__': 105 | tf.test.main() 106 | -------------------------------------------------------------------------------- /ddsp/processors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.processors.""" 16 | 17 | from absl.testing import parameterized 18 | from ddsp import core 19 | from ddsp import effects 20 | from ddsp import processors 21 | from ddsp import synths 22 | import numpy as np 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | class ProcessorGroupTest(parameterized.TestCase, tf.test.TestCase): 27 | 28 | def setUp(self): 29 | """Create some dummy input data for the chain.""" 30 | super().setUp() 31 | # Create inputs. 32 | self.n_batch = 4 33 | self.n_frames = 1000 34 | self.n_time = 64000 35 | rand_signal = lambda ch: np.random.randn(self.n_batch, self.n_frames, ch) 36 | self.nn_outputs = { 37 | 'amps': rand_signal(1), 38 | 'harmonic_distribution': rand_signal(99), 39 | 'magnitudes': rand_signal(256), 40 | 'f0_hz': 200 + rand_signal(1), 41 | 'target_audio': np.random.randn(self.n_batch, self.n_time) 42 | } 43 | 44 | # Create Processors. 45 | harmonic = synths.Harmonic(name='harmonic') 46 | noise = synths.FilteredNoise(name='noise') 47 | add = processors.Add(name='add') 48 | reverb = effects.Reverb(trainable=True, name='reverb') 49 | 50 | # Create DAG for testing. 51 | self.dag = [ 52 | (harmonic, ['amps', 'harmonic_distribution', 'f0_hz']), 53 | (noise, ['magnitudes']), 54 | (add, ['noise/signal', 'harmonic/signal']), 55 | (reverb, ['add/signal']), 56 | ] 57 | self.expected_outputs = [ 58 | 'amps', 59 | 'harmonic_distribution', 60 | 'magnitudes', 61 | 'f0_hz', 62 | 'target_audio', 63 | 'harmonic/signal', 64 | 'harmonic/controls/amplitudes', 65 | 'harmonic/controls/harmonic_distribution', 66 | 'harmonic/controls/f0_hz', 67 | 'noise/signal', 68 | 'noise/controls/magnitudes', 69 | 'add/signal', 70 | 'reverb/signal', 71 | 'reverb/controls/ir', 72 | 'out/signal', 73 | ] 74 | 75 | def _check_tensor_outputs(self, strings_to_check, outputs): 76 | for tensor_string in strings_to_check: 77 | tensor = core.nested_lookup(tensor_string, outputs) 78 | self.assertIsInstance(tensor, (np.ndarray, tf.Tensor)) 79 | 80 | def test_dag_construction(self): 81 | """Tests if DAG is built properly and runs. 82 | """ 83 | processor_group = processors.ProcessorGroup(dag=self.dag, 84 | name='processor_group') 85 | outputs = processor_group.get_controls(self.nn_outputs) 86 | self.assertIsInstance(outputs, dict) 87 | self._check_tensor_outputs(self.expected_outputs, outputs) 88 | 89 | 90 | class AddTest(tf.test.TestCase): 91 | 92 | def test_output_is_correct(self): 93 | processor = processors.Add(name='add') 94 | x = tf.zeros((2, 3), dtype=tf.float32) + 1.0 95 | y = tf.zeros((2, 3), dtype=tf.float32) + 2.0 96 | 97 | output = processor(x, y) 98 | 99 | expected = np.zeros((2, 3), dtype=np.float32) + 3.0 100 | self.assertAllEqual(expected, output) 101 | 102 | 103 | class MixTest(tf.test.TestCase): 104 | 105 | def test_output_shape_is_correct(self): 106 | processor = processors.Mix(name='mix') 107 | x1 = np.zeros((2, 100, 3), dtype=np.float32) + 1.0 108 | x2 = np.zeros((2, 100, 3), dtype=np.float32) + 2.0 109 | mix_level = np.zeros( 110 | (2, 100, 1), dtype=np.float32) + 0.1 # will be passed to sigmoid 111 | 112 | output = processor(x1, x2, mix_level) 113 | 114 | self.assertListEqual([2, 100, 3], output.shape.as_list()) 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /ddsp/synths_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.synths.""" 16 | 17 | from ddsp import core 18 | from ddsp import synths 19 | import numpy as np 20 | import tensorflow.compat.v2 as tf 21 | 22 | 23 | class HarmonicTest(tf.test.TestCase): 24 | 25 | def test_output_shape_is_correct(self): 26 | synthesizer = synths.Harmonic( 27 | n_samples=64000, 28 | sample_rate=16000, 29 | scale_fn=None, 30 | normalize_below_nyquist=True) 31 | batch_size = 3 32 | num_frames = 1000 33 | amp = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 1.0 34 | harmonic_distribution = tf.zeros( 35 | (batch_size, num_frames, 16), dtype=tf.float32) + 1.0 / 16 36 | f0_hz = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 16000 37 | 38 | output = synthesizer(amp, harmonic_distribution, f0_hz) 39 | 40 | self.assertAllEqual([batch_size, 64000], output.shape.as_list()) 41 | 42 | 43 | class FilteredNoiseTest(tf.test.TestCase): 44 | 45 | def test_output_shape_is_correct(self): 46 | synthesizer = synths.FilteredNoise(n_samples=16000) 47 | filter_bank_magnitudes = tf.zeros((3, 16000, 100), dtype=tf.float32) + 3.0 48 | output = synthesizer(filter_bank_magnitudes) 49 | 50 | self.assertAllEqual([3, 16000], output.shape.as_list()) 51 | 52 | 53 | class WavetableTest(tf.test.TestCase): 54 | 55 | def test_output_shape_is_correct(self): 56 | synthesizer = synths.Wavetable( 57 | n_samples=64000, 58 | sample_rate=16000, 59 | scale_fn=None) 60 | batch_size = 3 61 | num_frames = 1000 62 | n_wavetable = 1024 63 | amp = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 1.0 64 | wavetables = tf.zeros( 65 | (batch_size, num_frames, n_wavetable), dtype=tf.float32) 66 | f0_hz = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 440 67 | 68 | output = synthesizer(amp, wavetables, f0_hz) 69 | 70 | self.assertAllEqual([batch_size, 64000], output.shape.as_list()) 71 | 72 | 73 | class SinusoidalTest(tf.test.TestCase): 74 | 75 | def test_output_shape_is_correct(self): 76 | synthesizer = synths.Sinusoidal(n_samples=32000, sample_rate=16000) 77 | batch_size = 3 78 | num_frames = 1000 79 | n_partials = 10 80 | amps = tf.zeros((batch_size, num_frames, n_partials), 81 | dtype=tf.float32) 82 | freqs = tf.zeros((batch_size, num_frames, n_partials), 83 | dtype=tf.float32) 84 | 85 | output = synthesizer(amps, freqs) 86 | 87 | self.assertAllEqual([batch_size, 32000], output.shape.as_list()) 88 | 89 | def test_frequencies_controls_are_bounded(self): 90 | depth = 10 91 | def freq_scale_fn(x): 92 | return core.frequencies_sigmoid(x, depth=depth, hz_min=0.0, hz_max=8000.0) 93 | 94 | synthesizer = synths.Sinusoidal( 95 | n_samples=32000, sample_rate=16000, freq_scale_fn=freq_scale_fn) 96 | batch_size = 3 97 | num_frames = 10 98 | n_partials = 100 99 | amps = tf.zeros((batch_size, num_frames, n_partials), dtype=tf.float32) 100 | freqs = tf.linspace(-100.0, 100.0, n_partials) 101 | freqs = tf.tile(freqs[tf.newaxis, tf.newaxis, :, tf.newaxis], 102 | [batch_size, num_frames, 1, depth]) 103 | 104 | controls = synthesizer.get_controls(amps, freqs) 105 | freqs = controls['frequencies'] 106 | lt_nyquist = (freqs <= 8000.0) 107 | gt_zero = (freqs >= 0.0) 108 | both_conditions = np.logical_and(lt_nyquist, gt_zero) 109 | 110 | self.assertTrue(np.all(both_conditions)) 111 | 112 | 113 | if __name__ == '__main__': 114 | tf.test.main() 115 | -------------------------------------------------------------------------------- /ddsp/test_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of helper functions for testing.""" 16 | 17 | import numpy as np 18 | 19 | 20 | def gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec): 21 | x = np.linspace(0, audio_len_sec, int(audio_len_sec * sample_rate)) 22 | audio_sin = amp * (np.sin(2 * np.pi * frequency * x)) 23 | return audio_sin 24 | 25 | 26 | def gen_np_batched_sinusoids(frequency, amp, sample_rate, audio_len_sec, 27 | batch_size): 28 | batch_sinusoids = [ 29 | gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec) 30 | for _ in range(batch_size) 31 | ] 32 | return np.array(batch_sinusoids) 33 | 34 | -------------------------------------------------------------------------------- /ddsp/training/README.md: -------------------------------------------------------------------------------- 1 | # DDSP Training 2 | 3 | 4 | This directory contains the code for training models using DDSP modules. 5 | The current supported models are variants of an audio autoencoder. 6 | 7 |
8 | logo 9 |
10 | 11 | # Disclaimer 12 | *Unlike the base `ddsp/` library, this folder is actively modified for new 13 | experiments and has a higher chance of making breaking changes in the future.* 14 | 15 | _Functions and classes marked **EXPERIMENTAL** in their doc string are under active development and very likely to change. They should not be expected to be maintained in their current state._ 16 | 17 | ## Modules 18 | 19 | The DDSP training libraries are separated into several modules: 20 | 21 | * [data](./data.py): 22 | DataProvider objects provide tf.data.Dataset. 23 | * [models](./models.py): 24 | Model objects define the full forward pass and losses. 25 | * [preprocessing](./preprocessing.py): 26 | Preprocessor objects format and scale model inputs. 27 | * [encoders](./encoders.py): 28 | Layers to turn preprocessor outputs into latents. 29 | * [decoders](./decoders.py): 30 | Layers to turn latents into ddsp processor inputs. 31 | * [nn](./nn.py): 32 | Network functions and layers. 33 | * [inference](./inference.py): 34 | Model wrappers for efficient inference and the ability to store as 35 | SavedModels. 36 | 37 | 38 | The main training file is `ddsp_run.py` and its helper libraries: 39 | 40 | * [ddsp_run](./ddsp_run.py): 41 | Main file for launching training, evaluation, and sampling runs. 42 | * [train_util](./train_util.py): 43 | Training loop and helper functions. 44 | * [trainers](./trainers.py): 45 | Training step defined by helper objects that bind the strategy, optimizer, and model. 46 | * [eval_util](./eval_util.py): 47 | Evaluation and sampling loop. 48 | * [evaluators](./evaluators.py): 49 | Evaluator objects responsible for computing metrics and summaries. 50 | * [metrics](./metrics.py): 51 | Metrics for evaluation. 52 | * [summaries](./summaries.py): 53 | Summaries for tensorboard images and audio of samples. 54 | 55 | While the modules in the `ddsp/` base directory can be used to train models 56 | with `tf.compat.v1` or `tf.compat.v2` this directory only uses `tf.compat.v2`. 57 | 58 | ## Quickstart 59 | 60 | The [pip installation](../../README.md#installation) includes several scripts that can be called directly from 61 | the command line. 62 | 63 | Hyperparameters are configured via gin, and `ddsp_run.py` must be given three 64 | `--gin_file` flags, one from `gin/models`, one from `gin/datasets`, and one from `gin/eval`. The 65 | files in `gin/papers` include both the dataset, model, and evaluation files for reproducing experiments from a specific paper. 66 | 67 | By default, the program searches for gin files in the installed `ddsp/training/gin` location, but additional search paths can be added with `--gin_search_path` 68 | flags. Individual parameters can also be set with multiple `--gin_param` flags. 69 | 70 | This example below streams a version of the NSynth dataset from GCS. 71 | If not running on GCP, it is much faster to first download the dataset with 72 | [tensorflow_datasets](https://www.tensorflow.org/datasets), and add the flag 73 | `--gin_param="NSynthTfds.data_dir='/path/to/tfds/dir'"`: 74 | 75 | ### Train 76 | ```bash 77 | ddsp_run \ 78 | --mode=train \ 79 | --save_dir=/tmp/$USER-ddsp-0 \ 80 | --gin_file=papers/iclr2020/nsynth_ae.gin \ 81 | --gin_param="batch_size=16" \ 82 | --alsologtostderr 83 | ``` 84 | 85 | ### Evaluate 86 | ```bash 87 | ddsp_run \ 88 | --mode=eval \ 89 | --save_dir=/tmp/$USER-ddsp-0 \ 90 | --gin_file=dataset/nsynth.gin \ 91 | --gin_file=eval/basic_f0_ld.gin \ 92 | --alsologtostderr 93 | ``` 94 | 95 | ### Sample 96 | ```bash 97 | ddsp_run \ 98 | --mode=sample \ 99 | --save_dir=/tmp/$USER-ddsp-0 \ 100 | --gin_file=dataset/nsynth.gin \ 101 | --gin_file=eval/basic_f0_ld.gin \ 102 | --alsologtostderr 103 | ``` 104 | 105 | When training, all gin parameters in the 106 | [operative configuration](https://github.com/google/gin-config/blob/master/docs/index.md#retrieving-operative-parameter-values) 107 | will be saved to the `${MODEL_DIR}/operative_config-0.gin` file, which is then loaded for evaluation, sampling, or further training. The operative config is also visible as a text summary in tensorboard. See 108 | [this doc](https://github.com/google/gin-config/blob/master/docs/index.md#saving-gins-operative-config-to-a-file-and-tensorboard) 109 | for more details. 110 | 111 | ### Backwards compatability 112 | 113 | 114 | For backwards compatability, we keep track of changes in function signatures in `update_gin_config.py`, which can be used to update old operative configs to work with the current library. 115 | 116 | 117 | ### Using Cloud TPU 118 | 119 | To use a [Cloud TPU](https://cloud.google.com/tpu/) for any of the above commands, there are a few minor changes. 120 | 121 | First, your model directory will need to accessible to the TPU. This means it will need to be located in a [GCS bucket with proper permissions](https://cloud.google.com/tpu/docs/storage-buckets). 122 | 123 | Second, you will need to add the following flag: 124 | 125 | ``` 126 | --tpu=grpc://:8470 \ 127 | ``` 128 | 129 | The TPU internal IP address can be found in the Cloud Console. 130 | 131 | 132 | ## Training a model on your own data 133 | ### Prepare dataset 134 | TFRecord dataset out of a folder of .wav or .mp3 files 135 | 136 | ```bash 137 | ddsp_prepare_tfrecord \ 138 | --input_audio_filepatterns=/path/to/wavs/*wav \ 139 | --output_tfrecord_path=/path/to/dataset_name.tfrecord \ 140 | --num_shards=10 \ 141 | --alsologtostderr 142 | ``` 143 | 144 | ### Train 145 | ```bash 146 | ddsp_run \ 147 | --mode=train \ 148 | --save_dir=/tmp/$USER-ddsp-0 \ 149 | --gin_file=models/solo_instrument.gin \ 150 | --gin_file=datasets/tfrecord.gin \ 151 | --gin_file=eval/basic_f0_ld.gin \ 152 | --gin_param="TFRecordProvider.file_pattern='/path/to/dataset_name.tfrecord*'" \ 153 | --gin_param="batch_size=16" \ 154 | --alsologtostderr 155 | ``` 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /ddsp/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Training code for DDSP models.""" 16 | 17 | from ddsp.training import cloud 18 | from ddsp.training import data 19 | from ddsp.training import decoders 20 | from ddsp.training import encoders 21 | from ddsp.training import eval_util 22 | from ddsp.training import evaluators 23 | from ddsp.training import inference 24 | from ddsp.training import metrics 25 | from ddsp.training import models 26 | from ddsp.training import nn 27 | from ddsp.training import plotting 28 | from ddsp.training import postprocessing 29 | from ddsp.training import preprocessing 30 | from ddsp.training import summaries 31 | from ddsp.training import train_util 32 | from ddsp.training import trainers 33 | -------------------------------------------------------------------------------- /ddsp/training/cloud.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of functions for training on Google Cloud AI-Platform.""" 16 | 17 | import os 18 | import re 19 | 20 | from absl import logging 21 | from google.cloud import storage 22 | import hypertune 23 | 24 | 25 | def download_from_gstorage(gstorage_path, local_path): 26 | """Downloads a file from the bucket. 27 | 28 | Args: 29 | gstorage_path: Path to the file inside the bucket that needs to be 30 | downloaded. Format: gs://bucket-name/path/to/file.txt 31 | local_path: Local path where downloaded file should be stored. 32 | """ 33 | gstorage_path = gstorage_path.strip('gs:/') 34 | bucket_name = gstorage_path.split('/')[0] 35 | blob_name = os.path.relpath(gstorage_path, bucket_name) 36 | 37 | storage_client = storage.Client() 38 | 39 | bucket = storage_client.bucket(bucket_name) 40 | blob = bucket.blob(blob_name) 41 | 42 | blob.download_to_filename(local_path) 43 | logging.info( 44 | 'Downloaded file. Source: %s, Destination: %s', 45 | gstorage_path, local_path) 46 | 47 | 48 | def make_file_paths_local(paths, local_directory): 49 | """Makes sure that given files are locally available. 50 | 51 | If a Cloud Storage path is provided, downloads the file and returns the new 52 | path relative to local_directory. If a local path is provided it is returns 53 | path with no modification. 54 | 55 | Args: 56 | paths: Single path or a list of paths. 57 | local_directory: Local path to the directory were downloaded files will be 58 | stored. Note that if you want to download gin configuration files 59 | 60 | Returns: 61 | Single local path or a list of local paths. 62 | """ 63 | if isinstance(paths, str): 64 | if re.match('gs://*', paths): 65 | local_name = os.path.basename(paths) 66 | download_from_gstorage(paths, os.path.join(local_directory, local_name)) 67 | return local_name 68 | else: 69 | return paths 70 | else: 71 | local_paths = [] 72 | for path in paths: 73 | if re.match('gs://*', path): 74 | local_name = os.path.basename(path) 75 | download_from_gstorage(path, os.path.join(local_directory, local_name)) 76 | local_paths.append(local_name) 77 | else: 78 | local_paths.append(path) 79 | return local_paths 80 | 81 | 82 | def report_metric_to_hypertune(metric_value, step, tag='Loss'): 83 | """Use hypertune to report metrics for hyperparameter tuning.""" 84 | hpt = hypertune.HyperTune() 85 | hpt.report_hyperparameter_tuning_metric( 86 | hyperparameter_metric_tag=tag, 87 | metric_value=metric_value, 88 | global_step=step) 89 | -------------------------------------------------------------------------------- /ddsp/training/cloud_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.training.cloud.""" 16 | 17 | from unittest import mock 18 | 19 | from ddsp.training import cloud 20 | import tensorflow.compat.v2 as tf 21 | 22 | 23 | class MakeFilePathsLocalTest(tf.test.TestCase): 24 | 25 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True) 26 | def test_single_path_handling(self, download_from_gstorage_function): 27 | """Tests that function returns a single value if given single value.""" 28 | path = cloud.make_file_paths_local( 29 | 'gs://bucket-name/bucket/dir/some_file.gin', 30 | 'gin/search/path') 31 | download_from_gstorage_function.assert_called_once() 32 | self.assertEqual(path, 'some_file.gin') 33 | 34 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True) 35 | def test_single_local_path_handling(self, download_from_gstorage_function): 36 | """Tests that function does nothing if given local file path.""" 37 | path = cloud.make_file_paths_local( 38 | 'local_file.gin', 39 | 'gin/search/path') 40 | download_from_gstorage_function.assert_not_called() 41 | self.assertEqual(path, 'local_file.gin') 42 | 43 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True) 44 | def test_single_path_in_list_handling(self, download_from_gstorage_function): 45 | """Tests that function returns a single-element list if given one.""" 46 | path = cloud.make_file_paths_local( 47 | ['gs://bucket-name/bucket/dir/some_file.gin'], 48 | 'gin/search/path') 49 | download_from_gstorage_function.assert_called_once() 50 | self.assertNotIsInstance(path, str) 51 | self.assertListEqual(path, ['some_file.gin']) 52 | 53 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True) 54 | def test_more_paths_in_list_handling(self, download_from_gstorage_function): 55 | """Tests that function handle both local and gstorage paths in one list.""" 56 | paths = cloud.make_file_paths_local( 57 | ['gs://bucket-name/bucket/dir/first_file.gin', 58 | 'local_file.gin', 59 | 'gs://bucket-name/bucket/dir/second_file.gin'], 60 | 'gin/search/path') 61 | self.assertEqual(download_from_gstorage_function.call_count, 2) 62 | download_from_gstorage_function.assert_has_calls( 63 | [mock.call('gs://bucket-name/bucket/dir/first_file.gin', mock.ANY), 64 | mock.call('gs://bucket-name/bucket/dir/second_file.gin', mock.ANY)]) 65 | self.assertListEqual( 66 | paths, 67 | ['first_file.gin', 'local_file.gin', 'second_file.gin']) 68 | 69 | 70 | if __name__ == '__main__': 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /ddsp/training/data_preparation/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | Scripts and libraries to prepare datasets for training. For some more example usage, check out the [train_autoencoder](../../colab/demos/train_autoencoder.ipynb) demo. 4 | 5 | 6 | ## Making a TFRecord dataset from your own sounds 7 | 8 | For experiments from the original [ICLR 2020 paper](https://openreview.net/forum?id=B1x1ma4tDr), we need to do some preprocessing on the raw audio to get it into the correct format for training. This involves turning the full audio into short examples (4 seconds by default, but adjustable with flags), inferring the fundamental frequency (or \"pitch\") with [CREPE](http://github.com/marl/crepe), and computing the loudness. These features will then be stored in a sharded [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file for easier loading. Depending on the amount of input audio, this process usually takes a few minutes. 9 | 10 | ``` 11 | ddsp_prepare_tfrecord \ 12 | --input_audio_filepatterns=/path/to/wavs/*wav,/path/to/mp3s/*mp3 \ 13 | --output_tfrecord_path=/path/to/output.tfrecord \ 14 | --num_shards=10 \ 15 | --alsologtostderr 16 | ``` 17 | 18 | ## Making a TFRecord dataset from synthetic data 19 | 20 | For experiments from the [ICML 2020 workshop paper](https://goo.gl/magenta/ddsp-inv) of performing pitch detection with inverse synthesis, we need to create a synthetic dataset to use for supervision. We must also specify the data generation function `generate_examples.generate_fn` we want to use as a `--gin_param` flag or in a `--gin_file`. 21 | 22 | ``` 23 | ddsp_generate_synthetic_dataset \ 24 | --output_tfrecord_path=/path/to/output.tfrecord \ 25 | --num_shards=1000 \ 26 | --gin_param="generate_examples.generate_fn = @generate_notes_v2" \ 27 | --num_examples=10000000 \ 28 | --alsologtostderr 29 | ``` 30 | 31 | -------------------------------------------------------------------------------- /ddsp/training/data_preparation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data preparation code for DDSP models.""" 16 | 17 | from ddsp.training.data_preparation import prepare_tfrecord_lib 18 | -------------------------------------------------------------------------------- /ddsp/training/data_preparation/ddsp_generate_synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Apache Beam pipeline for computing TFRecord dataset of synthetic examples. 16 | 17 | Example usage: 18 | ============== 19 | ddsp_generate_synthetic_dataset \ 20 | --output_tfrecord_path=/tmp/synthetic_data.tfrecord \ 21 | --num_shards=1 \ 22 | --gin_param="generate_examples.generate_fn = @generate_notes_v2" \ 23 | --num_examples=100 \ 24 | --alsologtostderr 25 | 26 | For the ICML workshop paper, we created 10,000,000 examples in 1000 shards. 27 | """ 28 | 29 | from absl import app 30 | from absl import flags 31 | import apache_beam as beam 32 | from ddsp.training.data_preparation import synthetic_data # pylint:disable=unused-import 33 | import gin 34 | import numpy as np 35 | import pkg_resources 36 | import tensorflow.compat.v2 as tf 37 | 38 | 39 | GIN_PATH = pkg_resources.resource_filename(__name__, '../gin') 40 | 41 | FLAGS = flags.FLAGS 42 | 43 | flags.DEFINE_string( 44 | 'output_tfrecord_path', None, 45 | 'The prefix path to the output TFRecord. Shard numbers will be added to ' 46 | 'actual path(s).') 47 | flags.DEFINE_integer( 48 | 'num_shards', None, 49 | 'The number of shards to use for the TFRecord. If None, this number will ' 50 | 'be determined automatically.') 51 | flags.DEFINE_integer( 52 | 'num_examples', 1000000, 53 | 'The total number of synthetic examples to generate.') 54 | flags.DEFINE_integer( 55 | 'random_seed', 42, 56 | 'Random seed to use for deterministic generation.') 57 | flags.DEFINE_list( 58 | 'pipeline_options', '--runner=DirectRunner', 59 | 'A comma-separated list of command line arguments to be used as options ' 60 | 'for the Beam Pipeline.') 61 | 62 | # Gin config flags. 63 | flags.DEFINE_multi_string('gin_search_path', [], 64 | 'Additional gin file search paths.') 65 | flags.DEFINE_multi_string('gin_file', [], 'List of paths to the config files.') 66 | flags.DEFINE_multi_string('gin_param', [], 67 | 'Newline separated list of Gin parameter bindings.') 68 | 69 | 70 | class GenerateExampleFn(beam.DoFn): 71 | """Gin-configurable wrapper to generate synthetic examples.""" 72 | 73 | def __init__(self, gin_str, **kwargs): 74 | super().__init__(**kwargs) 75 | self._gin_str = gin_str 76 | 77 | def start_bundle(self): 78 | with gin.unlock_config(): 79 | gin.parse_config(self._gin_str) 80 | 81 | @gin.configurable('generate_examples') 82 | def process(self, seed, generate_fn=gin.REQUIRED): 83 | np.random.seed(seed) 84 | batch = generate_fn(n_batch=1) 85 | beam.metrics.Metrics.counter('GenerateExampleFn', 'generated').inc() 86 | yield {k: v[0].numpy() for k, v in batch.items()} 87 | 88 | 89 | def _float_dict_to_tfexample(float_dict): 90 | """Convert dictionary of float arrays to tf.train.Example proto.""" 91 | return tf.train.Example( 92 | features=tf.train.Features( 93 | feature={ 94 | k: tf.train.Feature( 95 | float_list=tf.train.FloatList(value=v.flatten())) 96 | for k, v in float_dict.items() 97 | } 98 | )) 99 | 100 | 101 | def run(): 102 | """Run the beam pipeline to create synthetic dataset.""" 103 | pipeline_options = beam.options.pipeline_options.PipelineOptions( 104 | FLAGS.pipeline_options) 105 | with beam.Pipeline(options=pipeline_options) as pipeline: 106 | for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: 107 | gin.add_config_file_search_path(gin_search_path) 108 | gin.parse_config_files_and_bindings( 109 | FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True) 110 | 111 | np.random.seed(FLAGS.random_seed) 112 | _ = ( 113 | pipeline 114 | | beam.Create(np.random.randint(2**32, size=FLAGS.num_examples)) 115 | | beam.ParDo(GenerateExampleFn(gin.config_str())) 116 | | beam.Reshuffle() 117 | | beam.Map(_float_dict_to_tfexample) 118 | | beam.io.tfrecordio.WriteToTFRecord( 119 | FLAGS.output_tfrecord_path, 120 | num_shards=FLAGS.num_shards, 121 | coder=beam.coders.ProtoCoder(tf.train.Example)) 122 | ) 123 | 124 | 125 | def main(unused_argv): 126 | """From command line.""" 127 | run() 128 | 129 | 130 | def console_entry_point(): 131 | """From pip installed script.""" 132 | app.run(main) 133 | 134 | 135 | if __name__ == '__main__': 136 | console_entry_point() 137 | -------------------------------------------------------------------------------- /ddsp/training/data_preparation/ddsp_prepare_tfrecord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Create a TFRecord dataset from audio files. 16 | 17 | Usage: 18 | ==================== 19 | ddsp_prepare_tfrecord \ 20 | --input_audio_filepatterns=/path/to/wavs/*wav,/path/to/mp3s/*mp3 \ 21 | --output_tfrecord_path=/path/to/output.tfrecord \ 22 | --num_shards=10 \ 23 | --alsologtostderr 24 | 25 | """ 26 | 27 | from absl import app 28 | from absl import flags 29 | from ddsp.training.data_preparation.prepare_tfrecord_lib import prepare_tfrecord 30 | import tensorflow.compat.v2 as tf 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_list('input_audio_filepatterns', [], 35 | 'List of filepatterns to glob for input audio files.') 36 | flags.DEFINE_string( 37 | 'output_tfrecord_path', None, 38 | 'The prefix path to the output TFRecord. Shard numbers will be added to ' 39 | 'actual path(s).') 40 | flags.DEFINE_integer( 41 | 'num_shards', None, 42 | 'The number of shards to use for the TFRecord. If None, this number will ' 43 | 'be determined automatically.') 44 | flags.DEFINE_integer('sample_rate', 16000, 45 | 'The sample rate to use for the audio.') 46 | flags.DEFINE_integer( 47 | 'frame_rate', 250, 48 | 'The frame rate to use for f0 and loudness features. If set to 0, ' 49 | 'these features will not be computed.') 50 | flags.DEFINE_float( 51 | 'example_secs', 4, 52 | 'The length of each example in seconds. Input audio will be split to this ' 53 | 'length using a sliding window. If 0, each full piece of audio will be ' 54 | 'used as an example.') 55 | flags.DEFINE_float( 56 | 'hop_secs', 1, 57 | 'The hop size between example start points (in seconds), when splitting ' 58 | 'audio into constant-length examples.') 59 | flags.DEFINE_float( 60 | 'eval_split_fraction', 0.0, 61 | 'Fraction of the dataset to reserve for eval split. If set to 0, no eval ' 62 | 'split is created.' 63 | ) 64 | flags.DEFINE_float( 65 | 'chunk_secs', 20.0, 66 | 'Chunk size in seconds used to split the input audio files. These ' 67 | 'non-overlapping chunks are partitioned into train and eval sets if ' 68 | 'eval_split_fraction > 0. This is used to split large audio files into ' 69 | 'manageable chunks for better parallelization and to enable ' 70 | 'non-overlapping train/eval splits.') 71 | flags.DEFINE_boolean( 72 | 'center', False, 73 | 'Add padding to audio such that frame timestamps are centered. Increases ' 74 | 'number of frames by one.') 75 | flags.DEFINE_boolean( 76 | 'viterbi', True, 77 | 'Use viterbi decoding of pitch.') 78 | flags.DEFINE_list( 79 | 'pipeline_options', '--runner=DirectRunner', 80 | 'A comma-separated list of command line arguments to be used as options ' 81 | 'for the Beam Pipeline.') 82 | 83 | 84 | def run(): 85 | input_audio_paths = [] 86 | for filepattern in FLAGS.input_audio_filepatterns: 87 | input_audio_paths.extend(tf.io.gfile.glob(filepattern)) 88 | 89 | prepare_tfrecord( 90 | input_audio_paths, 91 | FLAGS.output_tfrecord_path, 92 | num_shards=FLAGS.num_shards, 93 | sample_rate=FLAGS.sample_rate, 94 | frame_rate=FLAGS.frame_rate, 95 | example_secs=FLAGS.example_secs, 96 | hop_secs=FLAGS.hop_secs, 97 | eval_split_fraction=FLAGS.eval_split_fraction, 98 | chunk_secs=FLAGS.chunk_secs, 99 | center=FLAGS.center, 100 | viterbi=FLAGS.viterbi, 101 | pipeline_options=FLAGS.pipeline_options) 102 | 103 | 104 | def main(unused_argv): 105 | """From command line.""" 106 | run() 107 | 108 | 109 | def console_entry_point(): 110 | """From pip installed script.""" 111 | app.run(main) 112 | 113 | 114 | if __name__ == '__main__': 115 | console_entry_point() 116 | -------------------------------------------------------------------------------- /ddsp/training/data_preparation/prepare_tfrecord_lib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.training.data_preparation.prepare_tfrecord_lib.""" 16 | 17 | import os 18 | import sys 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | from ddsp import spectral_ops 24 | from ddsp.training.data_preparation import prepare_tfrecord_lib 25 | import numpy as np 26 | import scipy.io.wavfile 27 | import tensorflow.compat.v2 as tf 28 | 29 | CREPE_SAMPLE_RATE = spectral_ops.CREPE_SAMPLE_RATE 30 | 31 | 32 | class PrepareTFRecordBeamTest(parameterized.TestCase): 33 | 34 | def get_tempdir(self): 35 | try: 36 | flags.FLAGS.test_tmpdir 37 | except flags.UnparsedFlagAccessError: 38 | # Need to initialize flags when running `pytest`. 39 | flags.FLAGS(sys.argv) 40 | return self.create_tempdir().full_path 41 | 42 | def setUp(self): 43 | super().setUp() 44 | self.test_dir = self.get_tempdir() 45 | 46 | # Write test wav file. 47 | self.wav_sr = 22050 48 | self.wav_secs = 0.5 49 | self.wav_path = os.path.join(self.test_dir, 'test.wav') 50 | scipy.io.wavfile.write( 51 | self.wav_path, 52 | self.wav_sr, 53 | np.random.randint( 54 | np.iinfo(np.int16).min, np.iinfo(np.int16).max, 55 | size=int(self.wav_sr * self.wav_secs), dtype=np.int16)) 56 | 57 | def parse_tfrecord(self, path): 58 | return [tf.train.Example.FromString(record.numpy()) for record in 59 | tf.data.TFRecordDataset(os.path.join(self.test_dir, path))] 60 | 61 | def validate_outputs(self, expected_num_examples, expected_feature_lengths): 62 | all_examples = ( 63 | self.parse_tfrecord('output.tfrecord-00000-of-00002') + 64 | self.parse_tfrecord('output.tfrecord-00001-of-00002')) 65 | 66 | self.assertLen(all_examples, expected_num_examples) 67 | for ex in all_examples: 68 | self.assertCountEqual(expected_feature_lengths, ex.features.feature) 69 | 70 | for feat, expected_len in expected_feature_lengths.items(): 71 | arr = ex.features.feature[feat].float_list.value 72 | try: 73 | self.assertLen(arr, expected_len) 74 | except AssertionError as e: 75 | raise AssertionError('feature: %s' % feat) from e 76 | self.assertFalse(any(np.isinf(arr))) 77 | 78 | def get_expected_length(self, input_length, frame_rate, center=False): 79 | sample_rate = 16000 # Features at CREPE_SAMPLE_RATE. 80 | frame_size = 1024 # Unused for this calculation. 81 | hop_size = sample_rate // frame_rate 82 | padding = 'center' if center else 'same' 83 | n_frames, _ = spectral_ops.get_framed_lengths( 84 | input_length, frame_size, hop_size, padding) 85 | return n_frames 86 | 87 | @staticmethod 88 | def get_n_per_chunk(chunk_length, example_secs, hop_secs): 89 | """Convenience function to calculate number examples from striding.""" 90 | n = (chunk_length - example_secs) / hop_secs 91 | # Deal with limited float precision that causes (.3 / .1) = 2.9999.... 92 | return int(np.floor(np.round(n, decimals=3))) + 1 93 | 94 | @parameterized.named_parameters( 95 | ('chunk_and_split', 0.3, 0.2), 96 | ('no_chunk', None, 0.2), 97 | ('no_split', 0.3, None), 98 | ('no_chunk_no_split', None, None), 99 | ) 100 | def test_prepare_tfrecord(self, chunk_secs, example_secs): 101 | sample_rate = 16000 102 | frame_rate = 250 103 | hop_secs = 0.1 104 | 105 | # Calculate expected batch size. 106 | if example_secs: 107 | length = chunk_secs if chunk_secs else self.wav_secs 108 | n_per_chunk = self.get_n_per_chunk(length, example_secs, hop_secs) 109 | else: 110 | n_per_chunk = 1 111 | 112 | n_chunks = int(np.ceil(self.wav_secs / chunk_secs)) if chunk_secs else 1 113 | expected_n_batch = n_per_chunk * n_chunks 114 | print('n_chunks, n_per_chunk, chunk_secs, example_secs', 115 | n_chunks, n_per_chunk, chunk_secs, example_secs) 116 | 117 | # Calculate expected lengths. 118 | if example_secs: 119 | length = example_secs 120 | elif chunk_secs: 121 | length = chunk_secs 122 | else: 123 | length = self.wav_secs 124 | 125 | expected_n_t = int(length * sample_rate) 126 | expected_n_frames = self.get_expected_length(expected_n_t, frame_rate) 127 | 128 | # Make the actual records. 129 | prepare_tfrecord_lib.prepare_tfrecord( 130 | [self.wav_path], 131 | os.path.join(self.test_dir, 'output.tfrecord'), 132 | num_shards=2, 133 | sample_rate=sample_rate, 134 | frame_rate=frame_rate, 135 | example_secs=example_secs, 136 | hop_secs=hop_secs, 137 | chunk_secs=chunk_secs, 138 | center=False) 139 | 140 | self.validate_outputs( 141 | expected_n_batch, 142 | { 143 | 'audio': expected_n_t, 144 | 'audio_16k': expected_n_t, 145 | 'f0_hz': expected_n_frames, 146 | 'f0_confidence': expected_n_frames, 147 | 'loudness_db': expected_n_frames, 148 | }) 149 | 150 | @parameterized.named_parameters(('no_center', False), ('center', True)) 151 | def test_centering(self, center): 152 | frame_rate = 250 153 | sample_rate = 16000 154 | example_secs = 0.3 155 | hop_secs = 0.1 156 | n_batch = self.get_n_per_chunk(self.wav_secs, example_secs, hop_secs) 157 | prepare_tfrecord_lib.prepare_tfrecord( 158 | [self.wav_path], 159 | os.path.join(self.test_dir, 'output.tfrecord'), 160 | num_shards=2, 161 | sample_rate=sample_rate, 162 | frame_rate=frame_rate, 163 | example_secs=example_secs, 164 | hop_secs=hop_secs, 165 | center=center, 166 | chunk_secs=None) 167 | 168 | n_t = int(example_secs * sample_rate) 169 | n_frames = self.get_expected_length(n_t, frame_rate, center) 170 | n_expected_frames = 76 if center else 75 # (250 * 0.3) [+1]. 171 | self.assertEqual(n_frames, n_expected_frames) 172 | self.validate_outputs( 173 | n_batch, { 174 | 'audio': n_t, 175 | 'audio_16k': n_t, 176 | 'f0_hz': n_frames, 177 | 'f0_confidence': n_frames, 178 | 'loudness_db': n_frames, 179 | }) 180 | 181 | @parameterized.named_parameters( 182 | ('16kHz', 16000), 183 | ('32kHz', 32000), 184 | ('48kHz', 48000)) 185 | def test_sample_rate(self, sample_rate): 186 | frame_rate = 250 187 | example_secs = 0.3 188 | hop_secs = 0.1 189 | center = True 190 | n_batch = self.get_n_per_chunk(self.wav_secs, example_secs, hop_secs) 191 | prepare_tfrecord_lib.prepare_tfrecord( 192 | [self.wav_path], 193 | os.path.join(self.test_dir, 'output.tfrecord'), 194 | num_shards=2, 195 | sample_rate=sample_rate, 196 | frame_rate=frame_rate, 197 | example_secs=example_secs, 198 | hop_secs=hop_secs, 199 | center=center, 200 | chunk_secs=None) 201 | 202 | n_t = int(example_secs * sample_rate) 203 | n_t_16k = int(example_secs * CREPE_SAMPLE_RATE) 204 | n_frames = self.get_expected_length(n_t_16k, frame_rate, center) 205 | n_expected_frames = 76 # (250 * 0.3) + 1. 206 | self.assertEqual(n_frames, n_expected_frames) 207 | self.validate_outputs( 208 | n_batch, { 209 | 'audio': n_t, 210 | 'audio_16k': n_t_16k, 211 | 'f0_hz': n_frames, 212 | 'f0_confidence': n_frames, 213 | 'loudness_db': n_frames, 214 | }) 215 | 216 | @parameterized.named_parameters(('16kHz', 16000), ('44.1kHz', 44100), 217 | ('48kHz', 48000)) 218 | def test_audio_only(self, sample_rate): 219 | prepare_tfrecord_lib.prepare_tfrecord( 220 | [self.wav_path], 221 | os.path.join(self.test_dir, 'output.tfrecord'), 222 | num_shards=2, 223 | sample_rate=sample_rate, 224 | frame_rate=None, 225 | example_secs=None, 226 | chunk_secs=None) 227 | 228 | self.validate_outputs( 229 | 1, { 230 | 'audio': int(self.wav_secs * sample_rate), 231 | 'audio_16k': int(self.wav_secs * CREPE_SAMPLE_RATE), 232 | }) 233 | 234 | 235 | if __name__ == '__main__': 236 | absltest.main() 237 | -------------------------------------------------------------------------------- /ddsp/training/ddsp_run.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Train, evaluate, or sample (from) a ddsp model. 16 | 17 | Usage: 18 | ================================================================================ 19 | For training, you need to specify --gin_file for both the model and the dataset. 20 | You can optionally specify additional params with --gin_param. 21 | The pip install installs a `ddsp_run` script that can be called directly. 22 | ================================================================================ 23 | ddsp_run \ 24 | --mode=train \ 25 | --alsologtostderr \ 26 | --save_dir=/tmp/$USER-ddsp-0 \ 27 | --gin_file=models/ae.gin \ 28 | --gin_file=datasets/nsynth.gin \ 29 | --gin_param=batch_size=16 30 | 31 | 32 | ================================================================================ 33 | For evaluation and sampling, only the dataset file is required. 34 | ================================================================================ 35 | ddsp_run \ 36 | --mode=eval \ 37 | --alsologtostderr \ 38 | --save_dir=/tmp/$USER-ddsp-0 \ 39 | --gin_file=datasets/nsynth.gin 40 | 41 | ddsp_run \ 42 | --mode=sample \ 43 | --alsologtostderr \ 44 | --save_dir=/tmp/$USER-ddsp-0 \ 45 | --gin_file=datasets/nsynth.gin 46 | 47 | 48 | ================================================================================ 49 | The directory `gin/papers/` stores configs that give the specific models and 50 | datasets used for a paper's experiments, so only require one gin file to train. 51 | ================================================================================ 52 | ddsp_run \ 53 | --mode=train \ 54 | --alsologtostderr \ 55 | --save_dir=/tmp/$USER-ddsp-0 \ 56 | --gin_file=papers/iclr2020/nsynth_ae.gin 57 | 58 | 59 | """ 60 | 61 | import os 62 | import time 63 | 64 | from absl import app 65 | from absl import flags 66 | from absl import logging 67 | from ddsp.training import cloud 68 | from ddsp.training import eval_util 69 | from ddsp.training import models 70 | from ddsp.training import train_util 71 | from ddsp.training import trainers 72 | import gin 73 | import pkg_resources 74 | import tensorflow as tf 75 | 76 | gfile = tf.io.gfile 77 | FLAGS = flags.FLAGS 78 | 79 | # Program flags. 80 | flags.DEFINE_enum('mode', 'train', ['train', 'eval', 'sample'], 81 | 'Whether to train, evaluate, or sample from the model.') 82 | flags.DEFINE_string('save_dir', '/tmp/ddsp', 83 | 'Path where checkpoints and summary events will be saved ' 84 | 'during training and evaluation.') 85 | flags.DEFINE_string('restore_dir', '', 86 | 'Path from which checkpoints will be restored before ' 87 | 'training. Can be different than the save_dir.') 88 | flags.DEFINE_string('tpu', '', 'Address of the TPU. No TPU if left blank.') 89 | flags.DEFINE_string('cluster_config', '', 90 | 'Worker-specific JSON string for multiworker setup. ' 91 | 'For more information see train_util.get_strategy().') 92 | flags.DEFINE_boolean('allow_memory_growth', False, 93 | 'Whether to grow the GPU memory usage as is needed by the ' 94 | 'process. Prevents crashes on GPUs with smaller memory.') 95 | flags.DEFINE_boolean('hypertune', False, 96 | 'Enable metric reporting for hyperparameter tuning, such ' 97 | 'as on Google Cloud AI-Platform.') 98 | flags.DEFINE_float('early_stop_loss_value', None, 99 | 'Stops training early when the `total_loss` reaches below ' 100 | 'this value during training.') 101 | 102 | # Gin config flags. 103 | flags.DEFINE_multi_string('gin_search_path', [], 104 | 'Additional gin file search paths.') 105 | flags.DEFINE_multi_string('gin_file', [], 106 | 'List of paths to the config files. If file ' 107 | 'in gstorage bucket specify whole gstorage path: ' 108 | 'gs://bucket-name/dir/in/bucket/file.gin.') 109 | flags.DEFINE_multi_string('gin_param', [], 110 | 'Newline separated list of Gin parameter bindings.') 111 | 112 | # Evaluation/sampling specific flags. 113 | flags.DEFINE_boolean('run_once', False, 'Whether evaluation will run once.') 114 | flags.DEFINE_integer('initial_delay_secs', None, 115 | 'Time to wait before evaluation starts') 116 | 117 | GIN_PATH = pkg_resources.resource_filename(__name__, 'gin') 118 | 119 | 120 | def delay_start(): 121 | """Optionally delay the start of the run.""" 122 | delay_time = FLAGS.initial_delay_secs 123 | if delay_time: 124 | logging.info('Waiting for %i second(s)', delay_time) 125 | time.sleep(delay_time) 126 | 127 | 128 | def parse_gin(restore_dir): 129 | """Parse gin config from --gin_file, --gin_param, and the model directory.""" 130 | # Enable parsing gin files on Google Cloud. 131 | gin.config.register_file_reader(tf.io.gfile.GFile, tf.io.gfile.exists) 132 | # Add user folders to the gin search path. 133 | for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: 134 | gin.add_config_file_search_path(gin_search_path) 135 | 136 | # Parse gin configs, later calls override earlier ones. 137 | with gin.unlock_config(): 138 | # Optimization defaults. 139 | use_tpu = bool(FLAGS.tpu) 140 | opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin' 141 | gin.parse_config_file(os.path.join('optimization', opt_default)) 142 | eval_default = 'eval/basic.gin' 143 | gin.parse_config_file(eval_default) 144 | 145 | # Load operative_config if it exists (model has already trained). 146 | try: 147 | operative_config = train_util.get_latest_operative_config(restore_dir) 148 | logging.info('Using operative config: %s', operative_config) 149 | operative_config = cloud.make_file_paths_local(operative_config, GIN_PATH) 150 | gin.parse_config_file(operative_config, skip_unknown=True) 151 | except FileNotFoundError: 152 | logging.info('Operative config not found in %s', restore_dir) 153 | 154 | # User gin config and user hyperparameters from flags. 155 | gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH) 156 | gin.parse_config_files_and_bindings( 157 | gin_file, FLAGS.gin_param, skip_unknown=True) 158 | 159 | 160 | def allow_memory_growth(): 161 | """Sets the GPUs to grow the memory usage as is needed by the process.""" 162 | gpus = tf.config.experimental.list_physical_devices('GPU') 163 | if gpus: 164 | try: 165 | # Currently, memory growth needs to be the same across GPUs. 166 | for gpu in gpus: 167 | tf.config.experimental.set_memory_growth(gpu, True) 168 | except RuntimeError as e: 169 | # Memory growth must be set before GPUs have been initialized. 170 | print(e) 171 | 172 | 173 | def main(unused_argv): 174 | """Parse gin config and run ddsp training, evaluation, or sampling.""" 175 | restore_dir = os.path.expanduser(FLAGS.restore_dir) 176 | save_dir = os.path.expanduser(FLAGS.save_dir) 177 | # If no separate restore directory is given, use the save directory. 178 | restore_dir = save_dir if not restore_dir else restore_dir 179 | logging.info('Restore Dir: %s', restore_dir) 180 | logging.info('Save Dir: %s', save_dir) 181 | 182 | gfile.makedirs(restore_dir) # Only makes dirs if they don't exist. 183 | parse_gin(restore_dir) 184 | logging.info('Operative Gin Config:\n%s', gin.config.config_str()) 185 | 186 | if FLAGS.allow_memory_growth: 187 | allow_memory_growth() 188 | 189 | # Training. 190 | if FLAGS.mode == 'train': 191 | strategy = train_util.get_strategy(tpu=FLAGS.tpu, 192 | cluster_config=FLAGS.cluster_config) 193 | with strategy.scope(): 194 | model = models.get_model() 195 | trainer = trainers.get_trainer_class()(model, strategy) 196 | 197 | train_util.train(data_provider=gin.REQUIRED, 198 | trainer=trainer, 199 | save_dir=save_dir, 200 | restore_dir=restore_dir, 201 | early_stop_loss_value=FLAGS.early_stop_loss_value, 202 | report_loss_to_hypertune=FLAGS.hypertune) 203 | 204 | # Evaluation. 205 | elif FLAGS.mode == 'eval': 206 | model = models.get_model() 207 | delay_start() 208 | eval_util.evaluate(data_provider=gin.REQUIRED, 209 | model=model, 210 | save_dir=save_dir, 211 | restore_dir=restore_dir, 212 | run_once=FLAGS.run_once) 213 | 214 | # Sampling. 215 | elif FLAGS.mode == 'sample': 216 | model = models.get_model() 217 | delay_start() 218 | eval_util.sample(data_provider=gin.REQUIRED, 219 | model=model, 220 | save_dir=save_dir, 221 | restore_dir=restore_dir, 222 | run_once=FLAGS.run_once) 223 | 224 | 225 | def console_entry_point(): 226 | """From pip installed script.""" 227 | app.run(main) 228 | 229 | 230 | if __name__ == '__main__': 231 | console_entry_point() 232 | -------------------------------------------------------------------------------- /ddsp/training/decoders_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.training.decoders.""" 16 | 17 | import functools 18 | 19 | from absl.testing import parameterized 20 | from ddsp.training import decoders 21 | import numpy as np 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | class DilatedConvDecoderTest(parameterized.TestCase, tf.test.TestCase): 26 | 27 | def setUp(self): 28 | """Create some common default values for decoder.""" 29 | super().setUp() 30 | # For decoder. 31 | self.ch = 4 32 | self.layers_per_stack = 3 33 | self.stacks = 2 34 | self.output_splits = (('amps', 1), ('harmonic_distribution', 10), 35 | ('noise_magnitudes', 10)) 36 | 37 | # For audio features and conditioning. 38 | self.frame_rate = 100 39 | self.length_in_sec = 0.20 40 | self.time_steps = int(self.frame_rate * self.length_in_sec) 41 | 42 | def _gen_dummy_conditioning(self): 43 | """Generate dummy scaled f0 and ld conditioning.""" 44 | conditioning = {} 45 | # Generate dummy `f0_hz` with batch and channel dims. 46 | f0_hz_dummy = np.repeat(1.0, 47 | self.length_in_sec * self.frame_rate)[np.newaxis, :, 48 | np.newaxis] 49 | conditioning['f0_scaled'] = f0_hz_dummy # Testing correct shapes only. 50 | # Generate dummy `loudness_db` with batch and channel dims. 51 | loudness_db_dummy = np.repeat(1.0, self.length_in_sec * 52 | self.frame_rate)[np.newaxis, :, np.newaxis] 53 | conditioning[ 54 | 'ld_scaled'] = loudness_db_dummy # Testing correct shapes only. 55 | return conditioning 56 | 57 | def test_correct_output_splits_and_shapes_dilated_conv_decoder(self): 58 | decoder = decoders.DilatedConvDecoder( 59 | ch=self.ch, 60 | layers_per_stack=self.layers_per_stack, 61 | stacks=self.stacks, 62 | conditioning_keys=None, 63 | output_splits=self.output_splits) 64 | 65 | conditioning = self._gen_dummy_conditioning() 66 | output = decoder(conditioning) 67 | for output_name, output_dim in self.output_splits: 68 | dummy_output = np.zeros((1, self.time_steps, output_dim)) 69 | self.assertShapeEqual(dummy_output, output[output_name]) 70 | 71 | 72 | class RnnFcDecoderTest(parameterized.TestCase, tf.test.TestCase): 73 | 74 | def setUp(self): 75 | """Create some common default values for decoder.""" 76 | super().setUp() 77 | # For decoder. 78 | self.input_keys = ('pw_scaled', 'f0_scaled') 79 | self.output_splits = (('amps', 1), ('harmonic_distribution', 10)) 80 | self.n_batch = 2 81 | self.n_t = 4 82 | self.inputs = { 83 | 'pw_scaled': tf.ones([self.n_batch, self.n_t, 1]), 84 | 'f0_scaled': tf.ones([self.n_batch, self.n_t, 1]), 85 | } 86 | self.rnn_ch = 3 87 | self.get_decoder = functools.partial( 88 | decoders.RnnFcDecoder, 89 | rnn_channels=self.rnn_ch, 90 | ch=2, 91 | layers_per_stack=1, 92 | input_keys=self.input_keys, 93 | output_splits=self.output_splits, 94 | rnn_type='gru', 95 | ) 96 | 97 | @parameterized.named_parameters( 98 | ('stateful', False), 99 | ('stateless', True), 100 | ) 101 | def test_correct_outputs(self, stateless=False): 102 | decoder = self.get_decoder(stateless=stateless) 103 | 104 | # Add state. 105 | inputs = self.inputs 106 | if stateless: 107 | inputs['state'] = tf.ones([self.n_batch, self.rnn_ch]) 108 | 109 | # Run through the network 110 | outputs = decoder(inputs) 111 | 112 | # Check normal outputs. 113 | for name, dim in self.output_splits: 114 | dummy_output = np.zeros((self.n_batch, self.n_t, dim)) 115 | self.assertShapeEqual(dummy_output, outputs[name]) 116 | 117 | # Check the explicit state. 118 | if stateless: 119 | self.assertShapeEqual(inputs['state'], outputs['state']) 120 | self.assertNotAllEqual(inputs['state'], outputs['state']) 121 | 122 | 123 | if __name__ == '__main__': 124 | tf.test.main() 125 | -------------------------------------------------------------------------------- /ddsp/training/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Uses one of AI Platform base images. 2 | # You can try using different images however only this one has been tested. 3 | FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-2 4 | 5 | # Installs sndfile library for reading and writing audio files. 6 | RUN apt-get update && \ 7 | apt-get install --no-install-recommends -y libsndfile-dev 8 | 9 | # Upgrades Tensorflow and Tensorflow Probability 10 | # Newer version of Tensorflow is needed for multiple VMs training 11 | RUN pip install --upgrade pip && \ 12 | pip install --upgrade tensorflow tensorflow-probability 13 | 14 | # Installs cloudml-hypertune package needed for hyperparameter tuning 15 | RUN pip install cloudml-hypertune 16 | 17 | WORKDIR /root 18 | # Installs Magenta DDSP from Github. 19 | RUN wget https://github.com//magenta/ddsp/archive/main.zip && \ 20 | unzip main.zip && \ 21 | cd ddsp-main && \ 22 | python setup.py install 23 | 24 | # Copies running script. 25 | COPY task.py task.py 26 | 27 | ENTRYPOINT ["python", "task.py"] 28 | -------------------------------------------------------------------------------- /ddsp/training/docker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Docker code.""" 16 | 17 | from ddsp.training.docker import task 18 | -------------------------------------------------------------------------------- /ddsp/training/docker/config_hypertune.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | masterType: n1-highmem-64 4 | masterConfig: 5 | acceleratorConfig: 6 | count: 4 7 | type: NVIDIA_TESLA_T4 8 | useChiefInTfConfig: True 9 | hyperparameters: 10 | goal: MINIMIZE 11 | hyperparameterMetricTag: "Loss" 12 | maxTrials: 4 13 | maxParallelTrials: 4 14 | enableTrialEarlyStopping: True 15 | params: 16 | - parameterName: learning_rate 17 | type: DISCRETE 18 | discreteValues: 19 | - 0.0001 20 | - 3e-4 21 | - 0.001 22 | - 0.01 23 | -------------------------------------------------------------------------------- /ddsp/training/docker/config_multiple_vms.yaml: -------------------------------------------------------------------------------- 1 | # Recommended batch_size: 128 2 | # Recommended learning_rate: 0.001 3 | # Recommended number of steps: 15000 4 | # 5 | trainingInput: 6 | scaleTier: CUSTOM 7 | # Configures a chief worker with 2 T4 GPUs 8 | masterType: n1-highcpu-32 9 | masterConfig: 10 | acceleratorConfig: 11 | count: 2 12 | type: NVIDIA_TESLA_T4 13 | # Configures 3 workers, each with 2 T4 GPUs 14 | workerCount: 3 15 | workerType: n1-highcpu-32 16 | workerConfig: 17 | acceleratorConfig: 18 | count: 2 19 | type: NVIDIA_TESLA_T4 20 | # Makes AI Platform naming compatibile with Tensorflow naming 21 | useChiefInTfConfig: True 22 | -------------------------------------------------------------------------------- /ddsp/training/docker/config_single_vm.yaml: -------------------------------------------------------------------------------- 1 | # Recommended batch_size: 16 2 | # Recommended learning_rate: 0.0001 3 | # Recommended number of steps: 40000 4 | trainingInput: 5 | scaleTier: CUSTOM 6 | # Configures a single worker with 2 NVIDIA T4 GPUs 7 | masterType: n1-highcpu-16 8 | masterConfig: 9 | acceleratorConfig: 10 | count: 1 11 | type: NVIDIA_TESLA_T4 12 | useChiefInTfConfig: True 13 | -------------------------------------------------------------------------------- /ddsp/training/docker/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Script for running a containerized training on Google Cloud AI Platform.""" 16 | 17 | import json 18 | import os 19 | import subprocess 20 | 21 | from absl import app 22 | from absl import flags 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_string('save_dir', None, 27 | 'Path where checkpoints and summary events will be saved ' 28 | 'during training and evaluation.') 29 | flags.DEFINE_string('restore_dir', '', 30 | 'Path from which checkpoints will be restored before ' 31 | 'training. Can be different than the save_dir.') 32 | flags.DEFINE_string('file_pattern', None, 'Data file pattern') 33 | 34 | flags.DEFINE_integer('batch_size', 32, 'Batch size') 35 | flags.DEFINE_float('learning_rate', 0.003, 'Learning rate') 36 | flags.DEFINE_integer('num_steps', 30000, 'Number of training steps') 37 | flags.DEFINE_float('early_stop_loss_value', 0.0, 38 | 'Early stopping. When the total_loss reaches below this ' 39 | 'value training stops.') 40 | 41 | flags.DEFINE_integer('steps_per_summary', 300, 'Steps per summary') 42 | flags.DEFINE_integer('steps_per_save', 300, 'Steps per save') 43 | flags.DEFINE_boolean('hypertune', False, 44 | 'Enable metric reporting for hyperparameter tuning.') 45 | 46 | 47 | flags.DEFINE_multi_string('gin_search_path', [], 48 | 'Additional gin file search paths. ' 49 | 'Must be paths inside Docker container and ' 50 | 'necessary gin configs should be added at the ' 51 | 'Docker image building stage.') 52 | flags.DEFINE_multi_string('gin_file', [], 53 | 'List of paths to the config files. If file ' 54 | 'in gstorage bucket specify whole gstorage path: ' 55 | 'gs://bucket-name/dir/in/bucket/file.gin. If path ' 56 | 'should be local remember about copying the file ' 57 | 'inside the Docker container at building stage. ') 58 | flags.DEFINE_multi_string('gin_param', [], 59 | 'Newline separated list of Gin parameter bindings.') 60 | 61 | 62 | def get_worker_behavior_info(save_dir): 63 | """Infers worker behavior from the environment. 64 | 65 | Checks if TF_CONFIG environment variable is set 66 | and inferes cluster configuration and save_dir 67 | from it. 68 | 69 | Args: 70 | save_dir: Save directory given by the user. 71 | 72 | Returns: 73 | cluster_config: Inferred cluster configuration. 74 | save_dir: Inferred save directory. 75 | """ 76 | if 'TF_CONFIG' in os.environ: 77 | cluster_config = os.environ['TF_CONFIG'] 78 | cluster_config_dict = json.loads(cluster_config) 79 | if ('cluster' not in cluster_config_dict.keys() or 80 | 'task' not in cluster_config_dict or 81 | len(cluster_config_dict['cluster']) <= 1): 82 | cluster_config = '' 83 | elif cluster_config_dict['task']['type'] != 'chief': 84 | save_dir = '' 85 | else: 86 | cluster_config = '' 87 | 88 | return cluster_config, save_dir 89 | 90 | 91 | def parse_list_params(list_of_params, param_name): 92 | return [f'--{param_name}={param}' for param in list_of_params] 93 | 94 | 95 | def main(unused_argv): 96 | restore_dir = FLAGS.save_dir if not FLAGS.restore_dir else FLAGS.restore_dir 97 | 98 | cluster_config, save_dir = get_worker_behavior_info(FLAGS.save_dir) 99 | gin_search_path = parse_list_params(FLAGS.gin_search_path, 'gin_search_path') 100 | gin_file = parse_list_params(FLAGS.gin_file, 'gin_file') 101 | gin_param = parse_list_params(FLAGS.gin_param, 'gin_param') 102 | 103 | ddsp_run_command = ( 104 | ['ddsp_run', 105 | '--mode=train', 106 | '--alsologtostderr', 107 | '--gin_file=models/solo_instrument.gin', 108 | '-gin_file=datasets/tfrecord.gin', 109 | f'--cluster_config={cluster_config}', 110 | f'--save_dir={save_dir}', 111 | f'--restore_dir={restore_dir}', 112 | f'--hypertune={FLAGS.hypertune}', 113 | f'--early_stop_loss_value={FLAGS.early_stop_loss_value}', 114 | f'--gin_param=batch_size={FLAGS.batch_size}', 115 | f'--gin_param=learning_rate={FLAGS.learning_rate}', 116 | f'--gin_param=TFRecordProvider.file_pattern=\'{FLAGS.file_pattern}\'', 117 | f'--gin_param=train_util.train.num_steps={FLAGS.num_steps}', 118 | f'--gin_param=train_util.train.steps_per_save={FLAGS.steps_per_save}', 119 | ('--gin_param=train_util.train.steps_per_summary=' 120 | f'{FLAGS.steps_per_summary}')] 121 | + gin_search_path + gin_file + gin_param) 122 | 123 | subprocess.run(args=ddsp_run_command, check=True) 124 | 125 | if __name__ == '__main__': 126 | flags.mark_flag_as_required('file_pattern') 127 | flags.mark_flag_as_required('save_dir') 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /ddsp/training/docker/task_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for task.py.""" 16 | 17 | import os 18 | from unittest import mock 19 | 20 | from ddsp.training.docker import task 21 | import tensorflow.compat.v2 as tf 22 | 23 | 24 | class GetWorkerBehaviorInfoTest(tf.test.TestCase): 25 | 26 | def test_no_tf_config(self): 27 | """Tests behavior when there is no TF_CONFIG set.""" 28 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/') 29 | self.assertEqual(save_dir, 'some/dir/') 30 | self.assertEqual(cluster_config, '') 31 | 32 | def test_incomplete_tf_config(self): 33 | """Test behavior when set TF_CONFIG is incomplete.""" 34 | with mock.patch.dict(os.environ, {'TF_CONFIG': '{"cluster": {}}'}): 35 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/') 36 | self.assertEqual(save_dir, 'some/dir/') 37 | self.assertEqual(cluster_config, '') 38 | 39 | with mock.patch.dict(os.environ, {'TF_CONFIG': '{"task": {}}'}): 40 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/') 41 | self.assertEqual(save_dir, 'some/dir/') 42 | self.assertEqual(cluster_config, '') 43 | 44 | @mock.patch.dict( 45 | os.environ, 46 | {'TF_CONFIG': ('{"cluster": {"worker": ["worker0.example.com:2221"]},' 47 | '"task": {"type": "worker", "index": 0}}')}) 48 | def test_single_worker(self): 49 | """Tests behavior when cluster has only one worker.""" 50 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/') 51 | self.assertEqual(save_dir, 'some/dir/') 52 | self.assertEqual(cluster_config, '') 53 | 54 | @mock.patch.dict( 55 | os.environ, 56 | {'TF_CONFIG': ('{"cluster": {"worker": ["worker0.example.com:2221"],' 57 | '"chief": ["chief.example.com:2222"]},' 58 | '"task": {"type": "chief", "index": 0}}')}) 59 | def test_multi_worker_as_chief(self): 60 | """Tests multi-worker behavior when task type chief is set in TF_CONFIG.""" 61 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/') 62 | self.assertEqual(save_dir, 'some/dir/') 63 | self.assertEqual( 64 | cluster_config, 65 | ('{"cluster": {"worker": ["worker0.example.com:2221"],' 66 | '"chief": ["chief.example.com:2222"]},' 67 | '"task": {"type": "chief", "index": 0}}')) 68 | 69 | @mock.patch.dict( 70 | os.environ, 71 | {'TF_CONFIG': ('{"cluster": {"worker": ["worker0.example.com:2221"],' 72 | '"chief": ["chief.example.com:2222"]},' 73 | '"task": {"type": "worker", "index": 0}}')}) 74 | def test_multi_worker_as_worker(self): 75 | """Tests multi-worker behavior when task type worker is set in TF_CONFIG.""" 76 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/') 77 | self.assertEqual(save_dir, '') 78 | self.assertEqual( 79 | cluster_config, 80 | ('{"cluster": {"worker": ["worker0.example.com:2221"],' 81 | '"chief": ["chief.example.com:2222"]},' 82 | '"task": {"type": "worker", "index": 0}}')) 83 | 84 | 85 | if __name__ == '__main__': 86 | tf.test.main() 87 | -------------------------------------------------------------------------------- /ddsp/training/evaluators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of evaluator implementations for use in eval_util.""" 16 | import ddsp 17 | from ddsp.training import heuristics 18 | from ddsp.training import metrics 19 | from ddsp.training import summaries 20 | import gin 21 | import numpy as np 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | class BaseEvaluator(object): 26 | """Base class for evaluators.""" 27 | 28 | def __init__(self, sample_rate, frame_rate): 29 | self._sample_rate = sample_rate 30 | self._frame_rate = frame_rate 31 | 32 | def set_rates(self, sample_rate, frame_rate): 33 | """Sets sample and frame rates, not known in gin initialization.""" 34 | self._sample_rate = sample_rate 35 | self._frame_rate = frame_rate 36 | 37 | def evaluate(self, batch, output, losses): 38 | """Computes metrics.""" 39 | raise NotImplementedError() 40 | 41 | def sample(self, batch, outputs, step): 42 | """Computes and logs samples.""" 43 | raise NotImplementedError() 44 | 45 | def flush(self, step): 46 | """Logs metrics.""" 47 | raise NotImplementedError() 48 | 49 | 50 | 51 | @gin.register 52 | class BasicEvaluator(BaseEvaluator): 53 | """Computes audio samples and losses.""" 54 | 55 | def __init__(self, sample_rate, frame_rate): 56 | super().__init__(sample_rate, frame_rate) 57 | self._avg_losses = {} 58 | 59 | def evaluate(self, batch, outputs, losses): 60 | del outputs # Unused. 61 | if not self._avg_losses: 62 | self._avg_losses = { 63 | name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) 64 | for name in list(losses.keys()) 65 | } 66 | # Loss. 67 | for k, v in losses.items(): 68 | self._avg_losses[k].update_state(v) 69 | 70 | def sample(self, batch, outputs, step): 71 | audio = batch['audio'] 72 | audio_gen = outputs['audio_gen'] 73 | 74 | audio_gen = np.array(audio_gen) 75 | 76 | # Add audio. 77 | summaries.audio_summary( 78 | audio_gen, step, self._sample_rate, name='audio_generated') 79 | summaries.audio_summary( 80 | audio, step, self._sample_rate, name='audio_original') 81 | 82 | # Add plots. 83 | summaries.waveform_summary(audio, audio_gen, step) 84 | summaries.spectrogram_summary(audio, audio_gen, step) 85 | 86 | def flush(self, step): 87 | latest_losses = {} 88 | for k, metric in self._avg_losses.items(): 89 | latest_losses[k] = metric.result() 90 | tf.summary.scalar('losses/{}'.format(k), metric.result(), step=step) 91 | metric.reset_states() 92 | 93 | 94 | @gin.register 95 | class F0LdEvaluator(BaseEvaluator): 96 | """Computes F0 and loudness metrics.""" 97 | 98 | def __init__(self, sample_rate, frame_rate, run_f0_crepe=True): 99 | super().__init__(sample_rate, frame_rate) 100 | self._loudness_metrics = metrics.LoudnessMetrics( 101 | sample_rate=sample_rate, frame_rate=frame_rate) 102 | self._f0_metrics = metrics.F0Metrics( 103 | sample_rate=sample_rate, frame_rate=frame_rate) 104 | self._run_f0_crepe = run_f0_crepe 105 | if self._run_f0_crepe: 106 | self._f0_crepe_metrics = metrics.F0CrepeMetrics( 107 | sample_rate=sample_rate, frame_rate=frame_rate) 108 | 109 | def evaluate(self, batch, outputs, losses): 110 | del losses # Unused. 111 | audio_gen = outputs['audio_gen'] 112 | self._loudness_metrics.update_state(batch, audio_gen) 113 | 114 | if 'f0_hz' in outputs and 'f0_hz' in batch: 115 | self._f0_metrics.update_state(batch, outputs['f0_hz']) 116 | elif self._run_f0_crepe: 117 | self._f0_crepe_metrics.update_state(batch, audio_gen) 118 | 119 | def sample(self, batch, outputs, step): 120 | if 'f0_hz' in outputs and 'f0_hz' in batch: 121 | summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step, 122 | name='f0_harmonic') 123 | 124 | def flush(self, step): 125 | self._loudness_metrics.flush(step) 126 | self._f0_metrics.flush(step) 127 | if self._run_f0_crepe: 128 | self._f0_crepe_metrics.flush(step) 129 | 130 | 131 | @gin.register 132 | class TWMEvaluator(BaseEvaluator): 133 | """Evaluates F0s created with TWM heuristic.""" 134 | 135 | def __init__(self, 136 | sample_rate, 137 | frame_rate, 138 | processor_name='sinusoidal', 139 | noisy=False): 140 | super().__init__(sample_rate, frame_rate) 141 | self._noisy = noisy 142 | self._processor_name = processor_name 143 | self._f0_twm_metrics = metrics.F0Metrics( 144 | sample_rate=sample_rate, frame_rate=frame_rate, name='f0_twm') 145 | 146 | def _compute_twm_f0(self, outputs): 147 | """Computes F0 from sinusoids using TWM heuristic.""" 148 | processor_controls = outputs[self._processor_name]['controls'] 149 | freqs = processor_controls['frequencies'] 150 | amps = processor_controls['amplitudes'] 151 | if self._noisy: 152 | noise_ratios = processor_controls['noise_ratios'] 153 | amps = amps * (1.0 - noise_ratios) 154 | twm = ddsp.losses.TWMLoss() 155 | # Treat all freqs as candidate f0s. 156 | return twm.predict_f0(freqs, freqs, amps) 157 | 158 | def evaluate(self, batch, outputs, losses): 159 | del losses # Unused. 160 | twm_f0 = self._compute_twm_f0(outputs) 161 | self._f0_twm_metrics.update_state(batch, twm_f0) 162 | 163 | def sample(self, batch, outputs, step): 164 | twm_f0 = self._compute_twm_f0(outputs) 165 | summaries.f0_summary(batch['f0_hz'], twm_f0, step, name='f0_twm') 166 | 167 | def flush(self, step): 168 | self._f0_twm_metrics.flush(step) 169 | 170 | 171 | @gin.register 172 | class MidiAutoencoderEvaluator(BaseEvaluator): 173 | """Metrics for MIDI Autoencoder.""" 174 | 175 | def __init__(self, sample_rate, frame_rate, db_key='loudness_db', 176 | f0_key='f0_hz'): 177 | super().__init__(sample_rate, frame_rate) 178 | self._midi_metrics = metrics.MidiMetrics( 179 | frames_per_second=frame_rate, tag='learned') 180 | self._db_key = db_key 181 | self._f0_key = f0_key 182 | 183 | def evaluate(self, batch, outputs, losses): 184 | del losses # Unused. 185 | self._midi_metrics.update_state(outputs, outputs['pianoroll']) 186 | 187 | def sample(self, batch, outputs, step): 188 | audio = batch['audio'] 189 | summaries.audio_summary( 190 | audio, step, self._sample_rate, name='audio_original') 191 | 192 | audio_keys = ['midi_audio', 'synth_audio', 'midi_audio2', 'synth_audio2'] 193 | for k in audio_keys: 194 | if k in outputs and outputs[k] is not None: 195 | summaries.audio_summary(outputs[k], step, self._sample_rate, name=k) 196 | summaries.spectrogram_summary(audio, outputs[k], step, tag=k) 197 | summaries.waveform_summary(audio, outputs[k], step, name=k) 198 | 199 | summaries.f0_summary( 200 | batch[self._f0_key], outputs[f'{self._f0_key}_pred'], 201 | step, name='f0_hz_rec') 202 | 203 | summaries.pianoroll_summary(outputs, step, 'pianoroll', 204 | self._frame_rate, 'pianoroll') 205 | summaries.midiae_f0_summary(batch[self._f0_key], outputs, step) 206 | ld_rec = f'{self._db_key}_rec' 207 | if ld_rec in outputs: 208 | summaries.midiae_ld_summary(batch[self._db_key], outputs, step, 209 | self._db_key) 210 | 211 | summaries.midiae_sp_summary(outputs, step) 212 | 213 | def flush(self, step): 214 | self._midi_metrics.flush(step) 215 | 216 | 217 | @gin.register 218 | class MidiHeuristicEvaluator(BaseEvaluator): 219 | """Metrics for MIDI heuristic.""" 220 | 221 | def __init__(self, sample_rate, frame_rate): 222 | super().__init__(sample_rate, frame_rate) 223 | self._midi_metrics = metrics.MidiMetrics( 224 | tag='heuristic', frames_per_second=frame_rate) 225 | 226 | def _compute_heuristic_notes(self, outputs): 227 | return heuristics.segment_notes_batch( 228 | binarize_f=heuristics.midi_heuristic, 229 | pick_f0_f=heuristics.mean_f0, 230 | pick_amps_f=heuristics.median_amps, 231 | controls_batch=outputs) 232 | 233 | def evaluate(self, batch, outputs, losses): 234 | del losses # Unused. 235 | notes = self._compute_heuristic_notes(outputs) 236 | self._midi_metrics.update_state(outputs, notes) 237 | 238 | def sample(self, batch, outputs, step): 239 | notes = self._compute_heuristic_notes(outputs) 240 | outputs['heuristic_notes'] = notes 241 | summaries.midi_summary(outputs, step, 'heuristic', self._frame_rate, 242 | 'heuristic_notes') 243 | summaries.pianoroll_summary(outputs, step, 'heuristic', 244 | self._frame_rate, 'heuristic_notes') 245 | 246 | def flush(self, step): 247 | self._midi_metrics.flush(step) 248 | 249 | 250 | -------------------------------------------------------------------------------- /ddsp/training/gin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Gin Configs: Datasets 2 | 3 | This directory contains gin configs for datasets. Each file contains datasets to use for training, evaluation, and sampling. 4 | Each training run with `ddsp_run.py` must be supplied both a model file and a dataset file, each with the `--gin_file` flag. 5 | Evaluation and sampling jobs only require a dataset file. 6 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/base.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | import ddsp.training 3 | 4 | # Evaluate 5 | evaluate.batch_size = 32 6 | evaluate.num_batches = 5 # Depends on dataset size. 7 | 8 | # Sample 9 | sample.batch_size = 16 10 | sample.num_batches = 1 11 | sample.ckpt_delay_secs = 300 # 5 minutes 12 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/nsynth.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'datasets/base.gin' 3 | 4 | # Dataset (Different scopes to pass different args to different instances) 5 | train.data_provider = @train_data/data.NSynthTfds() 6 | train_data/NSynthTfds.split = 'train' 7 | 8 | evaluate.data_provider = @eval_data/data.NSynthTfds() 9 | eval_data/data.NSynthTfds.split = 'valid' 10 | 11 | sample.data_provider = @sample_data/data.NSynthTfds() 12 | sample_data/data.NSynthTfds.split = 'test' 13 | 14 | # Evaluate 15 | evaluate.num_batches = 50 # Full test set ~17000 samples 16 | 17 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/tfrecord.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'datasets/base.gin' 3 | 4 | # Make dataset with ddsp/training/data_preparation/ddsp_prepare_tfrecord.py 5 | # --gin_param="TFRecordProvider.file_pattern='/path/to/dataset*.tfrecord'" 6 | 7 | # Dataset 8 | train.data_provider = @data.TFRecordProvider() 9 | evaluate.data_provider = @data.TFRecordProvider() 10 | sample.data_provider = @data.TFRecordProvider() 11 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/urmp/README.md: -------------------------------------------------------------------------------- 1 | # URMP gin configs 2 | 3 | These gin configs are for the [URMP dataset](http://www2.ece.rochester.edu/projects/air/projects/URMP.html). 4 | 5 | To use, add the base path as a param with `--gin-param=base_path='path/to/urmp_f0_interop/'`. 6 | 7 | To select a specific instrument, set `--gin-param=instrument_key='sax'`. 8 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/urmp/all.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'datasets/urmp/base.gin' 3 | 4 | # All URMP instruments 5 | Urmp.instrument_key = 'all' 6 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/urmp/all_midi.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'datasets/urmp/midi_base.gin' 3 | 4 | # All URMP instruments 5 | UrmpMidi.instrument_key = 'all' 6 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/urmp/base.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'datasets/base.gin' 3 | 4 | train.data_provider = @train_data/data.Urmp() 5 | train_data/data.Urmp.split = 'train' 6 | 7 | evaluate.data_provider = @test_data/data.Urmp() 8 | sample.data_provider = @test_data/data.Urmp() 9 | test_data/data.Urmp.split = 'test' 10 | -------------------------------------------------------------------------------- /ddsp/training/gin/datasets/urmp/midi_base.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'datasets/base.gin' 3 | 4 | train.data_provider = @train_data/data.UrmpMidi() 5 | train_data/data.UrmpMidi.split = 'train' 6 | 7 | evaluate.data_provider = @test_data/data.UrmpMidi() 8 | sample.data_provider = @test_data/data.UrmpMidi() 9 | eval_discrete.data_provider = @test_data/data.UrmpMidi() 10 | test_data/data.UrmpMidi.split = 'test' 11 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/basic.gin: -------------------------------------------------------------------------------- 1 | evaluators = [ 2 | @BasicEvaluator, 3 | ] 4 | 5 | evaluate.evaluator_classes = %evaluators 6 | sample.evaluator_classes = %evaluators 7 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/basic_f0_ld.gin: -------------------------------------------------------------------------------- 1 | evaluators = [ 2 | @BasicEvaluator, 3 | @F0LdEvaluator 4 | ] 5 | 6 | evaluate.evaluator_classes = %evaluators 7 | sample.evaluator_classes = %evaluators 8 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/basic_f0_ld_twm.gin: -------------------------------------------------------------------------------- 1 | evaluators = [ 2 | @BasicEvaluator, 3 | @F0LdEvaluator, 4 | @TWMEvaluator 5 | ] 6 | 7 | evaluate.evaluator_classes = %evaluators 8 | sample.evaluator_classes = %evaluators 9 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/heuristic.gin: -------------------------------------------------------------------------------- 1 | segment_notes_batch.binarize_f = @heuristics.midi_heuristic 2 | segment_notes_batch.pick_f0_f = @heuristics.mean_f0 3 | segment_notes_batch.pick_amps_f = @heuristics.median_amps 4 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/heuristic_power.gin: -------------------------------------------------------------------------------- 1 | segment_notes_batch.binarize_f = @heuristics.midi_heuristic_power 2 | segment_notes_batch.pick_f0_f = @heuristics.mean_f0 3 | segment_notes_batch.pick_amps_f = @heuristics.median_amps 4 | -------------------------------------------------------------------------------- /ddsp/training/gin/eval/midi_ae.gin: -------------------------------------------------------------------------------- 1 | evaluators = [ 2 | @BasicEvaluator, 3 | @MidiAutoencoderEvaluator, 4 | ] 5 | 6 | evaluate.evaluator_classes = %evaluators 7 | sample.evaluator_classes = %evaluators 8 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/README.md: -------------------------------------------------------------------------------- 1 | # Gin Configs: Models 2 | 3 | This directory contains gin configs for model architectures (including ddsp modules). 4 | Each training run with `ddsp_run.py` must be supplied both a model file and a dataset file, each with the `--gin_file` flag. 5 | Evaluation and sampling jobs only require a dataset file. 6 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/ae.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # Autoencoder that decodes from (loudness, f0, z). 3 | # z = encoder(audio) 4 | 5 | import ddsp 6 | import ddsp.training 7 | 8 | # ===== 9 | # Model 10 | # ===== 11 | get_model.model = @models.Autoencoder() 12 | 13 | # Preprocessor 14 | Autoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor() 15 | F0LoudnessPreprocessor.time_steps = 1000 16 | 17 | # Encoder 18 | Autoencoder.encoder = @encoders.MfccTimeDistributedRnnEncoder() 19 | MfccTimeDistributedRnnEncoder.rnn_channels = 512 20 | MfccTimeDistributedRnnEncoder.rnn_type = 'gru' 21 | MfccTimeDistributedRnnEncoder.z_dims = 16 22 | MfccTimeDistributedRnnEncoder.z_time_steps = 125 23 | 24 | # Decoder 25 | Autoencoder.decoder = @decoders.RnnFcDecoder() 26 | RnnFcDecoder.rnn_channels = 512 27 | RnnFcDecoder.rnn_type = 'gru' 28 | RnnFcDecoder.ch = 512 29 | RnnFcDecoder.layers_per_stack = 3 30 | RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled', 'z') 31 | RnnFcDecoder.output_splits = (('amps', 1), 32 | ('harmonic_distribution', 100), 33 | ('noise_magnitudes', 65)) 34 | 35 | # Losses 36 | Autoencoder.losses = [ 37 | @losses.SpectralLoss(), 38 | ] 39 | SpectralLoss.loss_type = 'L1' 40 | SpectralLoss.mag_weight = 1.0 41 | SpectralLoss.logmag_weight = 1.0 42 | 43 | # ============== 44 | # ProcessorGroup 45 | # ============== 46 | 47 | Autoencoder.processor_group = @processors.ProcessorGroup() 48 | 49 | ProcessorGroup.dag = [ 50 | (@synths.Harmonic(), 51 | ['amps', 'harmonic_distribution', 'f0_hz']), 52 | (@synths.FilteredNoise(), 53 | ['noise_magnitudes']), 54 | (@processors.Add(), 55 | ['filtered_noise/signal', 'harmonic/signal']), 56 | ] 57 | 58 | # Harmonic Synthesizer 59 | Harmonic.name = 'harmonic' 60 | Harmonic.n_samples = 64000 61 | Harmonic.sample_rate = 16000 62 | Harmonic.normalize_below_nyquist = True 63 | Harmonic.scale_fn = @core.exp_sigmoid 64 | 65 | # Filtered Noise Synthesizer 66 | FilteredNoise.name = 'filtered_noise' 67 | FilteredNoise.n_samples = 64000 68 | FilteredNoise.window_size = 0 69 | FilteredNoise.scale_fn = @core.exp_sigmoid 70 | 71 | # Add 72 | processors.Add.name = 'add' 73 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/README.md: -------------------------------------------------------------------------------- 1 | Configs for the MidiAutoencoder family of models (decoding from MIDI to audio). 2 | 3 | Configs that go: 4 | 5 | * Synth Params -> MIDI -> Synth Params 6 | 7 | Instead of: 8 | 9 | * LD/DB -> MIDI -> LD/DB 10 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/midiae.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # MidiAutoencoder (base class) 3 | 4 | import ddsp 5 | import ddsp.training 6 | 7 | include 'eval/midi_ae.gin' 8 | include 'models/midiae/mixins/recon_lossgroup.gin' 9 | 10 | # Training 11 | train.num_steps = 500000 12 | 13 | 14 | # ============================================================================== 15 | # Model 16 | # ============================================================================== 17 | get_model.model = @models.MidiAutoencoder() 18 | 19 | 20 | # Preprocessor 21 | MidiAutoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor() 22 | F0LoudnessPreprocessor.time_steps = 1000 23 | 24 | 25 | # Synthcoder 26 | MidiAutoencoder.synthcoder = @decoders.DilatedConvDecoder() 27 | DilatedConvDecoder.ch = 128 28 | DilatedConvDecoder.layers_per_stack = 9 29 | DilatedConvDecoder.norm_type = 'layer' 30 | DilatedConvDecoder.input_keys = ('ld_scaled', 'f0_scaled') 31 | DilatedConvDecoder.stacks = 2 32 | DilatedConvDecoder.conditioning_keys = None 33 | DilatedConvDecoder.precondition_stack = None 34 | DilatedConvDecoder.output_splits = (('amplitudes', 1), 35 | ('harmonic_distribution', 60), 36 | ('magnitudes', 65)) 37 | 38 | 39 | # Stop Gradient 40 | MidiAutoencoder.sg_before_midiae = True 41 | 42 | 43 | # MIDI Encoder 44 | MidiAutoencoder.midi_encoder = None 45 | 46 | 47 | # MIDI Decoder 48 | MidiAutoencoder.midi_decoder = @decoders.MidiToHarmonicDecoder() 49 | MidiToHarmonicDecoder.f0_residual = True 50 | MidiToHarmonicDecoder.norm = True 51 | MidiToHarmonicDecoder.output_splits = (('f0_midi', 1), 52 | ('amplitudes', 1), 53 | ('harmonic_distribution', 60), 54 | ('magnitudes', 65)) 55 | MidiToHarmonicDecoder.net = @dec/nn.DilatedConvStack() 56 | dec/DilatedConvStack.ch = 128 57 | dec/DilatedConvStack.layers_per_stack = 5 58 | dec/DilatedConvStack.stacks = 4 59 | dec/DilatedConvStack.norm_type = 'layer' 60 | dec/DilatedConvStack.conditional = False 61 | 62 | 63 | # ============================================================================== 64 | # Losses 65 | # ============================================================================== 66 | MidiAutoencoder.reconstruction_losses = @recon_lossgroup/losses.LossGroup() 67 | 68 | 69 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/mixins/README.md: -------------------------------------------------------------------------------- 1 | # Mixins 2 | 3 | These gin files have independent effects of eachother so multiple of them 4 | can be included to mix in various configurations. 5 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/mixins/_.gin: -------------------------------------------------------------------------------- 1 | # Empty config so that we can do hyper sweeps over mixin files, with none being 2 | # an option. 3 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/mixins/hmm_prior.gin: -------------------------------------------------------------------------------- 1 | # Stop Gradient 2 | MidiAutoencoder.hmm_prior = @losses.HmmPrior() 3 | HmmPrior.avg_length = 200 4 | HmmPrior.midi_std = 0.5 5 | HmmPrior.n_timesteps = 1000 6 | HmmPrior.n_pitches = 128 7 | HmmPrior.weight = 5.0 8 | 9 | MidiAutoencoder.hmm_quantize = False 10 | 11 | # # Latent Closeness Losses 12 | MidiAutoencoder.qpitch_f0rec_loss = @qpitch/losses.MarginLoss() 13 | qpitch/MarginLoss.weight = 50.0 14 | qpitch/MarginLoss.margin = 0.5 15 | qpitch/MarginLoss.name = 'q_pitch-f0_rec' 16 | 17 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/mixins/midi_encoder.gin: -------------------------------------------------------------------------------- 1 | # Turn on the MIDI Encoder network (it's off by default) 2 | 3 | 4 | # HACK(emanilow): initialize all the model types (gin will ignore unused) 5 | MidiAutoencoder.midi_encoder = @encoders.HarmonicToMidiEncoder() 6 | ZMidiAutoencoder.midi_encoder = @encoders.HarmonicToMidiEncoder() 7 | HmmMidiAutoencoder.midi_encoder = @encoders.HarmonicToMidiEncoder() 8 | 9 | # --- MIDI Encoder details 10 | HarmonicToMidiEncoder.f0_residual = True 11 | HarmonicToMidiEncoder.net = @enc/nn.DilatedConvStack() 12 | enc/DilatedConvStack.ch = 128 13 | enc/DilatedConvStack.layers_per_stack = 5 14 | enc/DilatedConvStack.stacks = 4 15 | enc/DilatedConvStack.norm_type = 'layer' 16 | enc/DilatedConvStack.conditional = False 17 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/mixins/recon_lossgroup.gin: -------------------------------------------------------------------------------- 1 | # Standard reconstruction losses for the MidiAutoencoder encapsulated in a 2 | # LossGroup 3 | 4 | 5 | recon_lossgroup/LossGroup.dag = [ 6 | ['synth_spectral_loss', ['audio', 'synth_audio']], 7 | ['f0_loss', ['f0_midi', 'f0_midi_pred', 'f0_loss_weights']], 8 | ['amps_loss', ['amps', 'amps_pred']], 9 | ['hd_loss', ['hd', 'hd_pred']], 10 | ['noise_loss', ['noise', 'noise_pred']] 11 | ] 12 | 13 | recon_lossgroup/LossGroup.amps_loss = @amps_loss/losses.ParamLoss() 14 | amps_loss/ParamLoss.weight = 0.5 15 | amps_loss/ParamLoss.loss_type = 'L1' 16 | amps_loss/ParamLoss.name = 'amplitude_reconstruction' 17 | 18 | recon_lossgroup/LossGroup.f0_loss = @f0_loss/losses.ParamLoss() 19 | f0_loss/ParamLoss.weight = 50.0 20 | f0_loss/ParamLoss.loss_type = 'L2' 21 | f0_loss/ParamLoss.name = 'f0_reconstruction' 22 | 23 | recon_lossgroup/LossGroup.hd_loss = @hd_loss/losses.ParamLoss() 24 | hd_loss/ParamLoss.weight = 500.0 25 | hd_loss/ParamLoss.loss_type = 'L1' 26 | hd_loss/ParamLoss.name = 'harmonic_distribution_reconstruction' 27 | 28 | recon_lossgroup/LossGroup.noise_loss = @noise_loss/losses.ParamLoss() 29 | noise_loss/ParamLoss.weight = 0.5 30 | noise_loss/ParamLoss.loss_type = 'L1' 31 | noise_loss/ParamLoss.name = 'noise_reconstruction' 32 | 33 | recon_lossgroup/LossGroup.synth_spectral_loss = @losses.SpectralLoss() 34 | SpectralLoss.loss_type = 'L1' 35 | SpectralLoss.mag_weight = 1.0 36 | SpectralLoss.logmag_weight = 1.0 37 | SpectralLoss.name = 'spectral_loss_synth' 38 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/midiae/z_midiae.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # ZMidiAutoencoder 3 | 4 | import ddsp 5 | import ddsp.training 6 | 7 | include 'eval/midi_ae.gin' 8 | include 'models/midiae/mixins/recon_lossgroup.gin' 9 | 10 | # Training 11 | train.num_steps = 500000 12 | 13 | 14 | # ============================================================================== 15 | # Model 16 | # ============================================================================== 17 | get_model.model = @models.ZMidiAutoencoder() 18 | 19 | 20 | # Preprocessor 21 | ZMidiAutoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor() 22 | F0LoudnessPreprocessor.time_steps = 1000 23 | 24 | 25 | # Synth encoder 26 | ZMidiAutoencoder.z_synth_encoders = [ 27 | @encoders.OneHotEncoder(), 28 | @encoders.MfccTimeDistributedRnnEncoder(), 29 | ] 30 | MfccTimeDistributedRnnEncoder.rnn_channels = 512 31 | MfccTimeDistributedRnnEncoder.rnn_type = 'gru' 32 | MfccTimeDistributedRnnEncoder.z_dims = 128 33 | MfccTimeDistributedRnnEncoder.z_time_steps = 1000 34 | 35 | 36 | # OneHotEncoder 37 | OneHotEncoder.one_hot_key = 'instrument_id' 38 | OneHotEncoder.vocab_size = 13 # num instruments 39 | OneHotEncoder.n_dims = 128 40 | OneHotEncoder.skip_expand = False 41 | 42 | 43 | # Synthcoder 44 | ZMidiAutoencoder.synthcoder = @decoders.DilatedConvDecoder() 45 | DilatedConvDecoder.ch = 128 46 | DilatedConvDecoder.layers_per_stack = 9 47 | DilatedConvDecoder.norm_type = 'layer' 48 | DilatedConvDecoder.input_keys = ('ld_scaled', 'f0_scaled') 49 | DilatedConvDecoder.stacks = 2 50 | DilatedConvDecoder.conditioning_keys = ('z',) 51 | DilatedConvDecoder.precondition_stack = None 52 | DilatedConvDecoder.output_splits = (('amplitudes', 1), 53 | ('harmonic_distribution', 60), 54 | ('magnitudes', 65)) 55 | 56 | 57 | # Stop Gradient 58 | ZMidiAutoencoder.sg_before_midiae = True 59 | 60 | 61 | # MIDI Encoder 62 | ZMidiAutoencoder.midi_encoder = None 63 | 64 | 65 | # MIDI Global Latents 66 | ZMidiAutoencoder.z_global_encoders = [ 67 | @encoders.OneHotEncoder(), 68 | @ee/encoders.ExpressionEncoder(), 69 | ] 70 | 71 | 72 | # Global Expression Encoder 73 | ee/ExpressionEncoder.input_keys = ('f0_scaled', 74 | 'amps_scaled', 75 | 'hd_scaled', 76 | 'noise_scaled',) 77 | ee/ExpressionEncoder.z_dims = 128 78 | ee/ExpressionEncoder.pool_time = True 79 | ee/ExpressionEncoder.net = @z/nn.DilatedConvStack() 80 | 81 | 82 | # Z Note Encoder 83 | ZMidiAutoencoder.z_note_encoder = @zn/encoders.ExpressionEncoder() 84 | zn/ExpressionEncoder.input_keys = ('f0_scaled', 85 | 'amps_scaled', 86 | 'hd_scaled', 87 | 'noise_scaled',) 88 | zn/ExpressionEncoder.z_dims = 128 89 | zn/ExpressionEncoder.pool_time = False 90 | zn/ExpressionEncoder.net = @z/nn.DilatedConvStack() 91 | 92 | 93 | # Same DilatedConvStack for Note encoder and Global Expression Encoder. 94 | z/DilatedConvStack.ch = 128 95 | z/DilatedConvStack.layers_per_stack = 5 96 | z/DilatedConvStack.stacks = 4 97 | z/DilatedConvStack.norm_type = 'layer' 98 | z/DilatedConvStack.conditional = False 99 | 100 | 101 | # Z Preconditioning Stack 102 | ZMidiAutoencoder.z_preconditioning_stack = @nn.FcStackOut() 103 | nn.FcStackOut.ch = 512 104 | nn.FcStackOut.n_out = 256 105 | nn.FcStackOut.layers = 5 106 | 107 | 108 | # MIDI Decoder 109 | ZMidiAutoencoder.midi_decoder = @decoders.MidiToHarmonicDecoder() 110 | MidiToHarmonicDecoder.f0_residual = True 111 | MidiToHarmonicDecoder.norm = True 112 | MidiToHarmonicDecoder.output_splits = (('f0_midi', 1), 113 | ('amplitudes', 1), 114 | ('harmonic_distribution', 60), 115 | ('magnitudes', 65)) 116 | MidiToHarmonicDecoder.net = @dec/nn.DilatedConvStack() 117 | dec/DilatedConvStack.ch = 128 118 | dec/DilatedConvStack.layers_per_stack = 5 119 | dec/DilatedConvStack.stacks = 4 120 | dec/DilatedConvStack.norm_type = 'layer' 121 | dec/DilatedConvStack.conditional = True 122 | 123 | 124 | # ============================================================================== 125 | # Losses 126 | # ============================================================================== 127 | # Reconstruction Loss 128 | ZMidiAutoencoder.reconstruction_losses = @recon_lossgroup/losses.LossGroup() 129 | 130 | 131 | # ZMidiAutoencoder.z_global_prior = @zg/losses.GaussianPrior() 132 | # zg/GaussianPrior.weight = 1.0 133 | # zg/GaussianPrior.name = 'global_kl' 134 | 135 | 136 | # ZMidiAutoencoder.z_note_prior = @zn/losses.GaussianPrior() 137 | # zn/GaussianPrior.weight = 1.0 138 | # zn/GaussianPrior.name = 'note_kl' 139 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/solo_instrument.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # Decodes from (loudness, f0). Has a trainable reverb component as well. 3 | # Since it uses a trainable reverb, training data should all be from the same 4 | # acoustic environment. 5 | 6 | include 'models/ae.gin' 7 | 8 | # Encoder 9 | Autoencoder.encoder = None 10 | 11 | # Decoder 12 | Autoencoder.decoder = @decoders.RnnFcDecoder() 13 | RnnFcDecoder.rnn_channels = 512 14 | RnnFcDecoder.rnn_type = 'gru' 15 | RnnFcDecoder.ch = 512 16 | RnnFcDecoder.layers_per_stack = 3 17 | RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled') 18 | RnnFcDecoder.output_splits = (('amps', 1), 19 | ('harmonic_distribution', 60), 20 | ('noise_magnitudes', 65)) 21 | 22 | # ============== 23 | # ProcessorGroup 24 | # ============== 25 | 26 | ProcessorGroup.dag = [ 27 | (@synths.Harmonic(), 28 | ['amps', 'harmonic_distribution', 'f0_hz']), 29 | (@synths.FilteredNoise(), 30 | ['noise_magnitudes']), 31 | (@processors.Add(), 32 | ['filtered_noise/signal', 'harmonic/signal']), 33 | (@effects.Reverb(), 34 | ['add/signal']), 35 | ] 36 | 37 | # Reverb 38 | Reverb.name = 'reverb' 39 | Reverb.reverb_length = 48000 40 | Reverb.trainable = True 41 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/vst/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/vst/vst.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # Autoencoder that decodes from (power, f0). 3 | 4 | import ddsp 5 | import ddsp.training 6 | 7 | # ===== 8 | # Model 9 | # ===== 10 | get_model.model = @models.Autoencoder() 11 | 12 | # Globals 13 | frame_rate = 50 14 | frame_size = 1024 15 | sample_rate = 16000 16 | n_samples = 64320 # Extra frame for center padding. 17 | 18 | 19 | # Preprocessor 20 | # Use same preprocessor for creating dataset and for training / inference. 21 | Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor() 22 | OnlineF0PowerPreprocessor: 23 | frame_rate = %frame_rate 24 | frame_size = %frame_size 25 | padding = 'center' 26 | compute_power = True 27 | compute_f0 = False 28 | crepe_saved_model_path = None 29 | 30 | 31 | # Encoder 32 | Autoencoder.encoder = None 33 | 34 | 35 | # Decoder 36 | Autoencoder.decoder = @decoders.RnnFcDecoder() 37 | RnnFcDecoder: 38 | rnn_channels = 512 39 | rnn_type = 'gru' 40 | ch = 256 41 | layers_per_stack = 1 42 | input_keys = ('pw_scaled', 'f0_scaled') 43 | output_splits = (('amps', 1), 44 | ('harmonic_distribution', 60), 45 | ('noise_magnitudes', 65)) 46 | 47 | # Losses 48 | Autoencoder.losses = [ 49 | @losses.SpectralLoss(), 50 | ] 51 | SpectralLoss: 52 | loss_type = 'L1' 53 | mag_weight = 1.0 54 | logmag_weight = 1.0 55 | 56 | # ============== 57 | # ProcessorGroup 58 | # ============== 59 | 60 | Autoencoder.processor_group = @processors.ProcessorGroup() 61 | 62 | # ============== 63 | # ProcessorGroup 64 | # ============== 65 | 66 | # Has a "Crop" processor to remove the padding from centered frames. 67 | 68 | ProcessorGroup.dag = [ 69 | (@synths.Harmonic(), 70 | ['amps', 'harmonic_distribution', 'f0_hz']), 71 | (@synths.FilteredNoise(), 72 | ['noise_magnitudes']), 73 | (@processors.Add(), 74 | ['filtered_noise/signal', 'harmonic/signal']), 75 | (@effects.FilteredNoiseReverb(), 76 | ['add/signal']), 77 | (@processors.Crop(), 78 | ['reverb/signal']) 79 | ] 80 | 81 | # Reverb 82 | FilteredNoiseReverb: 83 | name = 'reverb' 84 | reverb_length = 24000 85 | n_frames = 500 86 | n_filter_banks = 32 87 | trainable = True 88 | 89 | # Harmonic Synthesizer 90 | Harmonic: 91 | name = 'harmonic' 92 | n_samples = %n_samples 93 | sample_rate = %sample_rate 94 | normalize_below_nyquist = True 95 | scale_fn = @core.exp_sigmoid 96 | amp_resample_method = 'linear' 97 | 98 | # Filtered Noise Synthesizer 99 | FilteredNoise: 100 | name = 'filtered_noise' 101 | n_samples = %n_samples 102 | window_size = 0 103 | scale_fn = @core.exp_sigmoid 104 | 105 | # Add 106 | processors.Add.name = 'add' 107 | 108 | # Remove the extra frame of synthesis from centering. 109 | # Since generation is forward. 110 | # Frame size is the frame of the "synthesis" which is just the hop size. 111 | Crop: 112 | frame_size = 320 113 | crop_location = 'back' 114 | 115 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/vst/vst_32k.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # Autoencoder that decodes from (power, f0). 3 | # All lengths from the 16kHz config * 2. 4 | 5 | import ddsp 6 | import ddsp.training 7 | 8 | # ===== 9 | # Model 10 | # ===== 11 | get_model.model = @models.Autoencoder() 12 | 13 | # Globals 14 | frame_rate = 50 15 | frame_size = 1024 16 | sample_rate = 32000 17 | n_samples = 128640 # Extra frame for center padding. 32000 * 4 + 640 18 | 19 | 20 | # Preprocessor 21 | # Use same preprocessor for creating dataset and for training / inference. 22 | Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor() 23 | OnlineF0PowerPreprocessor: 24 | frame_rate = %frame_rate 25 | frame_size = %frame_size 26 | padding = 'center' 27 | compute_power = True 28 | compute_f0 = False 29 | crepe_saved_model_path = None 30 | 31 | 32 | # Encoder 33 | Autoencoder.encoder = None 34 | 35 | 36 | # Decoder 37 | Autoencoder.decoder = @decoders.RnnFcDecoder() 38 | RnnFcDecoder: 39 | rnn_channels = 512 40 | rnn_type = 'gru' 41 | ch = 256 42 | layers_per_stack = 1 43 | input_keys = ('pw_scaled', 'f0_scaled') 44 | output_splits = (('amps', 1), 45 | ('harmonic_distribution', 100), 46 | ('noise_magnitudes', 98)) 47 | 48 | # Losses 49 | Autoencoder.losses = [ 50 | @losses.SpectralLoss(), 51 | ] 52 | SpectralLoss: 53 | loss_type = 'L1' 54 | mag_weight = 1.0 55 | logmag_weight = 1.0 56 | fft_sizes = [4096, 2048, 1024, 512, 256, 128] 57 | 58 | # 16kHz fft sizes: (2048, 1024, 512, 256, 128, 64) 59 | 60 | # ============== 61 | # ProcessorGroup 62 | # ============== 63 | 64 | Autoencoder.processor_group = @processors.ProcessorGroup() 65 | 66 | # ============== 67 | # ProcessorGroup 68 | # ============== 69 | 70 | # Has a "Crop" processor to remove the padding from centered frames. 71 | 72 | ProcessorGroup.dag = [ 73 | (@synths.Harmonic(), 74 | ['amps', 'harmonic_distribution', 'f0_hz']), 75 | (@synths.FilteredNoise(), 76 | ['noise_magnitudes']), 77 | (@processors.Add(), 78 | ['filtered_noise/signal', 'harmonic/signal']), 79 | (@effects.FilteredNoiseReverb(), 80 | ['add/signal']), 81 | (@processors.Crop(), 82 | ['reverb/signal']) 83 | ] 84 | 85 | # Reverb 86 | FilteredNoiseReverb: 87 | name = 'reverb' 88 | reverb_length = 48000 89 | n_frames = 500 90 | n_filter_banks = 32 91 | initial_bias = -4.0 92 | trainable = True 93 | 94 | # Harmonic Synthesizer 95 | Harmonic: 96 | name = 'harmonic' 97 | n_samples = %n_samples 98 | sample_rate = %sample_rate 99 | normalize_below_nyquist = True 100 | scale_fn = @core.exp_sigmoid 101 | amp_resample_method = 'linear' 102 | use_angular_cumsum = True # Necessary at 48k as oscillators precess more. 103 | 104 | # Filtered Noise Synthesizer 105 | FilteredNoise: 106 | name = 'filtered_noise' 107 | n_samples = %n_samples 108 | window_size = 0 109 | scale_fn = @core.exp_sigmoid 110 | 111 | # Add 112 | processors.Add.name = 'add' 113 | 114 | # Remove the extra frame of synthesis from centering. 115 | # Since generation is forward. 116 | # Frame size is the frame of the "synthesis" which is just the hop size. 117 | Crop: 118 | frame_size = 640 119 | crop_location = 'back' 120 | 121 | -------------------------------------------------------------------------------- /ddsp/training/gin/models/vst/vst_48k.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # Autoencoder that decodes from (power, f0). 3 | # All lengths from the 16kHz config * 3. 4 | 5 | import ddsp 6 | import ddsp.training 7 | 8 | # ===== 9 | # Model 10 | # ===== 11 | get_model.model = @models.Autoencoder() 12 | 13 | # Globals 14 | frame_rate = 50 15 | frame_size = 1024 16 | sample_rate = 48000 17 | n_samples = 192960 # Extra frame for center padding. 48000 * 4 + 960 18 | 19 | 20 | # Preprocessor 21 | # Use same preprocessor for creating dataset and for training / inference. 22 | Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor() 23 | OnlineF0PowerPreprocessor: 24 | frame_rate = %frame_rate 25 | frame_size = %frame_size 26 | padding = 'center' 27 | compute_power = True 28 | compute_f0 = False 29 | crepe_saved_model_path = None 30 | 31 | 32 | # Encoder 33 | Autoencoder.encoder = None 34 | 35 | 36 | # Decoder 37 | Autoencoder.decoder = @decoders.RnnFcDecoder() 38 | RnnFcDecoder: 39 | rnn_channels = 512 40 | rnn_type = 'gru' 41 | ch = 256 42 | layers_per_stack = 1 43 | input_keys = ('pw_scaled', 'f0_scaled') 44 | output_splits = (('amps', 1), 45 | ('harmonic_distribution', 100), 46 | ('noise_magnitudes', 98)) 47 | 48 | # Losses 49 | Autoencoder.losses = [ 50 | @losses.SpectralLoss(), 51 | ] 52 | SpectralLoss: 53 | loss_type = 'L1' 54 | mag_weight = 1.0 55 | logmag_weight = 1.0 56 | fft_sizes = [6144, 3072, 1536, 768, 384, 192] 57 | 58 | # 16kHz fft sizes: (2048, 1024, 512, 256, 128, 64) 59 | 60 | # ============== 61 | # ProcessorGroup 62 | # ============== 63 | 64 | Autoencoder.processor_group = @processors.ProcessorGroup() 65 | 66 | # ============== 67 | # ProcessorGroup 68 | # ============== 69 | 70 | # Has a "Crop" processor to remove the padding from centered frames. 71 | 72 | ProcessorGroup.dag = [ 73 | (@synths.Harmonic(), 74 | ['amps', 'harmonic_distribution', 'f0_hz']), 75 | (@synths.FilteredNoise(), 76 | ['noise_magnitudes']), 77 | (@processors.Add(), 78 | ['filtered_noise/signal', 'harmonic/signal']), 79 | (@effects.FilteredNoiseReverb(), 80 | ['add/signal']), 81 | (@processors.Crop(), 82 | ['reverb/signal']) 83 | ] 84 | 85 | # Reverb 86 | FilteredNoiseReverb: 87 | name = 'reverb' 88 | reverb_length = 72000 89 | n_frames = 500 90 | n_filter_banks = 32 91 | initial_bias = -4.0 92 | trainable = True 93 | 94 | # Harmonic Synthesizer 95 | Harmonic: 96 | name = 'harmonic' 97 | n_samples = %n_samples 98 | sample_rate = %sample_rate 99 | normalize_below_nyquist = True 100 | scale_fn = @core.exp_sigmoid 101 | amp_resample_method = 'linear' 102 | use_angular_cumsum = True # Necessary at 48k as oscillators precess more. 103 | 104 | # Filtered Noise Synthesizer 105 | FilteredNoise: 106 | name = 'filtered_noise' 107 | n_samples = %n_samples 108 | window_size = 0 109 | scale_fn = @core.exp_sigmoid 110 | 111 | # Add 112 | processors.Add.name = 'add' 113 | 114 | # Remove the extra frame of synthesis from centering. 115 | # Since generation is forward. 116 | # Frame size is the frame of the "synthesis" which is just the hop size. 117 | Crop: 118 | frame_size = 960 119 | crop_location = 'back' 120 | 121 | -------------------------------------------------------------------------------- /ddsp/training/gin/optimization/README.md: -------------------------------------------------------------------------------- 1 | # Gin Configs: Optimization 2 | 3 | This directory contains gin configs for training hyperparameters such as `batch_size` 4 | and `learning_rate`. Default values are loaded via `ddsp_run.py` and can be replaced with the `--gin_param` flag. 5 | -------------------------------------------------------------------------------- /ddsp/training/gin/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/optimization/base.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | import ddsp 3 | import ddsp.training 4 | 5 | # Globals for easier configuration with --gin_param 6 | learning_rate = 3e-4 7 | batch_size = 32 8 | 9 | train.batch_size = %batch_size 10 | train.num_steps = 1000000 11 | train.steps_per_summary = 300 12 | train.steps_per_save = 300 13 | 14 | Trainer.learning_rate = %learning_rate 15 | Trainer.lr_decay_steps = 10000 16 | Trainer.lr_decay_rate = 0.98 17 | Trainer.grad_clip_norm = 3.0 18 | Trainer.checkpoints_to_keep = 100 19 | -------------------------------------------------------------------------------- /ddsp/training/gin/optimization/base_tpu.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'optimization/base.gin' 3 | 4 | # Larger batch size for TPU. 5 | batch_size = 64 # (4x2, 4 per a core) 6 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/README.md: -------------------------------------------------------------------------------- 1 | # Gin Configs: Papers 2 | 3 | This directory contains gin configs for replicating the experiments in published 4 | papers using ddsp. 5 | 6 | ## List of papers 7 | 8 | * [iclr2020](./iclr2020/): Experiments from the orginal DDSP ICLR 2020 paper ([paper](https://openreview.net/forum?id=B1x1ma4tDr), [blog](https://magenta.tensorflow.org/ddsp)). 9 | 10 | ```latex 11 | @inproceedings{ 12 | engel2020ddsp, 13 | title={DDSP: Differentiable Digital Signal Processing}, 14 | author={Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts}, 15 | booktitle={International Conference on Learning Representations}, 16 | year={2020}, 17 | url={https://openreview.net/forum?id=B1x1ma4tDr} 18 | } 19 | ``` 20 | 21 | * [icml2020](./icml2020/): Experiments from the ICML 2020 SAS Workshop paper on Inverse Audio Synthesis ([paper](https://goo.gl/magenta/ddsp-inv)). 22 | 23 | ```latex 24 | @inproceedings{ 25 | engel2020ddsp, 26 | title={Self-supervised Pitch Detection with Inverse Audio Synthesis}, 27 | author={Jesse Engel and Rigel Swavely and Lamtharn (Hanoi) Hantrakul and Adam Roberts and Curtis Hawthorne}, 28 | booktitle={International Conference on Machine Learning, Self-supervised Audio and Speech Workshop}, 29 | year={2020}, 30 | url={https://openreview.net/forum?id=RlVTYWhsky7} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/iclr2020/README.md: -------------------------------------------------------------------------------- 1 | # DDSP: Differentiable Digital Signal Processing 2 | _ICLR 2020, ([paper](g.co/magenta/ddsp))_ 3 | 4 | Gin configs for reproducing the results of the paper. 5 | 6 | Each config file specifies both the dataset and the model, so only a single --gin_file flag is required. 7 | 8 | ## Configs 9 | 10 | * `nsynth_ae.gin`: Autoencoder (with latent z, and f0 given by CREPE) on NSynth dataset. 11 | * `nsynth_ae_abs.gin`: Deprecated. Please install v0.0.6 to use this config. Autoencoder (with latent z, and f0 inferred by model) on NSynth dataset. For improved version, see [icml2020](./../icml2020/) paper. 12 | * `solo_instrument.gin`: Decoder (with no z, and f0 given by CREPE) on your own dataset of a monophonic instrument. Make dataset with `ddsp/training/data_preparation/ddsp_prepare_tfrecord.py`. 13 | * `tiny_instrument.gin`: Same as `solo_instrument.gin`, but with much smaller model. 14 | 15 | 16 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/iclr2020/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/iclr2020/nsynth_ae.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'models/ae.gin' 3 | include 'datasets/nsynth.gin' 4 | include 'eval/basic_f0_ld.gin' 5 | 6 | # To recreate original experiment optimization params, uncomment lines below. 7 | # learning_rate = 1e-5 8 | # batch_size = 128 9 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/iclr2020/solo_instrument.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # Make dataset with ddsp/training/data_preparation/ddsp_prepare_tfrecord.py 3 | # --gin_param="TFRecordProvider.file_pattern='/path/to/dataset*.tfrecord'" 4 | 5 | include 'models/solo_instrument.gin' 6 | include 'datasets/tfrecord.gin' 7 | include 'eval/basic_f0_ld.gin' 8 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/iclr2020/tiny_instrument.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | # The 'Tiny Model' config shown in 3 | # https://storage.googleapis.com/ddsp/index.html#tiny 4 | include 'papers/iclr2020/solo_instrument.gin' 5 | 6 | RnnFcDecoder.layers_per_stack = 0 7 | RnnFcDecoder.rnn_channels = 256 8 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/icml2020/README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Pitch Detection by Inverse Audio Synthesis 2 | _ICML SAS Workshop 2020_ ([paper](https://openreview.net/forum?id=RlVTYWhsky7)) 3 | 4 | Instructions for reproducing the results of the paper. 5 | 6 | ## Generate synthetic dataset 7 | 8 | After pip installing ddsp, `pip install -U ddsp[data_preparation]`, you can create a synthetic dataset for the transcribing autoencoder using the installed script. 9 | The command below creates a very small dataset as a test. 10 | 11 | ```bash 12 | ddsp_generate_synthetic_dataset \ 13 | --output_tfrecord_path=/tmp/synthetic_data.tfrecord \ 14 | --gin_param="generate_examples.generate_fn=@data_preparation.synthetic_data.generate_notes_v2" \ 15 | --gin_param="generate_notes_v2.n_timesteps=125" \ 16 | --gin_param="generate_notes_v2.n_harmonics=100" \ 17 | --gin_param="generate_notes_v2.n_mags=65" \ 18 | --num_examples=1000 \ 19 | --num_shards=2 \ 20 | --alsologtostderr 21 | ``` 22 | 23 | The synthetic dataset in the paper has 10,000,000 examples split into 1000 shards. 24 | Because that's a large dataset, we have provided sharded TFRecord files for the synthetic on GCP at `gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*`. 25 | As the name indicates, the dataset was made with 125 timesteps, 100 harmonics, and 65 noise bins. 26 | If training on GCP it is fast to directly read from these buckets, but if training locally you will probably want to download the files locally (~1.7 TB) using the `gsutil` command line utility from the [gcloud sdk](https://cloud.google.com/sdk/docs/downloads-interactive). 27 | 28 | 29 | If you wish to make your own large dataset, it may require runinng in a distributed setup such as [Dataflow](https://cloud.google.com/dataflow) on GCP. 30 | If running on DataFlow, you'll need to set the `--pipeline_options` flag using the execution parameters described at https://cloud.google.com/dataflow/docs/guides/specifying-exec-params 31 | E.g., `--pipeline_options=--runner=DataFlowRunner,--project=,--temp_location=gs://,--region=us-central1,--job_name=`... 32 | 33 | ## Pretrain transcribing autoencoder 34 | 35 | ### Train 36 | Point the dataset file_pattern to the synthetic dataset created above. 37 | 38 | ```bash 39 | ddsp_run \ 40 | --mode=train \ 41 | --save_dir=/tmp/$USER-tae-pretrain-0 \ 42 | --gin_file=papers/icml2020/pretrain_model.gin \ 43 | --gin_file=papers/icml2020/pretrain_dataset.gin \ 44 | --gin_param="SyntheticNotes.file_pattern='gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*'" \ 45 | --gin_param="batch_size=64" \ 46 | --alsologtostderr 47 | ``` 48 | 49 | This command points to datasets on GCP, and the gin_params for file_pattern are redudant with the default values in the gin files, but provided here to show how you would modify them for local dataset paths. 50 | 51 | In the paper we train for ~1.2M steps with a batch size of 64. A single v100 can fit a max batch size of 32, so you will need to use multiple gpus to exactly reproduce the experiment. Given the large amount of pretraining, a pretrained checkpoint [is available here](https://storage.googleapis.com/ddsp-inv/ckpts/synthetic_pretrained_ckpt.zip) 52 | or on GCP at `gs://ddsp-inv/ckpts/synthetic_pretrained_ckpt`. 53 | 54 | ### Eval and Sample 55 | 56 | This will add evaluation metrics as scalar summaries for tensorboard. 57 | 58 | ```bash 59 | ddsp_run \ 60 | --mode=eval \ 61 | --run_once \ 62 | --restore_dir=/path/to/trained_ckpt \ 63 | --save_dir=/path/to/trained_ckpt \ 64 | --gin_file=papers/icml2020/pretrain_dataset.gin \ 65 | --alsologtostderr 66 | ``` 67 | 68 | This will add image and audio scalar summaries for tensorboard. 69 | 70 | ```bash 71 | ddsp_run \ 72 | --mode=sample \ 73 | --run_once \ 74 | --restore_dir=/path/to/trained_ckpt \ 75 | --save_dir=/path/to/trained_ckpt \ 76 | --gin_file=papers/icml2020/pretrain_dataset.gin \ 77 | --alsologtostderr 78 | ``` 79 | 80 | 81 | ## Finetuning transcribing autoencoder 82 | 83 | ### Train 84 | Now we finetune the model from above on a specific dataset. Use the `--restore_dir` flag to point to your pretrained checkpoint. 85 | 86 | A pretrained model on 1.2M steps (batch size=64) of synthetic data [is available here](https://storage.googleapis.com/ddsp-inv/ckpts/synthetic_pretrained_ckpt.zip) 87 | or on GCP. 88 | 89 | ```bash 90 | gsutil cp -r gs://ddsp-inv/ckpts/synthetic_pretrained_ckpt /path/to/synthetic_pretrained_ckpt 91 | ``` 92 | 93 | ```bash 94 | ddsp_run \ 95 | --mode=train \ 96 | --restore_dir=/path/to/synthetic_pretrained_ckpt \ 97 | --save_dir=/tmp/$USER-tae-finetune-0 \ 98 | --gin_file=papers/icml2020/finetune_model.gin \ 99 | --gin_file=papers/icml2020/finetune_dataset.gin \ 100 | --gin_param="SyntheticNotes.file_pattern='gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*'" \ 101 | --gin_param="train_data/TFRecordProvider.file_pattern='gs://ddsp-inv/datasets/all_instruments_train.tfrecord*'" \ 102 | --gin_param="test_data/TFRecordProvider.file_pattern='gs://ddsp-inv/datasets/all_instruments_test.tfrecord*'" \ 103 | --gin_param="batch_size=64" \ 104 | --alsologtostderr 105 | ``` 106 | 107 | This command points to datasets on GCP, and the gin_params for file_pattern are redudant with the default values in the gin files, but provided here to show how you would modify them for local dataset paths. 108 | We have provided sharded TFRecord files for the [URMP dataset](http://www2.ece.rochester.edu/projects/air/projects/URMP/annotations_5P.html) on GCP at `gs://ddsp-inv/datasets/all_instruments_train.tfrecord*` and `gs://ddsp-inv/datasets/all_instruments_test.tfrecord*`. 109 | If training on GCP it is fast to directly read from these buckets, but if training locally you will probably want to download the files locally (~16 GB) using the `gsutil` command line utility from the [gcloud sdk](https://cloud.google.com/sdk/docs/downloads-interactive). 110 | 111 | 112 | In the paper, this model was trained with a batch size of 64 on 8 accelerators (8 per an accelerator), and typically converges after 200-400k iterations. A single v100 can fit a max batch size of 12, so you will need to use multiple GPUs or TPUs to exactly reproduce the experiment. To use a TPU, start up an instance from the web interface and pass the internal ip address to the tpu flag `--tpu=grpc://`. 113 | 114 | 115 | Finetuned models for +400k steps (batch size=64) are available on GCP the 116 | [URMP](http://www2.ece.rochester.edu/projects/air/projects/URMP/annotations_5P.html) ([ckpt](https://storage.googleapis.com/ddsp-inv/ckpts/urmp_ckpt.zip)), 117 | [MDB-stem-synth](https://zenodo.org/record/1481172#.Xzouy5NKhTY) ([ckpt](https://storage.googleapis.com/ddsp-inv/ckpts/mdb_stem_synth_ckpt.zip)), 118 | and [MIR1k](https://sites.google.com/site/unvoicedsoundseparation/mir-1k) ([ckpt](https://storage.googleapis.com/ddsp-inv/ckpts/mir1k_ckpt.zip)) datasets. 119 | 120 | ### Eval and Sample 121 | 122 | This will add evaluation metrics as scalar summaries for tensorboard. 123 | 124 | ```bash 125 | ddsp_run \ 126 | --mode=eval \ 127 | --run_once \ 128 | --restore_dir=/path/to/trained_ckpt \ 129 | --save_dir=/path/to/trained_ckpt \ 130 | --gin_file=papers/icml2020/finetune_dataset.gin \ 131 | --alsologtostderr 132 | ``` 133 | 134 | This will add image and audio scalar summaries for tensorboard. 135 | 136 | ```bash 137 | ddsp_run \ 138 | --mode=sample \ 139 | --run_once \ 140 | --restore_dir=/path/to/trained_ckpt \ 141 | --save_dir=/path/to/trained_ckpt \ 142 | --gin_file=papers/icml2020/finetune_dataset.gin \ 143 | --alsologtostderr 144 | ``` 145 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/icml2020/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/icml2020/finetune_dataset.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'papers/icml2020/pretrain_dataset.gin' 3 | 4 | # Datasets 5 | train.data_provider = @train_data/data.ZippedProvider() 6 | evaluate.data_provider = @test_data/data.TFRecordProvider() 7 | sample.data_provider = @test_data/data.TFRecordProvider() 8 | 9 | 10 | 11 | # Training 12 | train_data/ZippedProvider.data_providers = [ 13 | @data.SyntheticNotes(), 14 | @train_data/data.TFRecordProvider(), 15 | ] 16 | 17 | 18 | # Make dataset with ddsp/training/data_preparation/ddsp_prepare_tfrecord.py 19 | # --gin_param="TFRecordProvider.file_pattern='/path/to/dataset*.tfrecord'" 20 | 21 | train_data/TFRecordProvider.file_pattern = 'gs://ddsp-inv/datasets/all_instruments_train.tfrecord*' 22 | test_data/TFRecordProvider.file_pattern = 'gs://ddsp-inv/datasets/all_instruments_test.tfrecord*' 23 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/icml2020/finetune_model.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | include 'papers/icml2020/pretrain_model.gin' 3 | 4 | SpectralLoss.mag_weight = 1.0 5 | SpectralLoss.logmag_weight = 1.0 6 | 7 | KDEConsistencyLoss.weight_mean_amp = 0.1 8 | KDEConsistencyLoss.weight_a = 0.1 9 | KDEConsistencyLoss.weight_b = 0.1 10 | KDEConsistencyLoss.scale_a = 0.1 11 | KDEConsistencyLoss.scale_b = 0.1 12 | 13 | FilteredNoiseConsistencyLoss.weight = 100.0 14 | 15 | HarmonicConsistencyLoss.amp_weight = 10.0 16 | HarmonicConsistencyLoss.dist_weight = 100.0 17 | HarmonicConsistencyLoss.f0_weight = 1.0 18 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/icml2020/pretrain_dataset.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | import ddsp.training 3 | 4 | 5 | # Evaluate 6 | evaluate.batch_size = 8 7 | evaluate.num_batches = 256 8 | 9 | 10 | # Sample 11 | sample.batch_size = 16 12 | sample.num_batches = 1 13 | sample.ckpt_delay_secs = 300 # 5 minutes 14 | 15 | 16 | # Dataset 17 | train.data_provider = @data.SyntheticNotes() 18 | evaluate.data_provider = @data.SyntheticNotes() 19 | sample.data_provider = @data.SyntheticNotes() 20 | 21 | 22 | ## Create a synthetic dataset with ddsp/training/data_preparation/ddsp_generate_synthetic_data.py 23 | # Synthetic data generator. 24 | data.SyntheticNotes.file_pattern = 'gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*' 25 | data.SyntheticNotes.n_timesteps = 125 26 | data.SyntheticNotes.n_harmonics = 100 27 | data.SyntheticNotes.n_mags = 65 28 | -------------------------------------------------------------------------------- /ddsp/training/gin/papers/icml2020/pretrain_model.gin: -------------------------------------------------------------------------------- 1 | # -*-Python-*- 2 | 3 | import ddsp 4 | import ddsp.training 5 | 6 | include 'eval/basic_f0_ld_twm.gin' 7 | 8 | # ===== 9 | # Model 10 | # ===== 11 | get_model.model = @models.InverseSynthesis() 12 | 13 | InverseSynthesis.reverb = False 14 | 15 | # Sinusoidal Encoder 16 | InverseSynthesis.sinusoidal_encoder = @encoders.ResnetSinusoidalEncoder() 17 | ResnetSinusoidalEncoder.size = 'small' 18 | ResnetSinusoidalEncoder.output_splits = (('frequencies', 6400), 19 | ('amplitudes', 100), 20 | ('noise_magnitudes', 65)) 21 | 22 | 23 | # FFT input parameters from onsets and frames transcription experiments. 24 | ResnetSinusoidalEncoder.spectral_fn = @f0_spectral/spectral_ops.compute_logmel 25 | f0_spectral/compute_logmel.lo_hz = 0.0 26 | f0_spectral/compute_logmel.hi_hz = 8000.0 27 | f0_spectral/compute_logmel.bins = 229 28 | f0_spectral/compute_logmel.fft_size = 2048 29 | f0_spectral/compute_logmel.overlap = 0.75 30 | f0_spectral/compute_logmel.pad_end = True 31 | 32 | # Harmonic Encoder 33 | InverseSynthesis.harmonic_encoder = @encoders.SinusoidalToHarmonicEncoder() 34 | SinusoidalToHarmonicEncoder.net = @nn.RnnSandwich() 35 | 36 | 37 | # Audio Losses 38 | InverseSynthesis.losses = [ 39 | @losses.SpectralLoss(), 40 | ] 41 | SpectralLoss.loss_type = 'L1' 42 | SpectralLoss.mag_weight = 0.0 43 | SpectralLoss.logmag_weight = 0.0 44 | 45 | 46 | # Sinusoidal Consistency loss: 47 | InverseSynthesis.sinusoidal_consistency_losses = @losses.KDEConsistencyLoss() 48 | KDEConsistencyLoss.weight_a = 1.0 49 | KDEConsistencyLoss.weight_b = 1.0 50 | KDEConsistencyLoss.scale_a = 0.1 51 | KDEConsistencyLoss.scale_b = 0.1 52 | 53 | 54 | # Harmonic Consistency Loss. 55 | InverseSynthesis.harmonic_consistency_losses = @losses.HarmonicConsistencyLoss() 56 | HarmonicConsistencyLoss.amp_weight = 1.0 57 | HarmonicConsistencyLoss.dist_weight = 1.0 58 | HarmonicConsistencyLoss.f0_weight = 1.0 59 | 60 | 61 | # Filtered Noise Consistency loss: 62 | InverseSynthesis.filtered_noise_consistency_loss = @losses.FilteredNoiseConsistencyLoss() 63 | FilteredNoiseConsistencyLoss.weight = 1.0 64 | -------------------------------------------------------------------------------- /ddsp/training/heuristics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.training.heuristics.""" 16 | 17 | from absl.testing import parameterized 18 | from ddsp.training import heuristics 19 | import numpy as np 20 | import tensorflow.compat.v2 as tf 21 | 22 | 23 | class HeuristicsTest(parameterized.TestCase): 24 | 25 | def test_pad_for_frame(self): 26 | signal = np.zeros((10,)) 27 | signal[0] = -1 28 | signal[-1] = 1 29 | frame_width = 4 30 | front_padded_signal = heuristics.pad_for_frame( 31 | signal, mode='front', frame_width=frame_width) 32 | framed = tf.signal.frame(front_padded_signal, frame_width, 1) 33 | self.assertEqual(framed.shape[0], signal.shape[0]) 34 | np.testing.assert_array_equal(framed[0], [-1, -1, -1, -1]) 35 | np.testing.assert_array_equal(framed[-1], [0, 0, 0, 1]) 36 | 37 | back_padded_signal = heuristics.pad_for_frame( 38 | signal, mode='end', frame_width=frame_width) 39 | framed = tf.signal.frame(back_padded_signal, frame_width, 1) 40 | self.assertEqual(framed.shape[0], signal.shape[0]) 41 | np.testing.assert_array_equal(framed[0], [-1, 0, 0, 0]) 42 | np.testing.assert_array_equal(framed[-1], [1, 1, 1, 1]) 43 | 44 | center_padded_signal = heuristics.pad_for_frame( 45 | signal, mode='center', frame_width=frame_width) 46 | framed = tf.signal.frame(center_padded_signal, frame_width, 1) 47 | self.assertEqual(framed.shape[0], signal.shape[0]) 48 | np.testing.assert_array_equal(framed[0], [-1, -1, -1, 0]) 49 | np.testing.assert_array_equal(framed[-1], [0, 0, 1, 1]) 50 | 51 | def test_pad_for_frame_odd(self): 52 | signal = np.zeros((10,)) 53 | signal[0] = -1 54 | signal[-1] = 1 55 | frame_width = 3 56 | 57 | center_padded_signal = heuristics.pad_for_frame( 58 | signal, mode='center', frame_width=frame_width) 59 | framed = tf.signal.frame(center_padded_signal, frame_width, 1) 60 | self.assertEqual(framed.shape[0], signal.shape[0]) 61 | np.testing.assert_array_equal(framed[0], [-1, -1, 0]) 62 | np.testing.assert_array_equal(framed[-1], [0, 1, 1]) 63 | 64 | @parameterized.parameters( 65 | (10.0, 2.0, 0.5, (11, 32000)), 66 | (10.0, 10.0, 0.5, (3, 160000)), 67 | (10.0, 1.0, 0.5, (21, 16000)), 68 | (10.0, 1.0, 0.75, (14, 16000)), 69 | ) 70 | def test_window_array(self, dur, win_len, frame_step_ratio, expected_shape): 71 | sr = 16000 72 | sw = np.sin(np.linspace(start=0.0, stop=dur, num=int(sr*dur))) 73 | 74 | wa = heuristics.window_array(sw, sr, win_len, frame_step_ratio) 75 | self.assertSequenceEqual(wa.shape, expected_shape) 76 | self.assertEqual(wa.shape[-1], int(sr*win_len)) 77 | 78 | 79 | if __name__ == '__main__': 80 | tf.test.main() 81 | -------------------------------------------------------------------------------- /ddsp/training/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module with all the global configurable models for training.""" 16 | 17 | from ddsp.training.models.autoencoder import Autoencoder 18 | from ddsp.training.models.inverse_synthesis import InverseSynthesis 19 | from ddsp.training.models.midi_autoencoder import MidiAutoencoder 20 | from ddsp.training.models.midi_autoencoder import ZMidiAutoencoder 21 | from ddsp.training.models.model import Model 22 | import gin 23 | 24 | _configurable = lambda cls: gin.configurable(cls, module=__name__) 25 | 26 | Autoencoder = _configurable(Autoencoder) 27 | InverseSynthesis = _configurable(InverseSynthesis) 28 | MidiAutoencoder = _configurable(MidiAutoencoder) 29 | ZMidiAutoencoder = _configurable(ZMidiAutoencoder) 30 | 31 | 32 | 33 | @gin.configurable 34 | def get_model(model=gin.REQUIRED): 35 | """Gin configurable function get a 'global' model for use in ddsp_run.py. 36 | 37 | Convenience for using the same model in train(), evaluate(), and sample(). 38 | Args: 39 | model: An instantiated model, such as 'models.Autoencoder()'. 40 | 41 | Returns: 42 | The 'global' model specified in the gin config. 43 | """ 44 | return model 45 | -------------------------------------------------------------------------------- /ddsp/training/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model that encodes audio features and decodes with a ddsp processor group.""" 16 | 17 | import ddsp 18 | from ddsp.training.models.model import Model 19 | 20 | 21 | class Autoencoder(Model): 22 | """Wrap the model function for dependency injection with gin.""" 23 | 24 | def __init__(self, 25 | preprocessor=None, 26 | encoder=None, 27 | decoder=None, 28 | processor_group=None, 29 | losses=None, 30 | **kwargs): 31 | super().__init__(**kwargs) 32 | self.preprocessor = preprocessor 33 | self.encoder = encoder 34 | self.decoder = decoder 35 | self.processor_group = processor_group 36 | self.loss_objs = ddsp.core.make_iterable(losses) 37 | 38 | def encode(self, features, training=True): 39 | """Get conditioning by preprocessing then encoding.""" 40 | if self.preprocessor is not None: 41 | features.update(self.preprocessor(features, training=training)) 42 | if self.encoder is not None: 43 | features.update(self.encoder(features)) 44 | return features 45 | 46 | def decode(self, features, training=True): 47 | """Get generated audio by decoding than processing.""" 48 | features.update(self.decoder(features, training=training)) 49 | return self.processor_group(features) 50 | 51 | def get_audio_from_outputs(self, outputs): 52 | """Extract audio output tensor from outputs dict of call().""" 53 | return outputs['audio_synth'] 54 | 55 | def call(self, features, training=True): 56 | """Run the core of the network, get predictions and loss.""" 57 | features = self.encode(features, training=training) 58 | features.update(self.decoder(features, training=training)) 59 | 60 | # Run through processor group. 61 | pg_out = self.processor_group(features, return_outputs_dict=True) 62 | 63 | # Parse outputs 64 | outputs = pg_out['controls'] 65 | outputs['audio_synth'] = pg_out['signal'] 66 | 67 | if training: 68 | self._update_losses_dict( 69 | self.loss_objs, features['audio'], outputs['audio_synth']) 70 | 71 | return outputs 72 | 73 | -------------------------------------------------------------------------------- /ddsp/training/models/autoencoder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.training.models.autoencoder.""" 16 | 17 | from absl.testing import parameterized 18 | from ddsp.core import tf_float32 19 | from ddsp.training import models 20 | import gin 21 | import numpy as np 22 | import pkg_resources 23 | import tensorflow as tf 24 | 25 | GIN_PATH = pkg_resources.resource_filename(__name__, '../gin') 26 | gin.add_config_file_search_path(GIN_PATH) 27 | 28 | 29 | class AutoencoderTest(parameterized.TestCase, tf.test.TestCase): 30 | 31 | @parameterized.named_parameters( 32 | ('nsynth_ae', 16000, 250, False, 'papers/iclr2020/nsynth_ae.gin'), 33 | ('solo_instrument', 16000, 250, False, 34 | 'papers/iclr2020/solo_instrument.gin'), 35 | ('vst_16kHz', 16000, 50, True, 'models/vst/vst.gin'), 36 | ('vst_32kHz', 32000, 50, True, 'models/vst/vst_32k.gin'), 37 | ('vst_48kHz', 48000, 50, True, 'models/vst/vst_48k.gin'), 38 | ) 39 | def test_build_model(self, sample_rate, frame_rate, centered, gin_file): 40 | """Tests if Model builds properly and produces audio of correct shape. 41 | 42 | Args: 43 | sample_rate: Sample rate of audio. 44 | frame_rate: Frame rate of features. 45 | centered: Add an additional frame. 46 | gin_file: Name of gin_file to use. 47 | """ 48 | n_batch = 1 49 | n_secs = 4 50 | # frame_size = sample_rate // frame_rate 51 | n_frames = int(frame_rate * n_secs) 52 | n_samples = int(sample_rate * n_secs) 53 | n_samples_16k = int(16000 * n_secs) 54 | if centered: 55 | n_frames += 1 56 | # n_samples += frame_size 57 | # n_samples_16k += 320 58 | 59 | inputs = { 60 | 'loudness_db': np.zeros([n_batch, n_frames]), 61 | 'f0_hz': np.zeros([n_batch, n_frames]), 62 | 'f0_confidence': np.zeros([n_batch, n_frames]), 63 | 'audio': np.random.randn(n_batch, n_samples), 64 | 'audio_16k': np.random.randn(n_batch, n_samples_16k), 65 | } 66 | inputs = {k: tf_float32(v) for k, v in inputs.items()} 67 | 68 | with gin.unlock_config(): 69 | gin.clear_config() 70 | gin.parse_config_file(gin_file) 71 | 72 | model = models.Autoencoder() 73 | outputs = model(inputs) 74 | self.assertIsInstance(outputs, dict) 75 | # Confirm that model generates correctly sized audio. 76 | audio_gen = model.get_audio_from_outputs(outputs) 77 | audio_gen_shape = audio_gen.shape.as_list() 78 | self.assertEqual(audio_gen_shape, list(inputs['audio'].shape)) 79 | 80 | 81 | if __name__ == '__main__': 82 | tf.test.main() 83 | -------------------------------------------------------------------------------- /ddsp/training/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model base class.""" 16 | 17 | import time 18 | 19 | from absl import logging 20 | import ddsp 21 | from ddsp.core import copy_if_tf_function 22 | from ddsp.training import train_util 23 | import tensorflow as tf 24 | 25 | 26 | class Model(tf.keras.Model): 27 | """Base class for all models.""" 28 | 29 | def __init__(self, **kwargs): 30 | super().__init__(**kwargs) 31 | self._losses_dict = {} 32 | 33 | def __call__(self, *args, return_losses=False, **kwargs): 34 | """Reset the losses dict on each call. 35 | 36 | Args: 37 | *args: Arguments passed on to call(). 38 | return_losses: Return a dictionary of losses in addition to the call() 39 | function returns. 40 | **kwargs: Keyword arguments passed on to call(). 41 | 42 | Returns: 43 | outputs: A dictionary of model outputs generated in call(). 44 | {output_name: output_tensor or dict}. 45 | losses: If return_losses=True, also returns a dictionary of losses, 46 | {loss_name: loss_value}. 47 | """ 48 | # Copy mutable dicts if in graph mode to prevent side-effects (pure func). 49 | args = [copy_if_tf_function(a) if isinstance(a, dict) else a for a in args] 50 | 51 | # Run model. 52 | self._losses_dict = {} 53 | outputs = super().__call__(*args, **kwargs) 54 | 55 | # Get total loss. 56 | if not return_losses: 57 | return outputs 58 | else: 59 | self._losses_dict['total_loss'] = self.sum_losses(self._losses_dict) 60 | return outputs, self._losses_dict 61 | 62 | def sum_losses(self, losses_dict): 63 | """Sum all the scalar losses in a dictionary.""" 64 | return tf.reduce_sum(list(losses_dict.values())) 65 | 66 | def _update_losses_dict(self, loss_objs, *args, **kwargs): 67 | """Helper function to run loss objects on args and add to model losses.""" 68 | for loss_obj in ddsp.core.make_iterable(loss_objs): 69 | if hasattr(loss_obj, 'get_losses_dict'): 70 | losses_dict = loss_obj.get_losses_dict(*args, **kwargs) 71 | self._losses_dict.update(losses_dict) 72 | 73 | def restore(self, checkpoint_path, verbose=True, restore_keys=None): 74 | """Restore model and optimizer from a checkpoint. 75 | 76 | Args: 77 | checkpoint_path: Path to checkpoint file or directory. 78 | verbose: Warn about missing variables. 79 | restore_keys: Optional list of strings for submodules to restore. 80 | 81 | Raises: 82 | FileNotFoundError: If no checkpoint is found. 83 | """ 84 | start_time = time.time() 85 | latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path) 86 | 87 | if restore_keys is None: 88 | # If no keys are passed, restore the whole model. 89 | checkpoint = tf.train.Checkpoint(model=self) 90 | logging.info('Model restoring all components.') 91 | if verbose: 92 | checkpoint.restore(latest_checkpoint) 93 | else: 94 | checkpoint.restore(latest_checkpoint).expect_partial() 95 | 96 | else: 97 | # Restore only sub-modules by building a new subgraph. 98 | # Following https://www.tensorflow.org/guide/checkpoint#loading_mechanics. 99 | logging.info('Trainer restoring model subcomponents:') 100 | for k in restore_keys: 101 | to_restore = {k: getattr(self, k)} 102 | log_str = 'Restoring {}'.format(to_restore) 103 | logging.info(log_str) 104 | fake_model = tf.train.Checkpoint(**to_restore) 105 | new_root = tf.train.Checkpoint(model=fake_model) 106 | status = new_root.restore(latest_checkpoint) 107 | status.assert_existing_objects_matched() 108 | 109 | logging.info('Loaded checkpoint %s', latest_checkpoint) 110 | logging.info('Loading model took %.1f seconds', time.time() - start_time) 111 | 112 | def get_audio_from_outputs(self, outputs): 113 | """Extract audio output tensor from outputs dict of call().""" 114 | raise NotImplementedError('Must implement `self.get_audio_from_outputs()`.') 115 | 116 | def call(self, *args, training=False, **kwargs): 117 | """Run the forward pass, add losses, and create a dictionary of outputs. 118 | 119 | This function must run the forward pass, add losses to self._losses_dict and 120 | return a dictionary of all the relevant output tensors. 121 | 122 | Args: 123 | *args: Args for forward pass. 124 | training: Required `training` kwarg passed in by keras. 125 | **kwargs: kwargs for forward pass. 126 | 127 | Returns: 128 | Dictionary of all relevant tensors. 129 | """ 130 | raise NotImplementedError('Must implement a `self.call()` method.') 131 | 132 | 133 | -------------------------------------------------------------------------------- /ddsp/training/plotting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Plotting utilities for the DDSP library. Useful in colab and elsewhere.""" 16 | 17 | from ddsp import core 18 | from ddsp import spectral_ops 19 | from matplotlib import gridspec 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import tensorflow.compat.v2 as tf 23 | 24 | DEFAULT_SAMPLE_RATE = spectral_ops.CREPE_SAMPLE_RATE 25 | 26 | 27 | def specplot(audio, 28 | vmin=-5, 29 | vmax=1, 30 | rotate=True, 31 | size=512 + 256, 32 | **matshow_kwargs): 33 | """Plot the log magnitude spectrogram of audio.""" 34 | # If batched, take first element. 35 | if len(audio.shape) == 2: 36 | audio = audio[0] 37 | 38 | logmag = spectral_ops.compute_logmag(core.tf_float32(audio), size=size) 39 | if rotate: 40 | logmag = np.rot90(logmag) 41 | # Plotting. 42 | plt.matshow(logmag, 43 | vmin=vmin, 44 | vmax=vmax, 45 | cmap=plt.cm.magma, 46 | aspect='auto', 47 | **matshow_kwargs) 48 | plt.xticks([]) 49 | plt.yticks([]) 50 | plt.xlabel('Time') 51 | plt.ylabel('Frequency') 52 | 53 | 54 | def transfer_function(ir, sample_rate=DEFAULT_SAMPLE_RATE): 55 | """Get true transfer function from an impulse_response.""" 56 | n_fft = core.get_fft_size(0, ir.shape.as_list()[-1]) 57 | frequencies = np.abs( 58 | np.fft.fftfreq(n_fft, 1 / sample_rate)[:int(n_fft / 2) + 1]) 59 | magnitudes = tf.abs(tf.signal.rfft(ir, [n_fft])) 60 | return frequencies, magnitudes 61 | 62 | 63 | def plot_impulse_responses(impulse_response, 64 | desired_magnitudes, 65 | sample_rate=DEFAULT_SAMPLE_RATE): 66 | """Plot a target frequency response, and that of an impulse response.""" 67 | n_fft = desired_magnitudes.shape[-1] * 2 68 | frequencies = np.fft.fftfreq(n_fft, 1 / sample_rate)[:n_fft // 2] 69 | true_frequencies, true_magnitudes = transfer_function(impulse_response) 70 | 71 | # Plot it. 72 | plt.figure(figsize=(12, 6)) 73 | plt.subplot(121) 74 | # Desired transfer function. 75 | plt.semilogy(frequencies, desired_magnitudes, label='Desired') 76 | # True transfer function. 77 | plt.semilogy(true_frequencies, true_magnitudes[0, 0, :], label='True') 78 | plt.title('Transfer Function') 79 | plt.legend() 80 | 81 | plt.subplot(122) 82 | plt.plot(impulse_response[0, 0, :]) 83 | plt.title('Impulse Response') 84 | 85 | 86 | def pianoroll_plot_setup(figsize=None, side_piano_ratio=0.025, 87 | faint_pr=True, xlim=None): 88 | """Makes a tiny piano left of the y-axis and a faint piano on the main figure. 89 | 90 | This function sets up the figure for pretty plotting a piano roll. It makes a 91 | small imshow plot to the left of the main plot that looks like a piano. This 92 | piano side plot is aligned along the y-axis of the main plot, such that y 93 | values align with MIDI values (y=0 is the lowest C-1, y=11 is C0, etc). 94 | Additionally, a main figure is set up that shares the y-axis of the piano side 95 | plot. Optionally, a set of faint horizontal lines are drawn on the main figure 96 | that correspond to the black keys on the piano (and a line separating B & C 97 | and E & F). This function returns the formatted figure, the side piano axis, 98 | and the main axis for plotting your data. 99 | 100 | By default, this will draw 11 octaves of piano keys along the y-axis; you will 101 | probably want reduce what is visible using `ax.set_ylim()` on either returned 102 | axis. 103 | 104 | Using with imshow piano roll data: 105 | A common use case is for using imshow() on the main axis to display a piano 106 | roll alongside the piano side plot AND the faint piano roll behind your 107 | data. In this case, if your data is a 2D array you have to use a masked 108 | numpy array to make certain values invisible on the plot, and therefore make 109 | the faint piano roll visible. Here's an example: 110 | 111 | midi = np.flipud([ 112 | [0.0, 0.0, 1.0], 113 | [0.0, 1.0, 0.0], 114 | [1.0, 0.0, 0.0], 115 | ]) 116 | 117 | midi_masked = np.ma.masked_values(midi, 0.0) # Mask out all 0.0's 118 | fig, ax, sp = plotting.pianoroll_plot_setup() 119 | ax.imshow(midi_masked, origin='lower', aspect='auto') # main subplot axis 120 | sp.set_ylabel('My favorite MIDI data') # side piano axis 121 | fig.show() 122 | 123 | The other option is to use imshow in RGBA mode, where your data is split 124 | into 4 channels. Every alpha value that is 0.0 will be transparent and show 125 | the faint piano roll below your data. 126 | 127 | Args: 128 | figsize: Size if the matplotlib figure. Will be passed to `plt.figure()`. 129 | Defaults to None. 130 | side_piano_ratio: Width of the y-axis piano in terms of raio of the whole 131 | figure. Defaults to 1/40th. 132 | faint_pr: Whether to draw faint black & white keys across the main plot. 133 | Defaults to True. 134 | xlim: Tuple containing the min and max of the x values for the main plot. 135 | Only used to determine the x limits for the faint piano roll in the main 136 | plot. Defaults to (0, 1000). 137 | 138 | Returns: 139 | (figure, main_axis, left_piano_axis) 140 | figure: A matplotlib figure object containing both subplots set up with an 141 | aligned piano roll. 142 | main_axis: A matplotlib axis object to be used for plotting. Optionally 143 | has a faint piano roll in the background. 144 | left_piano_axis: A matplotlib axis object that has a small, aligned piano 145 | along the left side y-axis of the main_axis subplot. 146 | """ 147 | octaves = 11 148 | 149 | # Setup figure and gridspec. 150 | fig = plt.figure(figsize=figsize) 151 | gs_ratio = int(1 / side_piano_ratio) 152 | gs = gridspec.GridSpec(1, 2, width_ratios=[1, gs_ratio]) 153 | left_piano_ax = fig.add_subplot(gs[0]) 154 | 155 | # Make a piano on the left side of the y-axis with imshow(). 156 | keys = np.array( 157 | [0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0] # notes in descending order; B -> C 158 | ) 159 | keys = np.tile(keys, octaves)[:, None] 160 | left_piano_ax.imshow(keys, cmap='binary', aspect='auto', 161 | extent=[0, 0.625, -0.5, octaves*12-0.5]) 162 | 163 | # Make the lines between keys. 164 | for i in range(octaves): 165 | left_piano_ax.hlines(i*12 - 0.5, -0.5, 1, colors='black', linewidth=0.5) 166 | left_piano_ax.hlines(i*12 + 1.0, -0.5, 1, colors='black', linewidth=0.5) 167 | left_piano_ax.hlines(i*12 + 3.0, -0.5, 1, colors='black', linewidth=0.5) 168 | left_piano_ax.hlines(i*12 + 4.5, -0.5, 1, colors='black', linewidth=0.5) 169 | left_piano_ax.hlines(i*12 + 6.0, -0.5, 1, colors='black', linewidth=0.5) 170 | left_piano_ax.hlines(i*12 + 8.0, -0.5, 1, colors='black', linewidth=0.5) 171 | left_piano_ax.hlines(i*12 + 10.0, -0.5, 1, colors='black', linewidth=0.5) 172 | 173 | # Set the limits of the side piano and remove ticks so it looks nice. 174 | left_piano_ax.set_xlim(0, 0.995) 175 | left_piano_ax.set_xticks([]) 176 | 177 | # Create the aligned axis we'll return to the user. 178 | main_ax = fig.add_subplot(gs[1], sharey=left_piano_ax) 179 | 180 | # Draw a faint piano roll behind the main axes (if the user wants). 181 | if faint_pr: 182 | xlim = (0, 1000) if xlim is None else xlim 183 | x_min, x_max = xlim 184 | x_delta = x_max - x_min 185 | main_ax.imshow(np.tile(keys, x_delta), cmap='binary', aspect='auto', 186 | alpha=0.05, extent=[x_min, x_max, -0.5, octaves*12-0.5]) 187 | for i in range(octaves): 188 | main_ax.hlines(i * 12 + 4.5, x_min, x_max, colors='black', 189 | linewidth=0.5, alpha=0.25) 190 | main_ax.hlines(i * 12 - 0.5, x_min, x_max, colors='black', 191 | linewidth=0.5, alpha=0.25) 192 | 193 | main_ax.set_xlim(*xlim) 194 | 195 | # Some final cosmetic tweaks before returning the axis obj's and figure. 196 | plt.setp(main_ax.get_yticklabels(), visible=False) 197 | gs.tight_layout(fig) 198 | return fig, main_ax, left_piano_ax 199 | -------------------------------------------------------------------------------- /ddsp/training/preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of preprocess functions.""" 16 | 17 | import ddsp 18 | from ddsp.training import nn 19 | import gin 20 | import tensorflow as tf 21 | 22 | F0_RANGE = ddsp.spectral_ops.F0_RANGE 23 | DB_RANGE = ddsp.spectral_ops.DB_RANGE 24 | 25 | tfkl = tf.keras.layers 26 | 27 | 28 | # ---------------------- Preprocess Helpers ------------------------------------ 29 | def at_least_3d(x): 30 | """Optionally adds time, batch, then channel dimension.""" 31 | x = x[tf.newaxis] if not x.shape else x 32 | x = x[tf.newaxis, :] if len(x.shape) == 1 else x 33 | x = x[:, :, tf.newaxis] if len(x.shape) == 2 else x 34 | return x 35 | 36 | 37 | def scale_db(db): 38 | """Scales [-DB_RANGE, 0] to [0, 1].""" 39 | return (db / DB_RANGE) + 1.0 40 | 41 | 42 | def inv_scale_db(db_scaled): 43 | """Scales [0, 1] to [-DB_RANGE, 0].""" 44 | return (db_scaled - 1.0) * DB_RANGE 45 | 46 | 47 | def scale_f0_hz(f0_hz): 48 | """Scales [0, Nyquist] Hz to [0, 1.0] MIDI-scaled.""" 49 | return ddsp.core.hz_to_midi(f0_hz) / F0_RANGE 50 | 51 | 52 | def inv_scale_f0_hz(f0_scaled): 53 | """Scales [0, 1.0] MIDI-scaled to [0, Nyquist] Hz.""" 54 | return ddsp.core.midi_to_hz(f0_scaled * F0_RANGE) 55 | 56 | 57 | # ---------------------- Preprocess objects ------------------------------------ 58 | @gin.register 59 | class F0LoudnessPreprocessor(nn.DictLayer): 60 | """Resamples and scales 'f0_hz' and 'loudness_db' features.""" 61 | 62 | def __init__(self, 63 | time_steps=1000, 64 | frame_rate=250, 65 | sample_rate=16000, 66 | compute_loudness=True, 67 | **kwargs): 68 | super().__init__(**kwargs) 69 | self.time_steps = time_steps 70 | self.frame_rate = frame_rate 71 | self.sample_rate = sample_rate 72 | self.compute_loudness = compute_loudness 73 | 74 | def call(self, loudness_db, f0_hz, audio=None) -> [ 75 | 'f0_hz', 'loudness_db', 'f0_scaled', 'ld_scaled']: 76 | # Compute loudness fresh (it's fast). 77 | if self.compute_loudness: 78 | loudness_db = ddsp.spectral_ops.compute_loudness( 79 | audio, 80 | sample_rate=self.sample_rate, 81 | frame_rate=self.frame_rate) 82 | 83 | # Resample features to the frame_rate. 84 | f0_hz = self.resample(f0_hz) 85 | loudness_db = self.resample(loudness_db) 86 | # For NN training, scale frequency and loudness to the range [0, 1]. 87 | # Log-scale f0 features. Loudness from [-1, 0] to [1, 0]. 88 | f0_scaled = scale_f0_hz(f0_hz) 89 | ld_scaled = scale_db(loudness_db) 90 | return f0_hz, loudness_db, f0_scaled, ld_scaled 91 | 92 | @staticmethod 93 | def invert_scaling(f0_scaled, ld_scaled): 94 | """Takes in scaled f0 and loudness, and puts them back to hz & db scales.""" 95 | f0_hz = inv_scale_f0_hz(f0_scaled) 96 | loudness_db = inv_scale_db(ld_scaled) 97 | return f0_hz, loudness_db 98 | 99 | def resample(self, x): 100 | x = at_least_3d(x) 101 | return ddsp.core.resample(x, self.time_steps) 102 | 103 | 104 | @gin.register 105 | class F0PowerPreprocessor(F0LoudnessPreprocessor): 106 | """Dynamically compute additional power_db feature.""" 107 | 108 | def __init__(self, 109 | time_steps=1000, 110 | frame_rate=250, 111 | sample_rate=16000, 112 | frame_size=64, 113 | **kwargs): 114 | super().__init__(time_steps, **kwargs) 115 | self.frame_rate = frame_rate 116 | self.sample_rate = sample_rate 117 | self.frame_size = frame_size 118 | 119 | def call(self, f0_hz, power_db=None, audio=None) -> [ 120 | 'f0_hz', 'pw_db', 'f0_scaled', 'pw_scaled']: 121 | """Compute power on the fly if it's not in the inputs.""" 122 | # For NN training, scale frequency and loudness to the range [0, 1]. 123 | f0_hz = self.resample(f0_hz) 124 | f0_scaled = scale_f0_hz(f0_hz) 125 | 126 | if power_db is not None: 127 | # Use dataset values if present. 128 | pw_db = power_db 129 | elif audio is not None: 130 | # Otherwise, compute power on the fly. 131 | pw_db = ddsp.spectral_ops.compute_power(audio, 132 | sample_rate=self.sample_rate, 133 | frame_rate=self.frame_rate, 134 | frame_size=self.frame_size) 135 | else: 136 | raise ValueError('Power preprocessing requires either ' 137 | '"power_db" or "audio" keys to be provided ' 138 | 'in the dataset.') 139 | # Resample. 140 | pw_db = self.resample(pw_db) 141 | # Scale power. 142 | pw_scaled = scale_db(pw_db) 143 | 144 | return f0_hz, pw_db, f0_scaled, pw_scaled 145 | 146 | @staticmethod 147 | def invert_scaling(f0_scaled, pw_scaled): 148 | """Puts scaled f0, loudness, and power back to hz & db scales.""" 149 | f0_hz = inv_scale_f0_hz(f0_scaled) 150 | power_db = inv_scale_db(pw_scaled) 151 | return f0_hz, power_db 152 | 153 | 154 | @gin.register 155 | class OnlineF0PowerPreprocessor(nn.DictLayer): 156 | """Dynamically compute power_db and f0_hz with optional centered frames.""" 157 | 158 | def __init__(self, 159 | frame_rate=250, 160 | frame_size=1024, 161 | padding='center', 162 | compute_power=True, 163 | compute_f0=True, 164 | crepe_saved_model_path='full', 165 | viterbi=False, 166 | **kwargs): 167 | super().__init__(**kwargs) 168 | # Preprocessing must happen at 16kHz because CREPE trained at 16kHz. 169 | self.sample_rate = ddsp.spectral_ops.CREPE_SAMPLE_RATE 170 | self.frame_rate = frame_rate 171 | self.frame_size = frame_size 172 | self.hop_size = self.sample_rate // frame_rate 173 | 174 | self.compute_f0 = compute_f0 175 | self.compute_power = compute_power 176 | 177 | self.padding = padding 178 | 179 | # Crepe model, must either be a model size or path to SavedModel. 180 | if crepe_saved_model_path: 181 | self.crepe_model = ddsp.spectral_ops.PretrainedCREPE( 182 | model_size_or_path=crepe_saved_model_path, hop_size=self.hop_size 183 | ) 184 | 185 | # Use viterbi decoding. 186 | self.viterbi = viterbi 187 | 188 | def call( 189 | self, audio, f0_hz=None, f0_confidence=None, audio_16k=None, pw_db=None 190 | ) -> ['f0_hz', 'pw_db', 'f0_scaled', 'pw_scaled', 'f0_confidence']: 191 | """Compute power on the fly if it's not in the inputs.""" 192 | # Compute features at 16kHz (needed for CREPE). 193 | if audio_16k is not None: 194 | audio = audio_16k 195 | 196 | # Compute power and f0 on the fly. 197 | if self.compute_power: 198 | pw_db = ddsp.spectral_ops.compute_power(audio, 199 | sample_rate=self.sample_rate, 200 | frame_rate=self.frame_rate, 201 | frame_size=self.frame_size, 202 | padding=self.padding) 203 | 204 | if self.compute_f0: 205 | f0_hz, f0_confidence = self.crepe_model.predict_f0_and_confidence( 206 | audio, viterbi=self.viterbi, padding=self.padding) 207 | # Stop gradients from flowing to CREPE. 208 | f0_hz = tf.stop_gradient(f0_hz) 209 | f0_confidence = tf.stop_gradient(f0_confidence) 210 | elif f0_hz is None or f0_confidence is None: 211 | raise ValueError('Preprocessor must either have `compute_f0=True`, or' 212 | '__call__ must be supplied 3 arguments, ' 213 | '[audio, f0_hz, and f0_confidence].') 214 | 215 | # For NN training, scale frequency and loudness to the range [0, 1]. 216 | pw_db = at_least_3d(pw_db) 217 | f0_hz = at_least_3d(f0_hz) 218 | 219 | pw_scaled = scale_db(pw_db) 220 | f0_scaled = scale_f0_hz(f0_hz) 221 | 222 | # For sanity checking. 223 | # You need to define time_steps correctly just to make sure you know what 224 | # you're doing, and what the model output shape is (no interpolation). 225 | n_t = audio.shape[1] 226 | time_steps, _ = ddsp.spectral_ops.get_framed_lengths( 227 | n_t, self.frame_size, self.hop_size, self.padding) 228 | for k, output in { 229 | 'f0_hz': f0_hz, 230 | 'pw_db': pw_db, 231 | 'f0_scaled': f0_scaled, 232 | 'pw_scaled': pw_scaled, 233 | 'f0_confidence': f0_confidence}.items(): 234 | if output.shape[1] != time_steps: 235 | raise ValueError( 236 | f'OnlineF0PowerPreprocessor output: ({k}) does not have ' 237 | f'{time_steps} timesteps. Output shape: {output.shape}. ' 238 | f'\nInputs: seconds ({n_t/self.sample_rate}), ' 239 | f'frame_rate ({self.frame_rate}), ' 240 | f'padding ("{self.padding}").') 241 | 242 | return f0_hz, pw_db, f0_scaled, pw_scaled, f0_confidence 243 | 244 | 245 | -------------------------------------------------------------------------------- /ddsp/training/preprocessing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for ddsp.training.preprocessing.""" 16 | 17 | from absl.testing import parameterized 18 | from ddsp.core import resample 19 | from ddsp.spectral_ops import compute_power 20 | from ddsp.training import preprocessing 21 | import tensorflow as tf 22 | 23 | tfkl = tf.keras.layers 24 | 25 | 26 | class F0PowerPreprocessorTest(parameterized.TestCase, tf.test.TestCase): 27 | 28 | def setUp(self): 29 | """Create input dictionary and preprocessor.""" 30 | super().setUp() 31 | sr = 16000 32 | frame_rate = 250 33 | frame_size = 256 34 | n_samples = 16000 35 | n_t = 250 36 | # Replicate preprocessor computations. 37 | audio = 0.5 * tf.sin(tf.range(0, n_samples, dtype=tf.float32))[None, :] 38 | power_db = compute_power(audio, 39 | sample_rate=sr, 40 | frame_rate=frame_rate, 41 | frame_size=frame_size) 42 | power_db = preprocessing.at_least_3d(power_db) 43 | power_db = resample(power_db, n_t) 44 | self.input_dict = { 45 | 'f0_hz': tf.ones([1, n_t]), 46 | 'audio': audio, 47 | 'power_db': power_db, 48 | } 49 | self.preprocessor = preprocessing.F0PowerPreprocessor( 50 | time_steps=n_t, 51 | frame_rate=frame_rate, 52 | sample_rate=sr) 53 | 54 | @parameterized.named_parameters( 55 | ('audio_only', ['audio']), 56 | ('power_only', ['power_db']), 57 | ('audio_and_power', ['audio', 'power_db']), 58 | ) 59 | def test_audio_only(self, input_keys): 60 | input_keys += ['f0_hz'] 61 | inputs = {k: v for k, v in self.input_dict.items() if k in input_keys} 62 | outputs = self.preprocessor(inputs) 63 | self.assertAllClose(self.input_dict['power_db'], 64 | outputs['pw_db'], 65 | rtol=0.5, 66 | atol=30) 67 | 68 | if __name__ == '__main__': 69 | tf.test.main() 70 | -------------------------------------------------------------------------------- /ddsp/training/trainers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library of Trainer objects that define traning step and wrap optimizer.""" 16 | 17 | import time 18 | 19 | from absl import logging 20 | from ddsp.training import train_util 21 | import gin 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | @gin.configurable 26 | class Trainer(object): 27 | """Class to bind an optimizer, model, strategy, and training step function.""" 28 | 29 | def __init__(self, 30 | model, 31 | strategy, 32 | checkpoints_to_keep=100, 33 | learning_rate=0.001, 34 | lr_decay_steps=10000, 35 | lr_decay_rate=0.98, 36 | grad_clip_norm=3.0, 37 | restore_keys=None): 38 | """Constructor. 39 | 40 | Args: 41 | model: Model to train. 42 | strategy: A distribution strategy. 43 | checkpoints_to_keep: Max number of checkpoints before deleting oldest. 44 | learning_rate: Scalar initial learning rate. 45 | lr_decay_steps: Exponential decay timescale. 46 | lr_decay_rate: Exponential decay magnitude. 47 | grad_clip_norm: Norm level by which to clip gradients. 48 | restore_keys: List of names of model properties to restore. If no keys are 49 | passed, restore the whole model. 50 | """ 51 | self.model = model 52 | self.strategy = strategy 53 | self.checkpoints_to_keep = checkpoints_to_keep 54 | self.grad_clip_norm = grad_clip_norm 55 | self.restore_keys = restore_keys 56 | 57 | # Create an optimizer. 58 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 59 | initial_learning_rate=learning_rate, 60 | decay_steps=lr_decay_steps, 61 | decay_rate=lr_decay_rate) 62 | 63 | with self.strategy.scope(): 64 | self.optimizer = tf.keras.optimizers.Adam(lr_schedule) 65 | 66 | def get_checkpoint(self, model=None): 67 | """Model arg can also be a tf.train.Checkpoint(**dict(submodules)).""" 68 | model = model or self.model # Default to full model. 69 | return tf.train.Checkpoint(model=model, optimizer=self.optimizer) 70 | 71 | def save(self, save_dir): 72 | """Saves model and optimizer to a checkpoint.""" 73 | # Saving weights in checkpoint format because saved_model requires 74 | # handling variable batch size, which some synths and effects can't. 75 | start_time = time.time() 76 | checkpoint = self.get_checkpoint() 77 | manager = tf.train.CheckpointManager( 78 | checkpoint, directory=save_dir, max_to_keep=self.checkpoints_to_keep) 79 | step = self.step.numpy() 80 | manager.save(checkpoint_number=step) 81 | logging.info('Saved checkpoint to %s at step %s', save_dir, step) 82 | logging.info('Saving model took %.1f seconds', time.time() - start_time) 83 | 84 | def restore(self, checkpoint_path, restore_keys=None): 85 | """Restore model and optimizer from a checkpoint if it exists. 86 | 87 | Args: 88 | checkpoint_path: Path to checkpoint file or directory. 89 | restore_keys: Optional list of strings for submodules to restore. 90 | 91 | Raises: 92 | FileNotFoundError: If no checkpoint is found. 93 | """ 94 | logging.info('Restoring from checkpoint...') 95 | start_time = time.time() 96 | 97 | # Prefer function args over object properties. 98 | restore_keys = restore_keys or self.restore_keys 99 | if restore_keys is None: 100 | # If no keys are passed, restore the whole model. 101 | model = self.model 102 | logging.info('Trainer restoring the full model') 103 | else: 104 | # Restore only sub-modules by building a new subgraph. 105 | restore_dict = {k: getattr(self.model, k) for k in restore_keys} 106 | model = tf.train.Checkpoint(**restore_dict) 107 | 108 | logging.info('Trainer restoring model subcomponents:') 109 | for k, v in restore_dict.items(): 110 | log_str = 'Restoring {}: {}'.format(k, v) 111 | logging.info(log_str) 112 | 113 | # Restore from latest checkpoint. 114 | checkpoint = self.get_checkpoint(model) 115 | latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path) 116 | # checkpoint.restore must be within a strategy.scope() so that optimizer 117 | # slot variables are mirrored. 118 | with self.strategy.scope(): 119 | if restore_keys is None: 120 | checkpoint.restore(latest_checkpoint) 121 | else: 122 | checkpoint.restore(latest_checkpoint).expect_partial() 123 | logging.info('Loaded checkpoint %s', latest_checkpoint) 124 | logging.info('Loading model took %.1f seconds', time.time() - start_time) 125 | 126 | @property 127 | def step(self): 128 | """The number of training steps completed.""" 129 | return self.optimizer.iterations 130 | 131 | def psum(self, x, axis=None): 132 | """Sum across processors.""" 133 | return self.strategy.reduce(tf.distribute.ReduceOp.SUM, x, axis=axis) 134 | 135 | def run(self, fn, *args, **kwargs): 136 | """Distribute and run function on processors.""" 137 | return self.strategy.run(fn, args=args, kwargs=kwargs) 138 | 139 | def build(self, batch): 140 | """Build the model by running a distributed batch through it.""" 141 | logging.info('Building the model...') 142 | _ = self.run(tf.function(self.model.__call__), batch) 143 | self.model.summary() 144 | 145 | def distribute_dataset(self, dataset): 146 | """Create a distributed dataset.""" 147 | if isinstance(dataset, tf.data.Dataset): 148 | return self.strategy.experimental_distribute_dataset(dataset) 149 | else: 150 | return dataset 151 | 152 | @tf.function 153 | def train_step(self, inputs): 154 | """Distributed training step.""" 155 | # Wrap iterator in tf.function, slight speedup passing in iter vs batch. 156 | batch = next(inputs) if hasattr(inputs, '__next__') else inputs 157 | losses = self.run(self.step_fn, batch) 158 | # Add up the scalar losses across replicas. 159 | n_replicas = self.strategy.num_replicas_in_sync 160 | return {k: self.psum(v, axis=None) / n_replicas for k, v in losses.items()} 161 | 162 | @tf.function 163 | def step_fn(self, batch): 164 | """Per-Replica training step.""" 165 | with tf.GradientTape() as tape: 166 | _, losses = self.model(batch, return_losses=True, training=True) 167 | # Clip and apply gradients. 168 | grads = tape.gradient(losses['total_loss'], self.model.trainable_variables) 169 | grads, _ = tf.clip_by_global_norm(grads, self.grad_clip_norm) 170 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 171 | return losses 172 | 173 | 174 | @gin.configurable 175 | def get_trainer_class(trainer_class=Trainer): 176 | """Gin configurable function get a 'global' trainer for use in ddsp_run.py. 177 | 178 | Args: 179 | trainer_class: A trainer class such as `Trainer`. 180 | 181 | Returns: 182 | The 'global' trainer class specifieed in the gin config. 183 | """ 184 | return trainer_class 185 | -------------------------------------------------------------------------------- /ddsp/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Separate file for storing the current version of DDSP. 16 | 17 | Stored in a separate file so that setup.py can reference the version without 18 | pulling in all the dependencies in __init__.py. 19 | """ 20 | 21 | __version__ = '3.7.0' 22 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Install ddsp.""" 16 | 17 | import os 18 | import sys 19 | import setuptools 20 | 21 | # To enable importing version.py directly, we add its path to sys.path. 22 | version_path = os.path.join(os.path.dirname(__file__), 'ddsp') 23 | sys.path.append(version_path) 24 | from version import __version__ # pylint: disable=g-import-not-at-top 25 | 26 | setuptools.setup( 27 | name='ddsp', 28 | version=__version__, 29 | description='Differentiable Digital Signal Processing ', 30 | author='Google Inc.', 31 | author_email='no-reply@google.com', 32 | url='http://github.com/magenta/ddsp', 33 | license='Apache 2.0', 34 | packages=setuptools.find_packages(), 35 | package_data={ 36 | '': ['*.gin'], 37 | }, 38 | scripts=[], 39 | install_requires=[ 40 | 'absl-py', 41 | 'apache-beam', 42 | 'cloudml-hypertune<=0.1.0.dev6', 43 | 'crepe<=0.0.12', 44 | 'dill<=0.3.4', 45 | 'future', 46 | 'gin-config>=0.3.0', 47 | 'google-cloud-storage', 48 | 'hmmlearn<=0.2.7', 49 | 'librosa<=0.10', 50 | 'pydub<=0.25.1', 51 | 'protobuf<=3.20', # temporary fix for proto dependency bug 52 | 'mir_eval<=0.7', 53 | 'note_seq<0.0.4', 54 | 'numpy<1.24', 55 | 'scipy<=1.10.1', 56 | 'six', 57 | 'tensorflow<=2.11', 58 | 'tensorflowjs<3.19', 59 | 'tensorflow-probability<=0.19', 60 | 'tensorflow-datasets<=4.9', 61 | 'tflite_support<=0.1' 62 | ], 63 | extras_require={ 64 | 'gcp': [ 65 | 'gevent', 'google-api-python-client', 'google-compute-engine', 66 | 'oauth2client' 67 | ], 68 | 'data_preparation': [ 69 | 'apache_beam', 70 | # TODO(jesseengel): Remove versioning when beam import is fixed. 71 | 'pyparsing<=2.4.7' 72 | ], 73 | 'test': ['pytest', 'pylint!=2.5.0'], 74 | }, 75 | # pylint: disable=line-too-long 76 | entry_points={ 77 | 'console_scripts': [ 78 | 'ddsp_export = ddsp.training.ddsp_export:console_entry_point', 79 | 'ddsp_run = ddsp.training.ddsp_run:console_entry_point', 80 | 'ddsp_prepare_tfrecord = ddsp.training.data_preparation.ddsp_prepare_tfrecord:console_entry_point', 81 | 'ddsp_generate_synthetic_dataset = ddsp.training.data_preparation.ddsp_generate_synthetic_dataset:console_entry_point', 82 | 'ddsp_ai_platform = ddsp.training.docker.ddsp_ai_platform:console_entry_point', 83 | ], 84 | }, 85 | # pylint: enable=line-too-long 86 | classifiers=[ 87 | 'Development Status :: 4 - Beta', 88 | 'Intended Audience :: Developers', 89 | 'Intended Audience :: Science/Research', 90 | 'License :: OSI Approved :: Apache Software License', 91 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 92 | ], 93 | tests_require=['pytest'], 94 | setup_requires=['pytest-runner'], 95 | keywords='audio dsp signalprocessing machinelearning music', 96 | ) 97 | -------------------------------------------------------------------------------- /update_gin_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The DDSP Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Script to update old '.gin' config files to work with current codebase. 16 | 17 | When there are changes in a function signature (adding or removing kwargs), 18 | those changes need to be reflected in gin config files. There is no way to 19 | automatically do this from within gin. 20 | 21 | This script maintains some backwards compatability, by keeping track of changes 22 | to function signatures and automatically updating gin config files to 23 | incorporate those changes. 24 | 25 | In particular, this is useful for operative_config-*.gin files that are created 26 | during training, to allow loading of old models. 27 | 28 | This script will update any target gin config files, and leave copies of the old 29 | files for safe-keeping. 30 | 31 | Usage: 32 | ====== 33 | python update_gin_config.py path/to/config/operative_config-*.gin 34 | """ 35 | 36 | import os 37 | import re 38 | 39 | from absl import app 40 | from absl import flags 41 | from tensorflow import io 42 | 43 | flags.DEFINE_string('path', '/tmp/operative_config-*.gin', 'Path to gin ' 44 | 'config files. May be a path to a file, a or a glob ' 45 | 'expression.') 46 | flags.DEFINE_boolean('overwrite', False, 'Overwrite orginal files.') 47 | 48 | 49 | FLAGS = flags.FLAGS 50 | gfile = io.gfile 51 | 52 | 53 | # ============================================================================== 54 | # Updates to perform 55 | # ============================================================================== 56 | # Remove lines with any of these strings in them. 57 | REMOVE = [ 58 | 'SpectralLoss.delta_delta_freq_weight', 59 | 'SpectralLoss.delta_delta_time_weight', 60 | 'DilatedConvEncoder.resample', 61 | 'DilatedConvDecoder.resample', 62 | ] 63 | 64 | 65 | # Perform these line-by-line substitutions, (Regex, Replacement String). 66 | SUBSTITUTE = [ 67 | ('ZRnnFcDecoder', 'RnnFcDecoder'), 68 | ] 69 | 70 | 71 | # Add the following kwargs, (GinConfigurable, kwarg, value). 72 | # -> 'GinConfigurable.kwarg = value' 73 | ADD = [ 74 | ('RnnFcDecoder', 'input_keys', '("f0_scaled", "ld_scaled")'), 75 | ] 76 | 77 | 78 | # ============================================================================== 79 | # Main Program 80 | # ============================================================================== 81 | def add_kwarg(lines, gin_configurable, kwarg, value): 82 | """Return the line where a GinConfigurable first appears.""" 83 | gin_kwarg = gin_configurable + '.' + kwarg 84 | new_line = gin_kwarg + ' = ' + value + '\n' 85 | configurable_present = any([gin_configurable in line for line in lines]) 86 | kwarg_present = any([gin_kwarg in line for line in lines]) 87 | if configurable_present and not kwarg_present: 88 | # Add to the bottom of the config. 89 | lines.append('\n' + new_line) 90 | print(f'Added: {new_line.rstrip()}') 91 | elif configurable_present and kwarg_present: 92 | print(f'Skipped Add: {new_line.rstrip()}, {gin_kwarg} already present.') 93 | else: 94 | print(f'Skipped Add: {new_line.rstrip()}, {gin_configurable} not present.') 95 | 96 | 97 | def main(argv): 98 | # Parse input args. 99 | if len(argv) > 2: 100 | raise app.UsageError('Too many command-line arguments.') 101 | elif len(argv) == 2: 102 | path = argv[1] 103 | else: 104 | path = FLAGS.path 105 | 106 | # Get a list of files that match the pattern. 107 | files = gfile.glob(path) 108 | 109 | for fpath in files: 110 | # Create a new file path. 111 | dirname, filename = os.path.split(fpath) 112 | new_filename = filename if FLAGS.overwrite else 'updated_' + filename 113 | new_fpath = os.path.join(dirname, new_filename) 114 | print(f'\nUpdating: \n{fpath} -> \n{new_fpath}') 115 | print('================') 116 | 117 | # Read old config. 118 | with gfile.GFile(fpath, 'r') as f: 119 | lines = f.readlines() 120 | 121 | # Make new config. 122 | new_lines = [] 123 | for line in lines: 124 | # Remove lines with old arguments. 125 | if any([tag in line for tag in REMOVE]): 126 | print(f'Removed: {line.rstrip()}') 127 | continue 128 | 129 | # Substitute. 130 | for regex, sub in SUBSTITUTE: 131 | old_line = line 132 | line, n = re.subn(regex, sub, line) 133 | if n: 134 | print(f'Swapped: {old_line.rstrip()} -> {line.rstrip()}') 135 | 136 | # Append the new line. 137 | new_lines.append(line) 138 | 139 | # Add new lines after substitutions. 140 | for gin_configurable, kwarg, value in ADD: 141 | add_kwarg(new_lines, gin_configurable, kwarg, value) 142 | 143 | # Delete target file if it exists. 144 | if gfile.exists(new_fpath): 145 | gfile.remove(new_fpath) 146 | 147 | # Write to a new file. 148 | with gfile.GFile(new_fpath, 'w') as f: 149 | _ = f.write(''.join(new_lines)) 150 | 151 | 152 | if __name__ == '__main__': 153 | app.run(main) 154 | -------------------------------------------------------------------------------- /update_pip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Usage: 4 | # sh third_party/py/ddsp/update_pip.sh 5 | set -ex 6 | 7 | orig_dir=$(pwd) 8 | tmp_dir=$(mktemp -d -t ddsp-XXXX) 9 | git clone https://github.com/magenta/ddsp.git $tmp_dir 10 | cd $tmp_dir 11 | 12 | python setup.py sdist 13 | python setup.py bdist_wheel --universal 14 | twine upload dist/* 15 | 16 | cd $orig_dir 17 | rm -rf $tmp_dir 18 | --------------------------------------------------------------------------------