├── .github
└── workflows
│ └── build.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── ddsp
├── __init__.py
├── colab
│ ├── README.md
│ ├── __init__.py
│ ├── colab_utils.py
│ ├── demos
│ │ ├── README.md
│ │ ├── Train_VST.ipynb
│ │ ├── pitch_detection.ipynb
│ │ ├── timbre_transfer.ipynb
│ │ └── train_autoencoder.ipynb
│ └── tutorials
│ │ ├── 0_processor.ipynb
│ │ ├── 1_synths_and_effects.ipynb
│ │ ├── 2_processor_group.ipynb
│ │ ├── 3_training.ipynb
│ │ ├── 4_core_functions.ipynb
│ │ └── README.md
├── core.py
├── core_test.py
├── dags.py
├── dags_test.py
├── effects.py
├── effects_test.py
├── losses.py
├── losses_test.py
├── processors.py
├── processors_test.py
├── spectral_ops.py
├── spectral_ops_test.py
├── synths.py
├── synths_test.py
├── test_util.py
├── training
│ ├── README.md
│ ├── __init__.py
│ ├── cloud.py
│ ├── cloud_test.py
│ ├── data.py
│ ├── data_preparation
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── ddsp_generate_synthetic_dataset.py
│ │ ├── ddsp_prepare_tfrecord.py
│ │ ├── prepare_tfrecord_lib.py
│ │ ├── prepare_tfrecord_lib_test.py
│ │ └── synthetic_data.py
│ ├── ddsp_export.py
│ ├── ddsp_run.py
│ ├── decoders.py
│ ├── decoders_test.py
│ ├── docker
│ │ ├── Dockerfile
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── config_hypertune.yaml
│ │ ├── config_multiple_vms.yaml
│ │ ├── config_single_vm.yaml
│ │ ├── ddsp_ai_platform.py
│ │ ├── task.py
│ │ └── task_test.py
│ ├── encoders.py
│ ├── eval_util.py
│ ├── evaluators.py
│ ├── gin
│ │ ├── __init__.py
│ │ ├── datasets
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── base.gin
│ │ │ ├── nsynth.gin
│ │ │ ├── tfrecord.gin
│ │ │ └── urmp
│ │ │ │ ├── README.md
│ │ │ │ ├── all.gin
│ │ │ │ ├── all_midi.gin
│ │ │ │ ├── base.gin
│ │ │ │ └── midi_base.gin
│ │ ├── eval
│ │ │ ├── __init__.py
│ │ │ ├── basic.gin
│ │ │ ├── basic_f0_ld.gin
│ │ │ ├── basic_f0_ld_twm.gin
│ │ │ ├── heuristic.gin
│ │ │ ├── heuristic_power.gin
│ │ │ └── midi_ae.gin
│ │ ├── models
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── ae.gin
│ │ │ ├── midiae
│ │ │ │ ├── README.md
│ │ │ │ ├── midiae.gin
│ │ │ │ ├── mixins
│ │ │ │ │ ├── README.md
│ │ │ │ │ ├── _.gin
│ │ │ │ │ ├── hmm_prior.gin
│ │ │ │ │ ├── midi_encoder.gin
│ │ │ │ │ └── recon_lossgroup.gin
│ │ │ │ └── z_midiae.gin
│ │ │ ├── solo_instrument.gin
│ │ │ └── vst
│ │ │ │ ├── __init__.py
│ │ │ │ ├── vst.gin
│ │ │ │ ├── vst_32k.gin
│ │ │ │ └── vst_48k.gin
│ │ ├── optimization
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── base.gin
│ │ │ └── base_tpu.gin
│ │ └── papers
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── iclr2020
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── nsynth_ae.gin
│ │ │ ├── solo_instrument.gin
│ │ │ └── tiny_instrument.gin
│ │ │ └── icml2020
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── finetune_dataset.gin
│ │ │ ├── finetune_model.gin
│ │ │ ├── pretrain_dataset.gin
│ │ │ └── pretrain_model.gin
│ ├── heuristics.py
│ ├── heuristics_test.py
│ ├── inference.py
│ ├── metrics.py
│ ├── metrics_test.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── autoencoder.py
│ │ ├── autoencoder_test.py
│ │ ├── inverse_synthesis.py
│ │ ├── midi_autoencoder.py
│ │ └── model.py
│ ├── nn.py
│ ├── nn_test.py
│ ├── plotting.py
│ ├── postprocessing.py
│ ├── preprocessing.py
│ ├── preprocessing_test.py
│ ├── summaries.py
│ ├── train_util.py
│ └── trainers.py
└── version.py
├── pylintrc
├── setup.cfg
├── setup.py
├── update_gin_config.py
└── update_pip.sh
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: build
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - name: Set up Python
11 | uses: actions/setup-python@v2
12 | with:
13 | python-version: '3.8'
14 | - name: Install dependencies
15 | # TODO(jesseengel): Remove `legacy-resolver` when pip dependency resolution
16 | # is no longer broken. Currently stalls (backtracking), only on GitHub, not
17 | # locally.
18 | run: |
19 | sudo apt update
20 | sudo apt-get -y install ffmpeg
21 | pip install -U pip
22 | pip install -e .[data_preparation,test] --use-deprecated=legacy-resolver
23 | - name: Test with pytest
24 | run: pytest
25 | - name: Lint with pylint
26 | run: pylint ddsp
27 | # The below step just reports the success or failure of tests as a "commit status".
28 | # This is needed for copybara integration.
29 | - name: Report success or failure as github status
30 | if: always()
31 | shell: bash
32 | run: |
33 | status="${{ job.status }}"
34 | lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
35 | curl -sS --request POST \
36 | --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
37 | --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
38 | --header 'content-type: application/json' \
39 | --data '{
40 | "state": "'$lowercase_status'",
41 | "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
42 | "description": "'$status'",
43 | "context": "github-actions/build"
44 | }'
45 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 | DSP can be subtle to get completely right, so we particularly appreciate the
5 | contributions of those with expertise in signal processing to help fix any
6 | mistakes we may have made 😄.
7 |
8 | # Versioning
9 |
10 | We'll do our best to keep the version updated. This repo contains two code bases
11 | which makes versioning a bit tricky. The core code base `ddsp/` and a more
12 | experimental training code base `ddsp/training/` that is used for active
13 | research. We will thus adopt the following scheme for incrementing version:
14 |
15 | `vMajor.Minor.Revision`
16 |
17 | * Major: Breaking change in `ddsp/`
18 | * Minor: New feature in `ddsp/`, breaking change in `training/`
19 | * Revision: New feature in `training/`, minor bug fix anywhere
20 |
21 | ## Code Design Goals
22 | As much as we can, we would like the DDSP library to be approachable,
23 | well-tested, well-documented, and full of useful examples. Thus, PRs that add
24 | new functionality should be accompanined with ample documentation and tests to
25 | help newcomers understand a typical use case, and guard against silent failures
26 | from breaking changes in the future. Please follow the existing doc/testing
27 | style when you can.
28 |
29 | To ensure a consistent style, new code should follow the [Google's Python Style Guide](https://google.github.io/styleguide/pyguide.html)
30 | and will need to pass a google style linter before acceptance. While this can
31 | add a little work up front, and occasionally make things more verbose, it helps
32 | reduce mental overhead and makes the code more readable.
33 |
34 | ## Code reviews
35 |
36 | All submissions, including submissions by project members, require review. We
37 | use GitHub pull requests for this purpose. Consult
38 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
39 | information on using pull requests.
40 |
41 | Please be sure to test your code by running `pytest` and `pylint` before
42 | submitting a pull request for review. Note that code cannot be merged until
43 | these tests pass on [GitHub Actions](https://github.com/magenta/ddsp/actions?query=workflow%3Abuild).
44 |
45 |
46 | ## Getting Started
47 |
48 | If you're looking for a way to contribute, but not sure where to start, you
49 | could:
50 |
51 | * Add some documentation to an existing function.
52 | * Add a missing test to improve coverage.
53 | * Add type hints to functions in a new file.
54 | * Add a new colab tutorial or demo, covering a typical use case or showing something cool.
55 | * Respond to a bug or feature request in the github [Issues](github.com/magenta/ddsp/issues).
56 | * Add a new signal `Processor` and corresponding test.
57 |
58 | ## Contributor License Agreement
59 |
60 | Contributions to this project must be accompanied by a Contributor License
61 | Agreement. You (or your employer) retain the copyright to your contribution;
62 | this simply gives us permission to use and redistribute your contributions as
63 | part of the project. Head over to to see
64 | your current agreements on file or to sign a new one.
65 |
66 | You generally only need to submit a CLA once, so if you've already submitted one
67 | (even if it was for a different project), you probably don't need to do it
68 | again.
69 |
70 |
71 | ## Community Guidelines
72 |
73 | This project follows
74 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
75 |
--------------------------------------------------------------------------------
/ddsp/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Base module for the differentiable digital signal processing library."""
16 |
17 | # Module imports.
18 | from ddsp import core
19 | from ddsp import dags
20 | from ddsp import effects
21 | from ddsp import losses
22 | from ddsp import processors
23 | from ddsp import spectral_ops
24 | from ddsp import synths
25 |
26 | # Version number.
27 | from ddsp.version import __version__
28 |
--------------------------------------------------------------------------------
/ddsp/colab/README.md:
--------------------------------------------------------------------------------
1 | # Colab Notebooks
2 |
3 | Interactive notebooks to demonstrate DDSP.
4 |
5 | * [demos](./demos/): Self-contained demonstrations for training models and showing them in action.
6 |
7 | * [tutorials](./tutorials/): Interactive walkthroughs of the DDSP functions and APIs.
8 |
9 |
--------------------------------------------------------------------------------
/ddsp/colab/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/colab/demos/README.md:
--------------------------------------------------------------------------------
1 | # Demos
2 |
3 | Here are colab notebooks for demonstrating neat things you can do with DDSP.
4 |
5 | * [timbre_transfer](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/timbre_transfer.ipynb):
6 | Convert audio between sound sources with pretrained models. Try turning your voice into a violin, or scratching your laptop and seeing how it sounds as a flute :). Pick from a selection of pretrained models or upload your own that you can train with the `train_autoencoder` demo.
7 |
8 | * [train_autoencoder](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/train_autoencoder.ipynb):
9 | Takes you through all the steps to convert audio files into a dataset and train your own DDSP autoencoder model. You can transfer data and models to/from google drive, and download a .zip file of your trained model to be used with the `timbre_transfer` demo.
10 |
11 | * [pitch_detection](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/pitch_detection.ipynb):
12 | Demonstration of self-supervised pitch detection models from [2020 ICML Workshop paper](https://openreview.net/forum?id=RlVTYWhsky7).
13 |
14 | * [Train_VST](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/demos/Train_VST.ipynb):
15 | Simplified training colab for the real-time audio plugin (WIP).
16 |
--------------------------------------------------------------------------------
/ddsp/colab/tutorials/README.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 | This is the best place to start is the step-by-step tutorials for all the major library components.
4 |
5 | * [0_processor](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/0_processor.ipynb):
6 | Introduction to the Processor class.
7 |
8 | * [1_synths_and_effects](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/1_synths_and_effects.ipynb):
9 | Example usage of processors.
10 |
11 | * [2_processor_group](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/2_processor_group.ipynb):
12 | Stringing processors together in a ProcessorGroup.
13 |
14 | * [3_training](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/3_training.ipynb):
15 | Example of training on a single sound.
16 |
17 | * [4_core_functions](https://colab.research.google.com/github/magenta/ddsp/blob/main/ddsp/colab/tutorials/4_core_functions.ipynb):
18 | Extensive examples for most of the core DDSP functions.
19 |
--------------------------------------------------------------------------------
/ddsp/dags.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of functions and layers of Directed Acyclical Graphs.
16 |
17 | DAGLayer exists as an alternative to manually specifying the forward pass in
18 | python. The advantage is that a variety of configurations can be
19 | programmatically specified via external dependency injection, such as with the
20 | `gin` library.
21 | """
22 |
23 | from typing import Dict, Sequence, Tuple, Text, TypeVar
24 |
25 | from absl import logging
26 | from ddsp import core
27 | import gin
28 | import tensorflow.compat.v2 as tf
29 |
30 | tfkl = tf.keras.layers
31 |
32 | # Define Types.
33 | TensorDict = Dict[Text, tf.Tensor]
34 | KeyOrModule = TypeVar('KeyOrModule', Text, tf.Module)
35 | Node = Tuple[KeyOrModule, Sequence[Text], Sequence[Text]]
36 | DAG = Sequence[Node]
37 |
38 | # Helper Functions for DAGs ---------------------------------------------------
39 | filter_by_value = lambda d, cond: dict(filter(lambda e: cond(e[1]), d.items()))
40 | is_module = lambda v: isinstance(v, tf.Module)
41 |
42 | # Duck typing.
43 | is_loss = lambda v: hasattr(v, 'get_losses_dict')
44 | is_processor = lambda v: hasattr(v, 'get_signal') and hasattr(v, 'get_controls')
45 |
46 |
47 | def split_keras_kwargs(kwargs):
48 | """Strip keras specific kwargs."""
49 | keras_kwargs = {}
50 | for key in ['training', 'mask', 'name']:
51 | if kwargs.get(key) is not None:
52 | keras_kwargs[key] = kwargs.pop(key)
53 | return keras_kwargs, kwargs
54 |
55 |
56 | # DAG and ProcessorGroup Classes -----------------------------------------------
57 | @gin.register
58 | class DAGLayer(tfkl.Layer):
59 | """String modules together."""
60 |
61 | def __init__(self, dag: DAG, **kwarg_modules):
62 | """Constructor.
63 |
64 | Args:
65 | dag: A directed acyclical graph in the form of a list of nodes. Each node
66 | has the form
67 |
68 | ['module', ['input_key', ...], ['output_key', ...]]
69 |
70 | 'module': Module instance or string name of module. For example,
71 | 'encoder' woud access the attribute `dag_layer.encoder`.
72 | 'input_key': List of strings, nested keys in dictionary of dag outputs.
73 | For example, 'inputs/f0_hz' would access `outputs[inputs]['f0_hz']`.
74 | Inputs to the dag are wrapped in a `inputs` dict as shown in the
75 | example. This list is ordered and has one key per a module input
76 | argument. Each node's outputs are prefixed by their module name.
77 | 'output_key': List of strings, keys for each return value of the module.
78 | For example, ['amps', 'freqs'] would have the module return a dict
79 | {'module_name': {'amps': return_value_0, 'freqs': return_value_1}}.
80 | If the module returns a dictionary, the keys of the dictionary will be
81 | used and these values (if provided) will be ignored.
82 |
83 | The graph is read sequentially and must be topologically sorted. This
84 | means that all inputs for a module must already be generated by earlier
85 | modules (or in the input dictionary).
86 | **kwarg_modules: A series of modules to add to DAGLayer. Each kwarg that
87 | is a tf.Module will be added as a property of the layer, so that it will
88 | be accessible as `dag_layer.kwarg`. Also, other keras kwargs such as
89 | 'name' are split off before adding modules.
90 | """
91 | keras_kwargs, kwarg_modules = split_keras_kwargs(kwarg_modules)
92 | super().__init__(**keras_kwargs)
93 |
94 | # Create properties/submodules from other kwargs.
95 | modules = filter_by_value(kwarg_modules, is_module)
96 |
97 | # Remove modules from the dag, make properties of dag_layer.
98 | dag, dag_modules = self.format_dag(dag)
99 | # DAG is now just strings.
100 | self.dag = dag
101 | modules.update(dag_modules)
102 |
103 | # Make as propreties of DAGLayer to keep track of variables in checkpoints.
104 | self.module_names = list(modules.keys())
105 | for module_name, module in modules.items():
106 | setattr(self, module_name, module)
107 |
108 | @property
109 | def modules(self):
110 | """Module getter."""
111 | return [getattr(self, name) for name in self.module_names]
112 |
113 | @staticmethod
114 | def format_dag(dag):
115 | """Remove modules from dag, and replace with module names."""
116 | modules = {}
117 | dag = list(dag) # Make mutable in case it's a tuple.
118 | for i, node in enumerate(dag):
119 | node = list(node) # Make mutable in case it's a tuple.
120 | module = node[0]
121 | if is_module(module):
122 | # Strip module from the dag.
123 | modules[module.name] = module
124 | # Replace with module name.
125 | node[0] = module.name
126 | dag[i] = node
127 | return dag, modules
128 |
129 | def call(self, inputs: TensorDict, **kwargs) -> tf.Tensor:
130 | """Run dag for an input dictionary."""
131 | return self.run_dag(inputs, **kwargs)
132 |
133 | @gin.configurable(allowlist=['verbose']) # For debugging.
134 | def run_dag(self,
135 | inputs: TensorDict,
136 | verbose: bool = False,
137 | **kwargs) -> TensorDict:
138 | """Connects and runs submodules of dag.
139 |
140 | Args:
141 | inputs: A dictionary of input tensors fed to the dag.
142 | verbose: Print out dag routing when running.
143 | **kwargs: Other kwargs to pass to submodules, such as keras kwargs.
144 |
145 | Returns:
146 | A nested dictionary of all the output tensors.
147 | """
148 | # Initialize the outputs with inputs to the dag.
149 | outputs = {'inputs': inputs}
150 | # TODO(jesseengel): Remove this cluttering of the base namespace. Only there
151 | # for backwards compatability.
152 | outputs.update(inputs)
153 |
154 | # Run through the DAG nodes in sequential order.
155 | for node in self.dag:
156 | # The first element of the node can be either a module or module_key.
157 | module_key, input_keys = node[0], node[1]
158 | module = getattr(self, module_key)
159 | # Optionally specify output keys if module does not return dict.
160 | output_keys = node[2] if len(node) > 2 else None
161 |
162 | # Get the inputs to the node.
163 | inputs = [core.nested_lookup(key, outputs) for key in input_keys]
164 |
165 | if verbose:
166 | shape = lambda d: tf.nest.map_structure(lambda x: list(x.shape), d)
167 | logging.info('Input to Module: %s\nKeys: %s\nIn: %s\n',
168 | module_key, input_keys, shape(inputs))
169 |
170 | # Duck typing to avoid dealing with multiple inheritance of Group modules.
171 | if is_processor(module):
172 | # Processor modules.
173 | module_outputs = module(*inputs, return_outputs_dict=True, **kwargs)
174 | elif is_loss(module):
175 | # Loss modules.
176 | module_outputs = module.get_losses_dict(*inputs, **kwargs)
177 | else:
178 | # Network modules.
179 | module_outputs = module(*inputs, **kwargs)
180 |
181 | if not isinstance(module_outputs, dict):
182 | module_outputs = core.to_dict(module_outputs, output_keys)
183 |
184 | if verbose:
185 | logging.info('Output from Module: %s\nOut: %s\n',
186 | module_key, shape(module_outputs))
187 |
188 | # Add module outputs to the dictionary.
189 | outputs[module_key] = module_outputs
190 |
191 | # Alias final module output as dag output.
192 | # 'out' is a reserved key for final dag output.
193 | outputs['out'] = module_outputs
194 |
195 | return outputs
196 |
--------------------------------------------------------------------------------
/ddsp/dags_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.dags.py."""
16 |
17 | from absl.testing import parameterized
18 | from ddsp import dags
19 | import gin
20 | import tensorflow as tf
21 |
22 | # Make dense layers configurable for this test.
23 | gin.external_configurable(tf.keras.layers.Dense, 'tf.keras.layers.Dense')
24 |
25 |
26 | @gin.configurable
27 | class ConfigurableDAGLayer(dags.DAGLayer):
28 | """Configurable wrapper DAGLayer encapsulated for this test."""
29 | pass
30 |
31 |
32 | class DAGLayerTest(parameterized.TestCase, tf.test.TestCase):
33 |
34 | def setUp(self):
35 | """Create some dummy input data for the chain."""
36 | super().setUp()
37 | # Create inputs.
38 | self.n_batch = 4
39 | self.x_dims = 5
40 | self.z_dims = 2
41 | self.x = tf.ones([self.n_batch, self.x_dims])
42 | self.inputs = {'test_data': self.x}
43 | self.gin_config_kwarg_modules = f"""
44 | import ddsp
45 |
46 | ### Modules
47 | ConfigurableDAGLayer.dag = [
48 | ('encoder', ['inputs/test_data'], ['z']),
49 | ('bottleneck', ['encoder/z'], ['z_bottleneck']),
50 | ('decoder', ['bottleneck/z_bottleneck'], ['reconstruction']),
51 | ]
52 | ConfigurableDAGLayer.encoder = @encoder/layers.Dense()
53 | encoder/layers.Dense.units = {self.x_dims}
54 |
55 | ConfigurableDAGLayer.bottleneck = @bottleneck/layers.Dense()
56 | bottleneck/layers.Dense.units = {self.z_dims}
57 |
58 | ConfigurableDAGLayer.decoder = @decoder/layers.Dense()
59 | decoder/layers.Dense.units = {self.x_dims}
60 | """
61 | self.gin_config_dag_modules = f"""
62 | import ddsp
63 |
64 | ### Modules
65 | ConfigurableDAGLayer.dag = [
66 | (@encoder/layers.Dense(), ['inputs/test_data'], ['z']),
67 | (@bottleneck/layers.Dense(), ['encoder/z'], ['z_bottleneck']),
68 | (@decoder/layers.Dense(), ['bottleneck/z_bottleneck'], ['reconstruction']),
69 | ]
70 | encoder/layers.Dense.name = 'encoder'
71 | encoder/layers.Dense.units = {self.x_dims}
72 |
73 | bottleneck/layers.Dense.name = 'bottleneck'
74 | bottleneck/layers.Dense.units = {self.z_dims}
75 |
76 | decoder/layers.Dense.name = 'decoder'
77 | decoder/layers.Dense.units = {self.x_dims}
78 | """
79 |
80 | @parameterized.named_parameters(
81 | ('kwarg_modules', True),
82 | ('dag_modules', False),
83 | )
84 | def test_build_layer(self, kwarg_modules):
85 | """Tests if layer builds properly and produces outputs of correct shape."""
86 | gin_config = (self.gin_config_kwarg_modules if kwarg_modules else
87 | self.gin_config_dag_modules)
88 | with gin.unlock_config():
89 | gin.clear_config()
90 | gin.parse_config(gin_config)
91 |
92 | dag_layer = ConfigurableDAGLayer()
93 | outputs = dag_layer(self.inputs)
94 | self.assertIsInstance(outputs, dict)
95 |
96 | z = outputs['bottleneck']['z_bottleneck']
97 | x_rec = outputs['decoder']['reconstruction']
98 | x_rec2 = outputs['out']['reconstruction']
99 |
100 | # Confirm that layer generates correctly sized tensors.
101 | self.assertEqual(outputs['test_data'].shape, self.x.shape)
102 | self.assertEqual(outputs['inputs']['test_data'].shape, self.x.shape)
103 | self.assertEqual(x_rec.shape, self.x.shape)
104 | self.assertEqual(z.shape[-1], self.z_dims)
105 | self.assertAllClose(x_rec, x_rec2)
106 |
107 | # Confirm that variables are inherited by DAGLayer.
108 | self.assertLen(dag_layer.trainable_variables, 6) # 3 weights, 3 biases.
109 |
110 | if __name__ == '__main__':
111 | tf.test.main()
112 |
--------------------------------------------------------------------------------
/ddsp/effects_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.effects."""
16 |
17 | from absl.testing import parameterized
18 | from ddsp import effects
19 | import tensorflow.compat.v2 as tf
20 |
21 |
22 | class ReverbTest(parameterized.TestCase, tf.test.TestCase):
23 |
24 | def setUp(self):
25 | """Creates some test specific attributes."""
26 | super().setUp()
27 | self.reverb_class = effects.Reverb
28 | self.audio = tf.zeros((3, 16000))
29 | self.construct_args = {'reverb_length': 100}
30 | self.call_args = {'ir': tf.zeros((3, 100, 1))}
31 | self.controls_keys = ['audio', 'ir']
32 |
33 | @parameterized.named_parameters(
34 | ('trainable', True),
35 | ('not_trainable', False),
36 | )
37 | def test_output_shape_and_variables_are_correct(self, trainable):
38 | reverb = self.reverb_class(trainable=trainable, **self.construct_args)
39 | if trainable:
40 | output = reverb(self.audio)
41 | else:
42 | output = reverb(self.audio, **self.call_args)
43 |
44 | self.assertListEqual(list(self.audio.shape), output.shape.as_list())
45 | self.assertEqual(reverb.trainable, trainable)
46 | self.assertEmpty(reverb.non_trainable_variables)
47 | assert_variables = self.assertNotEmpty if trainable else self.assertEmpty
48 | assert_variables(reverb.trainable_variables)
49 |
50 | def test_non_trainable_raises_value_error(self):
51 | reverb = self.reverb_class(trainable=False, **self.construct_args)
52 | with self.assertRaises(ValueError):
53 | _ = reverb(self.audio)
54 |
55 | @parameterized.named_parameters(
56 | ('trainable', True),
57 | ('not_trainable', False),
58 | )
59 | def test_get_controls_returns_correct_keys(self, trainable):
60 | reverb = self.reverb_class(trainable=trainable, **self.construct_args)
61 | reverb.build(self.audio.shape)
62 | if trainable:
63 | controls = reverb.get_controls(self.audio)
64 | else:
65 | controls = reverb.get_controls(self.audio, **self.call_args)
66 |
67 | self.assertListEqual(list(controls.keys()), self.controls_keys)
68 |
69 |
70 | class ExpDecayReverbTest(ReverbTest):
71 |
72 | def setUp(self):
73 | """Creates some test specific attributes."""
74 | super().setUp()
75 | self.reverb_class = effects.ExpDecayReverb
76 | self.audio = tf.zeros((3, 16000))
77 | self.construct_args = {'reverb_length': 100}
78 | self.call_args = {'gain': tf.zeros((3, 1)),
79 | 'decay': tf.zeros((3, 1))}
80 |
81 |
82 | class FilteredNoiseReverbTest(ReverbTest):
83 |
84 | def setUp(self):
85 | """Creates some test specific attributes."""
86 | super().setUp()
87 | self.reverb_class = effects.FilteredNoiseReverb
88 | self.audio = tf.zeros((3, 16000))
89 | self.construct_args = {'reverb_length': 100,
90 | 'n_frames': 10,
91 | 'n_filter_banks': 20}
92 | self.call_args = {'magnitudes': tf.zeros((3, 10, 20))}
93 |
94 |
95 | class FIRFilterTest(tf.test.TestCase):
96 |
97 | def test_output_shape_is_correct(self):
98 | processor = effects.FIRFilter()
99 |
100 | audio = tf.zeros((3, 16000))
101 | magnitudes = tf.zeros((3, 100, 30))
102 | output = processor(audio, magnitudes)
103 |
104 | self.assertListEqual([3, 16000], output.shape.as_list())
105 |
106 |
107 | class ModDelayTest(tf.test.TestCase):
108 |
109 | def test_output_shape_is_correct(self):
110 | processor = effects.ModDelay()
111 |
112 | audio = tf.zeros((3, 16000))
113 | gain = tf.zeros((3, 16000, 1))
114 | phase = tf.zeros((3, 16000, 1))
115 | output = processor(audio, gain, phase)
116 |
117 | self.assertListEqual([3, 16000], output.shape.as_list())
118 |
119 |
120 | if __name__ == '__main__':
121 | tf.test.main()
122 |
--------------------------------------------------------------------------------
/ddsp/losses_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.losses."""
16 |
17 | from ddsp import core
18 | from ddsp import losses
19 | import numpy as np
20 | import tensorflow as tf
21 |
22 |
23 | class LossGroupTest(tf.test.TestCase):
24 |
25 | def setUp(self):
26 | """Create some dummy input data for the chain."""
27 | super().setUp()
28 |
29 | # Create a network output dictionary.
30 | self.nn_outputs = {
31 | 'audio': tf.ones((3, 8000), dtype=tf.float32),
32 | 'audio_synth': tf.ones((3, 8000), dtype=tf.float32),
33 | 'magnitudes': tf.ones((3, 200, 2), dtype=tf.float32),
34 | 'f0_hz': 200 + tf.ones((3, 200, 1), dtype=tf.float32),
35 | }
36 |
37 | # Create Processors.
38 | spectral_loss = losses.SpectralLoss()
39 | crepe_loss = losses.PretrainedCREPEEmbeddingLoss(name='crepe_loss')
40 |
41 | # Create DAG for testing.
42 | self.dag = [
43 | (spectral_loss, ['audio', 'audio_synth']),
44 | (crepe_loss, ['audio', 'audio_synth']),
45 | ]
46 | self.expected_outputs = [
47 | 'spectral_loss',
48 | 'crepe_loss'
49 | ]
50 |
51 | def _check_tensor_outputs(self, strings_to_check, outputs):
52 | for tensor_string in strings_to_check:
53 | tensor = core.nested_lookup(tensor_string, outputs)
54 | self.assertIsInstance(tensor, (np.ndarray, tf.Tensor))
55 |
56 | def test_dag_construction(self):
57 | """Tests if DAG is built properly and runs.
58 | """
59 | loss_group = losses.LossGroup(dag=self.dag)
60 | print('!!!!!!!!!!!', loss_group.dag, loss_group.loss_names, self.dag)
61 | loss_outputs = loss_group(self.nn_outputs)
62 | self.assertIsInstance(loss_outputs, dict)
63 | self._check_tensor_outputs(self.expected_outputs, loss_outputs)
64 |
65 |
66 | class SpectralLossTest(tf.test.TestCase):
67 |
68 | def test_output_shape_is_correct(self):
69 | """Test correct shape with all losses active."""
70 | loss_obj = losses.SpectralLoss(
71 | mag_weight=1.0,
72 | delta_time_weight=1.0,
73 | delta_freq_weight=1.0,
74 | cumsum_freq_weight=1.0,
75 | logmag_weight=1.0,
76 | loudness_weight=1.0,
77 | )
78 |
79 | input_audio = tf.ones((3, 8000), dtype=tf.float32)
80 | target_audio = tf.ones((3, 8000), dtype=tf.float32)
81 |
82 | loss = loss_obj(input_audio, target_audio)
83 |
84 | self.assertListEqual([], loss.shape.as_list())
85 | self.assertTrue(np.isfinite(loss))
86 |
87 |
88 |
89 |
90 | class PretrainedCREPEEmbeddingLossTest(tf.test.TestCase):
91 |
92 | def test_output_shape_is_correct(self):
93 | loss_obj = losses.PretrainedCREPEEmbeddingLoss()
94 |
95 | input_audio = tf.ones((3, 16000), dtype=tf.float32)
96 | target_audio = tf.ones((3, 16000), dtype=tf.float32)
97 |
98 | loss = loss_obj(input_audio, target_audio)
99 |
100 | self.assertListEqual([], loss.shape.as_list())
101 | self.assertTrue(np.isfinite(loss))
102 |
103 |
104 | if __name__ == '__main__':
105 | tf.test.main()
106 |
--------------------------------------------------------------------------------
/ddsp/processors_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.processors."""
16 |
17 | from absl.testing import parameterized
18 | from ddsp import core
19 | from ddsp import effects
20 | from ddsp import processors
21 | from ddsp import synths
22 | import numpy as np
23 | import tensorflow.compat.v2 as tf
24 |
25 |
26 | class ProcessorGroupTest(parameterized.TestCase, tf.test.TestCase):
27 |
28 | def setUp(self):
29 | """Create some dummy input data for the chain."""
30 | super().setUp()
31 | # Create inputs.
32 | self.n_batch = 4
33 | self.n_frames = 1000
34 | self.n_time = 64000
35 | rand_signal = lambda ch: np.random.randn(self.n_batch, self.n_frames, ch)
36 | self.nn_outputs = {
37 | 'amps': rand_signal(1),
38 | 'harmonic_distribution': rand_signal(99),
39 | 'magnitudes': rand_signal(256),
40 | 'f0_hz': 200 + rand_signal(1),
41 | 'target_audio': np.random.randn(self.n_batch, self.n_time)
42 | }
43 |
44 | # Create Processors.
45 | harmonic = synths.Harmonic(name='harmonic')
46 | noise = synths.FilteredNoise(name='noise')
47 | add = processors.Add(name='add')
48 | reverb = effects.Reverb(trainable=True, name='reverb')
49 |
50 | # Create DAG for testing.
51 | self.dag = [
52 | (harmonic, ['amps', 'harmonic_distribution', 'f0_hz']),
53 | (noise, ['magnitudes']),
54 | (add, ['noise/signal', 'harmonic/signal']),
55 | (reverb, ['add/signal']),
56 | ]
57 | self.expected_outputs = [
58 | 'amps',
59 | 'harmonic_distribution',
60 | 'magnitudes',
61 | 'f0_hz',
62 | 'target_audio',
63 | 'harmonic/signal',
64 | 'harmonic/controls/amplitudes',
65 | 'harmonic/controls/harmonic_distribution',
66 | 'harmonic/controls/f0_hz',
67 | 'noise/signal',
68 | 'noise/controls/magnitudes',
69 | 'add/signal',
70 | 'reverb/signal',
71 | 'reverb/controls/ir',
72 | 'out/signal',
73 | ]
74 |
75 | def _check_tensor_outputs(self, strings_to_check, outputs):
76 | for tensor_string in strings_to_check:
77 | tensor = core.nested_lookup(tensor_string, outputs)
78 | self.assertIsInstance(tensor, (np.ndarray, tf.Tensor))
79 |
80 | def test_dag_construction(self):
81 | """Tests if DAG is built properly and runs.
82 | """
83 | processor_group = processors.ProcessorGroup(dag=self.dag,
84 | name='processor_group')
85 | outputs = processor_group.get_controls(self.nn_outputs)
86 | self.assertIsInstance(outputs, dict)
87 | self._check_tensor_outputs(self.expected_outputs, outputs)
88 |
89 |
90 | class AddTest(tf.test.TestCase):
91 |
92 | def test_output_is_correct(self):
93 | processor = processors.Add(name='add')
94 | x = tf.zeros((2, 3), dtype=tf.float32) + 1.0
95 | y = tf.zeros((2, 3), dtype=tf.float32) + 2.0
96 |
97 | output = processor(x, y)
98 |
99 | expected = np.zeros((2, 3), dtype=np.float32) + 3.0
100 | self.assertAllEqual(expected, output)
101 |
102 |
103 | class MixTest(tf.test.TestCase):
104 |
105 | def test_output_shape_is_correct(self):
106 | processor = processors.Mix(name='mix')
107 | x1 = np.zeros((2, 100, 3), dtype=np.float32) + 1.0
108 | x2 = np.zeros((2, 100, 3), dtype=np.float32) + 2.0
109 | mix_level = np.zeros(
110 | (2, 100, 1), dtype=np.float32) + 0.1 # will be passed to sigmoid
111 |
112 | output = processor(x1, x2, mix_level)
113 |
114 | self.assertListEqual([2, 100, 3], output.shape.as_list())
115 |
116 |
117 | if __name__ == '__main__':
118 | tf.test.main()
119 |
--------------------------------------------------------------------------------
/ddsp/synths_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.synths."""
16 |
17 | from ddsp import core
18 | from ddsp import synths
19 | import numpy as np
20 | import tensorflow.compat.v2 as tf
21 |
22 |
23 | class HarmonicTest(tf.test.TestCase):
24 |
25 | def test_output_shape_is_correct(self):
26 | synthesizer = synths.Harmonic(
27 | n_samples=64000,
28 | sample_rate=16000,
29 | scale_fn=None,
30 | normalize_below_nyquist=True)
31 | batch_size = 3
32 | num_frames = 1000
33 | amp = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 1.0
34 | harmonic_distribution = tf.zeros(
35 | (batch_size, num_frames, 16), dtype=tf.float32) + 1.0 / 16
36 | f0_hz = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 16000
37 |
38 | output = synthesizer(amp, harmonic_distribution, f0_hz)
39 |
40 | self.assertAllEqual([batch_size, 64000], output.shape.as_list())
41 |
42 |
43 | class FilteredNoiseTest(tf.test.TestCase):
44 |
45 | def test_output_shape_is_correct(self):
46 | synthesizer = synths.FilteredNoise(n_samples=16000)
47 | filter_bank_magnitudes = tf.zeros((3, 16000, 100), dtype=tf.float32) + 3.0
48 | output = synthesizer(filter_bank_magnitudes)
49 |
50 | self.assertAllEqual([3, 16000], output.shape.as_list())
51 |
52 |
53 | class WavetableTest(tf.test.TestCase):
54 |
55 | def test_output_shape_is_correct(self):
56 | synthesizer = synths.Wavetable(
57 | n_samples=64000,
58 | sample_rate=16000,
59 | scale_fn=None)
60 | batch_size = 3
61 | num_frames = 1000
62 | n_wavetable = 1024
63 | amp = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 1.0
64 | wavetables = tf.zeros(
65 | (batch_size, num_frames, n_wavetable), dtype=tf.float32)
66 | f0_hz = tf.zeros((batch_size, num_frames, 1), dtype=tf.float32) + 440
67 |
68 | output = synthesizer(amp, wavetables, f0_hz)
69 |
70 | self.assertAllEqual([batch_size, 64000], output.shape.as_list())
71 |
72 |
73 | class SinusoidalTest(tf.test.TestCase):
74 |
75 | def test_output_shape_is_correct(self):
76 | synthesizer = synths.Sinusoidal(n_samples=32000, sample_rate=16000)
77 | batch_size = 3
78 | num_frames = 1000
79 | n_partials = 10
80 | amps = tf.zeros((batch_size, num_frames, n_partials),
81 | dtype=tf.float32)
82 | freqs = tf.zeros((batch_size, num_frames, n_partials),
83 | dtype=tf.float32)
84 |
85 | output = synthesizer(amps, freqs)
86 |
87 | self.assertAllEqual([batch_size, 32000], output.shape.as_list())
88 |
89 | def test_frequencies_controls_are_bounded(self):
90 | depth = 10
91 | def freq_scale_fn(x):
92 | return core.frequencies_sigmoid(x, depth=depth, hz_min=0.0, hz_max=8000.0)
93 |
94 | synthesizer = synths.Sinusoidal(
95 | n_samples=32000, sample_rate=16000, freq_scale_fn=freq_scale_fn)
96 | batch_size = 3
97 | num_frames = 10
98 | n_partials = 100
99 | amps = tf.zeros((batch_size, num_frames, n_partials), dtype=tf.float32)
100 | freqs = tf.linspace(-100.0, 100.0, n_partials)
101 | freqs = tf.tile(freqs[tf.newaxis, tf.newaxis, :, tf.newaxis],
102 | [batch_size, num_frames, 1, depth])
103 |
104 | controls = synthesizer.get_controls(amps, freqs)
105 | freqs = controls['frequencies']
106 | lt_nyquist = (freqs <= 8000.0)
107 | gt_zero = (freqs >= 0.0)
108 | both_conditions = np.logical_and(lt_nyquist, gt_zero)
109 |
110 | self.assertTrue(np.all(both_conditions))
111 |
112 |
113 | if __name__ == '__main__':
114 | tf.test.main()
115 |
--------------------------------------------------------------------------------
/ddsp/test_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of helper functions for testing."""
16 |
17 | import numpy as np
18 |
19 |
20 | def gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec):
21 | x = np.linspace(0, audio_len_sec, int(audio_len_sec * sample_rate))
22 | audio_sin = amp * (np.sin(2 * np.pi * frequency * x))
23 | return audio_sin
24 |
25 |
26 | def gen_np_batched_sinusoids(frequency, amp, sample_rate, audio_len_sec,
27 | batch_size):
28 | batch_sinusoids = [
29 | gen_np_sinusoid(frequency, amp, sample_rate, audio_len_sec)
30 | for _ in range(batch_size)
31 | ]
32 | return np.array(batch_sinusoids)
33 |
34 |
--------------------------------------------------------------------------------
/ddsp/training/README.md:
--------------------------------------------------------------------------------
1 | # DDSP Training
2 |
3 |
4 | This directory contains the code for training models using DDSP modules.
5 | The current supported models are variants of an audio autoencoder.
6 |
7 |
8 |

9 |
10 |
11 | # Disclaimer
12 | *Unlike the base `ddsp/` library, this folder is actively modified for new
13 | experiments and has a higher chance of making breaking changes in the future.*
14 |
15 | _Functions and classes marked **EXPERIMENTAL** in their doc string are under active development and very likely to change. They should not be expected to be maintained in their current state._
16 |
17 | ## Modules
18 |
19 | The DDSP training libraries are separated into several modules:
20 |
21 | * [data](./data.py):
22 | DataProvider objects provide tf.data.Dataset.
23 | * [models](./models.py):
24 | Model objects define the full forward pass and losses.
25 | * [preprocessing](./preprocessing.py):
26 | Preprocessor objects format and scale model inputs.
27 | * [encoders](./encoders.py):
28 | Layers to turn preprocessor outputs into latents.
29 | * [decoders](./decoders.py):
30 | Layers to turn latents into ddsp processor inputs.
31 | * [nn](./nn.py):
32 | Network functions and layers.
33 | * [inference](./inference.py):
34 | Model wrappers for efficient inference and the ability to store as
35 | SavedModels.
36 |
37 |
38 | The main training file is `ddsp_run.py` and its helper libraries:
39 |
40 | * [ddsp_run](./ddsp_run.py):
41 | Main file for launching training, evaluation, and sampling runs.
42 | * [train_util](./train_util.py):
43 | Training loop and helper functions.
44 | * [trainers](./trainers.py):
45 | Training step defined by helper objects that bind the strategy, optimizer, and model.
46 | * [eval_util](./eval_util.py):
47 | Evaluation and sampling loop.
48 | * [evaluators](./evaluators.py):
49 | Evaluator objects responsible for computing metrics and summaries.
50 | * [metrics](./metrics.py):
51 | Metrics for evaluation.
52 | * [summaries](./summaries.py):
53 | Summaries for tensorboard images and audio of samples.
54 |
55 | While the modules in the `ddsp/` base directory can be used to train models
56 | with `tf.compat.v1` or `tf.compat.v2` this directory only uses `tf.compat.v2`.
57 |
58 | ## Quickstart
59 |
60 | The [pip installation](../../README.md#installation) includes several scripts that can be called directly from
61 | the command line.
62 |
63 | Hyperparameters are configured via gin, and `ddsp_run.py` must be given three
64 | `--gin_file` flags, one from `gin/models`, one from `gin/datasets`, and one from `gin/eval`. The
65 | files in `gin/papers` include both the dataset, model, and evaluation files for reproducing experiments from a specific paper.
66 |
67 | By default, the program searches for gin files in the installed `ddsp/training/gin` location, but additional search paths can be added with `--gin_search_path`
68 | flags. Individual parameters can also be set with multiple `--gin_param` flags.
69 |
70 | This example below streams a version of the NSynth dataset from GCS.
71 | If not running on GCP, it is much faster to first download the dataset with
72 | [tensorflow_datasets](https://www.tensorflow.org/datasets), and add the flag
73 | `--gin_param="NSynthTfds.data_dir='/path/to/tfds/dir'"`:
74 |
75 | ### Train
76 | ```bash
77 | ddsp_run \
78 | --mode=train \
79 | --save_dir=/tmp/$USER-ddsp-0 \
80 | --gin_file=papers/iclr2020/nsynth_ae.gin \
81 | --gin_param="batch_size=16" \
82 | --alsologtostderr
83 | ```
84 |
85 | ### Evaluate
86 | ```bash
87 | ddsp_run \
88 | --mode=eval \
89 | --save_dir=/tmp/$USER-ddsp-0 \
90 | --gin_file=dataset/nsynth.gin \
91 | --gin_file=eval/basic_f0_ld.gin \
92 | --alsologtostderr
93 | ```
94 |
95 | ### Sample
96 | ```bash
97 | ddsp_run \
98 | --mode=sample \
99 | --save_dir=/tmp/$USER-ddsp-0 \
100 | --gin_file=dataset/nsynth.gin \
101 | --gin_file=eval/basic_f0_ld.gin \
102 | --alsologtostderr
103 | ```
104 |
105 | When training, all gin parameters in the
106 | [operative configuration](https://github.com/google/gin-config/blob/master/docs/index.md#retrieving-operative-parameter-values)
107 | will be saved to the `${MODEL_DIR}/operative_config-0.gin` file, which is then loaded for evaluation, sampling, or further training. The operative config is also visible as a text summary in tensorboard. See
108 | [this doc](https://github.com/google/gin-config/blob/master/docs/index.md#saving-gins-operative-config-to-a-file-and-tensorboard)
109 | for more details.
110 |
111 | ### Backwards compatability
112 |
113 |
114 | For backwards compatability, we keep track of changes in function signatures in `update_gin_config.py`, which can be used to update old operative configs to work with the current library.
115 |
116 |
117 | ### Using Cloud TPU
118 |
119 | To use a [Cloud TPU](https://cloud.google.com/tpu/) for any of the above commands, there are a few minor changes.
120 |
121 | First, your model directory will need to accessible to the TPU. This means it will need to be located in a [GCS bucket with proper permissions](https://cloud.google.com/tpu/docs/storage-buckets).
122 |
123 | Second, you will need to add the following flag:
124 |
125 | ```
126 | --tpu=grpc://:8470 \
127 | ```
128 |
129 | The TPU internal IP address can be found in the Cloud Console.
130 |
131 |
132 | ## Training a model on your own data
133 | ### Prepare dataset
134 | TFRecord dataset out of a folder of .wav or .mp3 files
135 |
136 | ```bash
137 | ddsp_prepare_tfrecord \
138 | --input_audio_filepatterns=/path/to/wavs/*wav \
139 | --output_tfrecord_path=/path/to/dataset_name.tfrecord \
140 | --num_shards=10 \
141 | --alsologtostderr
142 | ```
143 |
144 | ### Train
145 | ```bash
146 | ddsp_run \
147 | --mode=train \
148 | --save_dir=/tmp/$USER-ddsp-0 \
149 | --gin_file=models/solo_instrument.gin \
150 | --gin_file=datasets/tfrecord.gin \
151 | --gin_file=eval/basic_f0_ld.gin \
152 | --gin_param="TFRecordProvider.file_pattern='/path/to/dataset_name.tfrecord*'" \
153 | --gin_param="batch_size=16" \
154 | --alsologtostderr
155 | ```
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/ddsp/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Training code for DDSP models."""
16 |
17 | from ddsp.training import cloud
18 | from ddsp.training import data
19 | from ddsp.training import decoders
20 | from ddsp.training import encoders
21 | from ddsp.training import eval_util
22 | from ddsp.training import evaluators
23 | from ddsp.training import inference
24 | from ddsp.training import metrics
25 | from ddsp.training import models
26 | from ddsp.training import nn
27 | from ddsp.training import plotting
28 | from ddsp.training import postprocessing
29 | from ddsp.training import preprocessing
30 | from ddsp.training import summaries
31 | from ddsp.training import train_util
32 | from ddsp.training import trainers
33 |
--------------------------------------------------------------------------------
/ddsp/training/cloud.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of functions for training on Google Cloud AI-Platform."""
16 |
17 | import os
18 | import re
19 |
20 | from absl import logging
21 | from google.cloud import storage
22 | import hypertune
23 |
24 |
25 | def download_from_gstorage(gstorage_path, local_path):
26 | """Downloads a file from the bucket.
27 |
28 | Args:
29 | gstorage_path: Path to the file inside the bucket that needs to be
30 | downloaded. Format: gs://bucket-name/path/to/file.txt
31 | local_path: Local path where downloaded file should be stored.
32 | """
33 | gstorage_path = gstorage_path.strip('gs:/')
34 | bucket_name = gstorage_path.split('/')[0]
35 | blob_name = os.path.relpath(gstorage_path, bucket_name)
36 |
37 | storage_client = storage.Client()
38 |
39 | bucket = storage_client.bucket(bucket_name)
40 | blob = bucket.blob(blob_name)
41 |
42 | blob.download_to_filename(local_path)
43 | logging.info(
44 | 'Downloaded file. Source: %s, Destination: %s',
45 | gstorage_path, local_path)
46 |
47 |
48 | def make_file_paths_local(paths, local_directory):
49 | """Makes sure that given files are locally available.
50 |
51 | If a Cloud Storage path is provided, downloads the file and returns the new
52 | path relative to local_directory. If a local path is provided it is returns
53 | path with no modification.
54 |
55 | Args:
56 | paths: Single path or a list of paths.
57 | local_directory: Local path to the directory were downloaded files will be
58 | stored. Note that if you want to download gin configuration files
59 |
60 | Returns:
61 | Single local path or a list of local paths.
62 | """
63 | if isinstance(paths, str):
64 | if re.match('gs://*', paths):
65 | local_name = os.path.basename(paths)
66 | download_from_gstorage(paths, os.path.join(local_directory, local_name))
67 | return local_name
68 | else:
69 | return paths
70 | else:
71 | local_paths = []
72 | for path in paths:
73 | if re.match('gs://*', path):
74 | local_name = os.path.basename(path)
75 | download_from_gstorage(path, os.path.join(local_directory, local_name))
76 | local_paths.append(local_name)
77 | else:
78 | local_paths.append(path)
79 | return local_paths
80 |
81 |
82 | def report_metric_to_hypertune(metric_value, step, tag='Loss'):
83 | """Use hypertune to report metrics for hyperparameter tuning."""
84 | hpt = hypertune.HyperTune()
85 | hpt.report_hyperparameter_tuning_metric(
86 | hyperparameter_metric_tag=tag,
87 | metric_value=metric_value,
88 | global_step=step)
89 |
--------------------------------------------------------------------------------
/ddsp/training/cloud_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.training.cloud."""
16 |
17 | from unittest import mock
18 |
19 | from ddsp.training import cloud
20 | import tensorflow.compat.v2 as tf
21 |
22 |
23 | class MakeFilePathsLocalTest(tf.test.TestCase):
24 |
25 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True)
26 | def test_single_path_handling(self, download_from_gstorage_function):
27 | """Tests that function returns a single value if given single value."""
28 | path = cloud.make_file_paths_local(
29 | 'gs://bucket-name/bucket/dir/some_file.gin',
30 | 'gin/search/path')
31 | download_from_gstorage_function.assert_called_once()
32 | self.assertEqual(path, 'some_file.gin')
33 |
34 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True)
35 | def test_single_local_path_handling(self, download_from_gstorage_function):
36 | """Tests that function does nothing if given local file path."""
37 | path = cloud.make_file_paths_local(
38 | 'local_file.gin',
39 | 'gin/search/path')
40 | download_from_gstorage_function.assert_not_called()
41 | self.assertEqual(path, 'local_file.gin')
42 |
43 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True)
44 | def test_single_path_in_list_handling(self, download_from_gstorage_function):
45 | """Tests that function returns a single-element list if given one."""
46 | path = cloud.make_file_paths_local(
47 | ['gs://bucket-name/bucket/dir/some_file.gin'],
48 | 'gin/search/path')
49 | download_from_gstorage_function.assert_called_once()
50 | self.assertNotIsInstance(path, str)
51 | self.assertListEqual(path, ['some_file.gin'])
52 |
53 | @mock.patch.object(cloud, 'download_from_gstorage', autospec=True)
54 | def test_more_paths_in_list_handling(self, download_from_gstorage_function):
55 | """Tests that function handle both local and gstorage paths in one list."""
56 | paths = cloud.make_file_paths_local(
57 | ['gs://bucket-name/bucket/dir/first_file.gin',
58 | 'local_file.gin',
59 | 'gs://bucket-name/bucket/dir/second_file.gin'],
60 | 'gin/search/path')
61 | self.assertEqual(download_from_gstorage_function.call_count, 2)
62 | download_from_gstorage_function.assert_has_calls(
63 | [mock.call('gs://bucket-name/bucket/dir/first_file.gin', mock.ANY),
64 | mock.call('gs://bucket-name/bucket/dir/second_file.gin', mock.ANY)])
65 | self.assertListEqual(
66 | paths,
67 | ['first_file.gin', 'local_file.gin', 'second_file.gin'])
68 |
69 |
70 | if __name__ == '__main__':
71 | tf.test.main()
72 |
--------------------------------------------------------------------------------
/ddsp/training/data_preparation/README.md:
--------------------------------------------------------------------------------
1 | # Data Preparation
2 |
3 | Scripts and libraries to prepare datasets for training. For some more example usage, check out the [train_autoencoder](../../colab/demos/train_autoencoder.ipynb) demo.
4 |
5 |
6 | ## Making a TFRecord dataset from your own sounds
7 |
8 | For experiments from the original [ICLR 2020 paper](https://openreview.net/forum?id=B1x1ma4tDr), we need to do some preprocessing on the raw audio to get it into the correct format for training. This involves turning the full audio into short examples (4 seconds by default, but adjustable with flags), inferring the fundamental frequency (or \"pitch\") with [CREPE](http://github.com/marl/crepe), and computing the loudness. These features will then be stored in a sharded [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) file for easier loading. Depending on the amount of input audio, this process usually takes a few minutes.
9 |
10 | ```
11 | ddsp_prepare_tfrecord \
12 | --input_audio_filepatterns=/path/to/wavs/*wav,/path/to/mp3s/*mp3 \
13 | --output_tfrecord_path=/path/to/output.tfrecord \
14 | --num_shards=10 \
15 | --alsologtostderr
16 | ```
17 |
18 | ## Making a TFRecord dataset from synthetic data
19 |
20 | For experiments from the [ICML 2020 workshop paper](https://goo.gl/magenta/ddsp-inv) of performing pitch detection with inverse synthesis, we need to create a synthetic dataset to use for supervision. We must also specify the data generation function `generate_examples.generate_fn` we want to use as a `--gin_param` flag or in a `--gin_file`.
21 |
22 | ```
23 | ddsp_generate_synthetic_dataset \
24 | --output_tfrecord_path=/path/to/output.tfrecord \
25 | --num_shards=1000 \
26 | --gin_param="generate_examples.generate_fn = @generate_notes_v2" \
27 | --num_examples=10000000 \
28 | --alsologtostderr
29 | ```
30 |
31 |
--------------------------------------------------------------------------------
/ddsp/training/data_preparation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Data preparation code for DDSP models."""
16 |
17 | from ddsp.training.data_preparation import prepare_tfrecord_lib
18 |
--------------------------------------------------------------------------------
/ddsp/training/data_preparation/ddsp_generate_synthetic_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Apache Beam pipeline for computing TFRecord dataset of synthetic examples.
16 |
17 | Example usage:
18 | ==============
19 | ddsp_generate_synthetic_dataset \
20 | --output_tfrecord_path=/tmp/synthetic_data.tfrecord \
21 | --num_shards=1 \
22 | --gin_param="generate_examples.generate_fn = @generate_notes_v2" \
23 | --num_examples=100 \
24 | --alsologtostderr
25 |
26 | For the ICML workshop paper, we created 10,000,000 examples in 1000 shards.
27 | """
28 |
29 | from absl import app
30 | from absl import flags
31 | import apache_beam as beam
32 | from ddsp.training.data_preparation import synthetic_data # pylint:disable=unused-import
33 | import gin
34 | import numpy as np
35 | import pkg_resources
36 | import tensorflow.compat.v2 as tf
37 |
38 |
39 | GIN_PATH = pkg_resources.resource_filename(__name__, '../gin')
40 |
41 | FLAGS = flags.FLAGS
42 |
43 | flags.DEFINE_string(
44 | 'output_tfrecord_path', None,
45 | 'The prefix path to the output TFRecord. Shard numbers will be added to '
46 | 'actual path(s).')
47 | flags.DEFINE_integer(
48 | 'num_shards', None,
49 | 'The number of shards to use for the TFRecord. If None, this number will '
50 | 'be determined automatically.')
51 | flags.DEFINE_integer(
52 | 'num_examples', 1000000,
53 | 'The total number of synthetic examples to generate.')
54 | flags.DEFINE_integer(
55 | 'random_seed', 42,
56 | 'Random seed to use for deterministic generation.')
57 | flags.DEFINE_list(
58 | 'pipeline_options', '--runner=DirectRunner',
59 | 'A comma-separated list of command line arguments to be used as options '
60 | 'for the Beam Pipeline.')
61 |
62 | # Gin config flags.
63 | flags.DEFINE_multi_string('gin_search_path', [],
64 | 'Additional gin file search paths.')
65 | flags.DEFINE_multi_string('gin_file', [], 'List of paths to the config files.')
66 | flags.DEFINE_multi_string('gin_param', [],
67 | 'Newline separated list of Gin parameter bindings.')
68 |
69 |
70 | class GenerateExampleFn(beam.DoFn):
71 | """Gin-configurable wrapper to generate synthetic examples."""
72 |
73 | def __init__(self, gin_str, **kwargs):
74 | super().__init__(**kwargs)
75 | self._gin_str = gin_str
76 |
77 | def start_bundle(self):
78 | with gin.unlock_config():
79 | gin.parse_config(self._gin_str)
80 |
81 | @gin.configurable('generate_examples')
82 | def process(self, seed, generate_fn=gin.REQUIRED):
83 | np.random.seed(seed)
84 | batch = generate_fn(n_batch=1)
85 | beam.metrics.Metrics.counter('GenerateExampleFn', 'generated').inc()
86 | yield {k: v[0].numpy() for k, v in batch.items()}
87 |
88 |
89 | def _float_dict_to_tfexample(float_dict):
90 | """Convert dictionary of float arrays to tf.train.Example proto."""
91 | return tf.train.Example(
92 | features=tf.train.Features(
93 | feature={
94 | k: tf.train.Feature(
95 | float_list=tf.train.FloatList(value=v.flatten()))
96 | for k, v in float_dict.items()
97 | }
98 | ))
99 |
100 |
101 | def run():
102 | """Run the beam pipeline to create synthetic dataset."""
103 | pipeline_options = beam.options.pipeline_options.PipelineOptions(
104 | FLAGS.pipeline_options)
105 | with beam.Pipeline(options=pipeline_options) as pipeline:
106 | for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
107 | gin.add_config_file_search_path(gin_search_path)
108 | gin.parse_config_files_and_bindings(
109 | FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True)
110 |
111 | np.random.seed(FLAGS.random_seed)
112 | _ = (
113 | pipeline
114 | | beam.Create(np.random.randint(2**32, size=FLAGS.num_examples))
115 | | beam.ParDo(GenerateExampleFn(gin.config_str()))
116 | | beam.Reshuffle()
117 | | beam.Map(_float_dict_to_tfexample)
118 | | beam.io.tfrecordio.WriteToTFRecord(
119 | FLAGS.output_tfrecord_path,
120 | num_shards=FLAGS.num_shards,
121 | coder=beam.coders.ProtoCoder(tf.train.Example))
122 | )
123 |
124 |
125 | def main(unused_argv):
126 | """From command line."""
127 | run()
128 |
129 |
130 | def console_entry_point():
131 | """From pip installed script."""
132 | app.run(main)
133 |
134 |
135 | if __name__ == '__main__':
136 | console_entry_point()
137 |
--------------------------------------------------------------------------------
/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Create a TFRecord dataset from audio files.
16 |
17 | Usage:
18 | ====================
19 | ddsp_prepare_tfrecord \
20 | --input_audio_filepatterns=/path/to/wavs/*wav,/path/to/mp3s/*mp3 \
21 | --output_tfrecord_path=/path/to/output.tfrecord \
22 | --num_shards=10 \
23 | --alsologtostderr
24 |
25 | """
26 |
27 | from absl import app
28 | from absl import flags
29 | from ddsp.training.data_preparation.prepare_tfrecord_lib import prepare_tfrecord
30 | import tensorflow.compat.v2 as tf
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | flags.DEFINE_list('input_audio_filepatterns', [],
35 | 'List of filepatterns to glob for input audio files.')
36 | flags.DEFINE_string(
37 | 'output_tfrecord_path', None,
38 | 'The prefix path to the output TFRecord. Shard numbers will be added to '
39 | 'actual path(s).')
40 | flags.DEFINE_integer(
41 | 'num_shards', None,
42 | 'The number of shards to use for the TFRecord. If None, this number will '
43 | 'be determined automatically.')
44 | flags.DEFINE_integer('sample_rate', 16000,
45 | 'The sample rate to use for the audio.')
46 | flags.DEFINE_integer(
47 | 'frame_rate', 250,
48 | 'The frame rate to use for f0 and loudness features. If set to 0, '
49 | 'these features will not be computed.')
50 | flags.DEFINE_float(
51 | 'example_secs', 4,
52 | 'The length of each example in seconds. Input audio will be split to this '
53 | 'length using a sliding window. If 0, each full piece of audio will be '
54 | 'used as an example.')
55 | flags.DEFINE_float(
56 | 'hop_secs', 1,
57 | 'The hop size between example start points (in seconds), when splitting '
58 | 'audio into constant-length examples.')
59 | flags.DEFINE_float(
60 | 'eval_split_fraction', 0.0,
61 | 'Fraction of the dataset to reserve for eval split. If set to 0, no eval '
62 | 'split is created.'
63 | )
64 | flags.DEFINE_float(
65 | 'chunk_secs', 20.0,
66 | 'Chunk size in seconds used to split the input audio files. These '
67 | 'non-overlapping chunks are partitioned into train and eval sets if '
68 | 'eval_split_fraction > 0. This is used to split large audio files into '
69 | 'manageable chunks for better parallelization and to enable '
70 | 'non-overlapping train/eval splits.')
71 | flags.DEFINE_boolean(
72 | 'center', False,
73 | 'Add padding to audio such that frame timestamps are centered. Increases '
74 | 'number of frames by one.')
75 | flags.DEFINE_boolean(
76 | 'viterbi', True,
77 | 'Use viterbi decoding of pitch.')
78 | flags.DEFINE_list(
79 | 'pipeline_options', '--runner=DirectRunner',
80 | 'A comma-separated list of command line arguments to be used as options '
81 | 'for the Beam Pipeline.')
82 |
83 |
84 | def run():
85 | input_audio_paths = []
86 | for filepattern in FLAGS.input_audio_filepatterns:
87 | input_audio_paths.extend(tf.io.gfile.glob(filepattern))
88 |
89 | prepare_tfrecord(
90 | input_audio_paths,
91 | FLAGS.output_tfrecord_path,
92 | num_shards=FLAGS.num_shards,
93 | sample_rate=FLAGS.sample_rate,
94 | frame_rate=FLAGS.frame_rate,
95 | example_secs=FLAGS.example_secs,
96 | hop_secs=FLAGS.hop_secs,
97 | eval_split_fraction=FLAGS.eval_split_fraction,
98 | chunk_secs=FLAGS.chunk_secs,
99 | center=FLAGS.center,
100 | viterbi=FLAGS.viterbi,
101 | pipeline_options=FLAGS.pipeline_options)
102 |
103 |
104 | def main(unused_argv):
105 | """From command line."""
106 | run()
107 |
108 |
109 | def console_entry_point():
110 | """From pip installed script."""
111 | app.run(main)
112 |
113 |
114 | if __name__ == '__main__':
115 | console_entry_point()
116 |
--------------------------------------------------------------------------------
/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.training.data_preparation.prepare_tfrecord_lib."""
16 |
17 | import os
18 | import sys
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 | from ddsp import spectral_ops
24 | from ddsp.training.data_preparation import prepare_tfrecord_lib
25 | import numpy as np
26 | import scipy.io.wavfile
27 | import tensorflow.compat.v2 as tf
28 |
29 | CREPE_SAMPLE_RATE = spectral_ops.CREPE_SAMPLE_RATE
30 |
31 |
32 | class PrepareTFRecordBeamTest(parameterized.TestCase):
33 |
34 | def get_tempdir(self):
35 | try:
36 | flags.FLAGS.test_tmpdir
37 | except flags.UnparsedFlagAccessError:
38 | # Need to initialize flags when running `pytest`.
39 | flags.FLAGS(sys.argv)
40 | return self.create_tempdir().full_path
41 |
42 | def setUp(self):
43 | super().setUp()
44 | self.test_dir = self.get_tempdir()
45 |
46 | # Write test wav file.
47 | self.wav_sr = 22050
48 | self.wav_secs = 0.5
49 | self.wav_path = os.path.join(self.test_dir, 'test.wav')
50 | scipy.io.wavfile.write(
51 | self.wav_path,
52 | self.wav_sr,
53 | np.random.randint(
54 | np.iinfo(np.int16).min, np.iinfo(np.int16).max,
55 | size=int(self.wav_sr * self.wav_secs), dtype=np.int16))
56 |
57 | def parse_tfrecord(self, path):
58 | return [tf.train.Example.FromString(record.numpy()) for record in
59 | tf.data.TFRecordDataset(os.path.join(self.test_dir, path))]
60 |
61 | def validate_outputs(self, expected_num_examples, expected_feature_lengths):
62 | all_examples = (
63 | self.parse_tfrecord('output.tfrecord-00000-of-00002') +
64 | self.parse_tfrecord('output.tfrecord-00001-of-00002'))
65 |
66 | self.assertLen(all_examples, expected_num_examples)
67 | for ex in all_examples:
68 | self.assertCountEqual(expected_feature_lengths, ex.features.feature)
69 |
70 | for feat, expected_len in expected_feature_lengths.items():
71 | arr = ex.features.feature[feat].float_list.value
72 | try:
73 | self.assertLen(arr, expected_len)
74 | except AssertionError as e:
75 | raise AssertionError('feature: %s' % feat) from e
76 | self.assertFalse(any(np.isinf(arr)))
77 |
78 | def get_expected_length(self, input_length, frame_rate, center=False):
79 | sample_rate = 16000 # Features at CREPE_SAMPLE_RATE.
80 | frame_size = 1024 # Unused for this calculation.
81 | hop_size = sample_rate // frame_rate
82 | padding = 'center' if center else 'same'
83 | n_frames, _ = spectral_ops.get_framed_lengths(
84 | input_length, frame_size, hop_size, padding)
85 | return n_frames
86 |
87 | @staticmethod
88 | def get_n_per_chunk(chunk_length, example_secs, hop_secs):
89 | """Convenience function to calculate number examples from striding."""
90 | n = (chunk_length - example_secs) / hop_secs
91 | # Deal with limited float precision that causes (.3 / .1) = 2.9999....
92 | return int(np.floor(np.round(n, decimals=3))) + 1
93 |
94 | @parameterized.named_parameters(
95 | ('chunk_and_split', 0.3, 0.2),
96 | ('no_chunk', None, 0.2),
97 | ('no_split', 0.3, None),
98 | ('no_chunk_no_split', None, None),
99 | )
100 | def test_prepare_tfrecord(self, chunk_secs, example_secs):
101 | sample_rate = 16000
102 | frame_rate = 250
103 | hop_secs = 0.1
104 |
105 | # Calculate expected batch size.
106 | if example_secs:
107 | length = chunk_secs if chunk_secs else self.wav_secs
108 | n_per_chunk = self.get_n_per_chunk(length, example_secs, hop_secs)
109 | else:
110 | n_per_chunk = 1
111 |
112 | n_chunks = int(np.ceil(self.wav_secs / chunk_secs)) if chunk_secs else 1
113 | expected_n_batch = n_per_chunk * n_chunks
114 | print('n_chunks, n_per_chunk, chunk_secs, example_secs',
115 | n_chunks, n_per_chunk, chunk_secs, example_secs)
116 |
117 | # Calculate expected lengths.
118 | if example_secs:
119 | length = example_secs
120 | elif chunk_secs:
121 | length = chunk_secs
122 | else:
123 | length = self.wav_secs
124 |
125 | expected_n_t = int(length * sample_rate)
126 | expected_n_frames = self.get_expected_length(expected_n_t, frame_rate)
127 |
128 | # Make the actual records.
129 | prepare_tfrecord_lib.prepare_tfrecord(
130 | [self.wav_path],
131 | os.path.join(self.test_dir, 'output.tfrecord'),
132 | num_shards=2,
133 | sample_rate=sample_rate,
134 | frame_rate=frame_rate,
135 | example_secs=example_secs,
136 | hop_secs=hop_secs,
137 | chunk_secs=chunk_secs,
138 | center=False)
139 |
140 | self.validate_outputs(
141 | expected_n_batch,
142 | {
143 | 'audio': expected_n_t,
144 | 'audio_16k': expected_n_t,
145 | 'f0_hz': expected_n_frames,
146 | 'f0_confidence': expected_n_frames,
147 | 'loudness_db': expected_n_frames,
148 | })
149 |
150 | @parameterized.named_parameters(('no_center', False), ('center', True))
151 | def test_centering(self, center):
152 | frame_rate = 250
153 | sample_rate = 16000
154 | example_secs = 0.3
155 | hop_secs = 0.1
156 | n_batch = self.get_n_per_chunk(self.wav_secs, example_secs, hop_secs)
157 | prepare_tfrecord_lib.prepare_tfrecord(
158 | [self.wav_path],
159 | os.path.join(self.test_dir, 'output.tfrecord'),
160 | num_shards=2,
161 | sample_rate=sample_rate,
162 | frame_rate=frame_rate,
163 | example_secs=example_secs,
164 | hop_secs=hop_secs,
165 | center=center,
166 | chunk_secs=None)
167 |
168 | n_t = int(example_secs * sample_rate)
169 | n_frames = self.get_expected_length(n_t, frame_rate, center)
170 | n_expected_frames = 76 if center else 75 # (250 * 0.3) [+1].
171 | self.assertEqual(n_frames, n_expected_frames)
172 | self.validate_outputs(
173 | n_batch, {
174 | 'audio': n_t,
175 | 'audio_16k': n_t,
176 | 'f0_hz': n_frames,
177 | 'f0_confidence': n_frames,
178 | 'loudness_db': n_frames,
179 | })
180 |
181 | @parameterized.named_parameters(
182 | ('16kHz', 16000),
183 | ('32kHz', 32000),
184 | ('48kHz', 48000))
185 | def test_sample_rate(self, sample_rate):
186 | frame_rate = 250
187 | example_secs = 0.3
188 | hop_secs = 0.1
189 | center = True
190 | n_batch = self.get_n_per_chunk(self.wav_secs, example_secs, hop_secs)
191 | prepare_tfrecord_lib.prepare_tfrecord(
192 | [self.wav_path],
193 | os.path.join(self.test_dir, 'output.tfrecord'),
194 | num_shards=2,
195 | sample_rate=sample_rate,
196 | frame_rate=frame_rate,
197 | example_secs=example_secs,
198 | hop_secs=hop_secs,
199 | center=center,
200 | chunk_secs=None)
201 |
202 | n_t = int(example_secs * sample_rate)
203 | n_t_16k = int(example_secs * CREPE_SAMPLE_RATE)
204 | n_frames = self.get_expected_length(n_t_16k, frame_rate, center)
205 | n_expected_frames = 76 # (250 * 0.3) + 1.
206 | self.assertEqual(n_frames, n_expected_frames)
207 | self.validate_outputs(
208 | n_batch, {
209 | 'audio': n_t,
210 | 'audio_16k': n_t_16k,
211 | 'f0_hz': n_frames,
212 | 'f0_confidence': n_frames,
213 | 'loudness_db': n_frames,
214 | })
215 |
216 | @parameterized.named_parameters(('16kHz', 16000), ('44.1kHz', 44100),
217 | ('48kHz', 48000))
218 | def test_audio_only(self, sample_rate):
219 | prepare_tfrecord_lib.prepare_tfrecord(
220 | [self.wav_path],
221 | os.path.join(self.test_dir, 'output.tfrecord'),
222 | num_shards=2,
223 | sample_rate=sample_rate,
224 | frame_rate=None,
225 | example_secs=None,
226 | chunk_secs=None)
227 |
228 | self.validate_outputs(
229 | 1, {
230 | 'audio': int(self.wav_secs * sample_rate),
231 | 'audio_16k': int(self.wav_secs * CREPE_SAMPLE_RATE),
232 | })
233 |
234 |
235 | if __name__ == '__main__':
236 | absltest.main()
237 |
--------------------------------------------------------------------------------
/ddsp/training/ddsp_run.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Train, evaluate, or sample (from) a ddsp model.
16 |
17 | Usage:
18 | ================================================================================
19 | For training, you need to specify --gin_file for both the model and the dataset.
20 | You can optionally specify additional params with --gin_param.
21 | The pip install installs a `ddsp_run` script that can be called directly.
22 | ================================================================================
23 | ddsp_run \
24 | --mode=train \
25 | --alsologtostderr \
26 | --save_dir=/tmp/$USER-ddsp-0 \
27 | --gin_file=models/ae.gin \
28 | --gin_file=datasets/nsynth.gin \
29 | --gin_param=batch_size=16
30 |
31 |
32 | ================================================================================
33 | For evaluation and sampling, only the dataset file is required.
34 | ================================================================================
35 | ddsp_run \
36 | --mode=eval \
37 | --alsologtostderr \
38 | --save_dir=/tmp/$USER-ddsp-0 \
39 | --gin_file=datasets/nsynth.gin
40 |
41 | ddsp_run \
42 | --mode=sample \
43 | --alsologtostderr \
44 | --save_dir=/tmp/$USER-ddsp-0 \
45 | --gin_file=datasets/nsynth.gin
46 |
47 |
48 | ================================================================================
49 | The directory `gin/papers/` stores configs that give the specific models and
50 | datasets used for a paper's experiments, so only require one gin file to train.
51 | ================================================================================
52 | ddsp_run \
53 | --mode=train \
54 | --alsologtostderr \
55 | --save_dir=/tmp/$USER-ddsp-0 \
56 | --gin_file=papers/iclr2020/nsynth_ae.gin
57 |
58 |
59 | """
60 |
61 | import os
62 | import time
63 |
64 | from absl import app
65 | from absl import flags
66 | from absl import logging
67 | from ddsp.training import cloud
68 | from ddsp.training import eval_util
69 | from ddsp.training import models
70 | from ddsp.training import train_util
71 | from ddsp.training import trainers
72 | import gin
73 | import pkg_resources
74 | import tensorflow as tf
75 |
76 | gfile = tf.io.gfile
77 | FLAGS = flags.FLAGS
78 |
79 | # Program flags.
80 | flags.DEFINE_enum('mode', 'train', ['train', 'eval', 'sample'],
81 | 'Whether to train, evaluate, or sample from the model.')
82 | flags.DEFINE_string('save_dir', '/tmp/ddsp',
83 | 'Path where checkpoints and summary events will be saved '
84 | 'during training and evaluation.')
85 | flags.DEFINE_string('restore_dir', '',
86 | 'Path from which checkpoints will be restored before '
87 | 'training. Can be different than the save_dir.')
88 | flags.DEFINE_string('tpu', '', 'Address of the TPU. No TPU if left blank.')
89 | flags.DEFINE_string('cluster_config', '',
90 | 'Worker-specific JSON string for multiworker setup. '
91 | 'For more information see train_util.get_strategy().')
92 | flags.DEFINE_boolean('allow_memory_growth', False,
93 | 'Whether to grow the GPU memory usage as is needed by the '
94 | 'process. Prevents crashes on GPUs with smaller memory.')
95 | flags.DEFINE_boolean('hypertune', False,
96 | 'Enable metric reporting for hyperparameter tuning, such '
97 | 'as on Google Cloud AI-Platform.')
98 | flags.DEFINE_float('early_stop_loss_value', None,
99 | 'Stops training early when the `total_loss` reaches below '
100 | 'this value during training.')
101 |
102 | # Gin config flags.
103 | flags.DEFINE_multi_string('gin_search_path', [],
104 | 'Additional gin file search paths.')
105 | flags.DEFINE_multi_string('gin_file', [],
106 | 'List of paths to the config files. If file '
107 | 'in gstorage bucket specify whole gstorage path: '
108 | 'gs://bucket-name/dir/in/bucket/file.gin.')
109 | flags.DEFINE_multi_string('gin_param', [],
110 | 'Newline separated list of Gin parameter bindings.')
111 |
112 | # Evaluation/sampling specific flags.
113 | flags.DEFINE_boolean('run_once', False, 'Whether evaluation will run once.')
114 | flags.DEFINE_integer('initial_delay_secs', None,
115 | 'Time to wait before evaluation starts')
116 |
117 | GIN_PATH = pkg_resources.resource_filename(__name__, 'gin')
118 |
119 |
120 | def delay_start():
121 | """Optionally delay the start of the run."""
122 | delay_time = FLAGS.initial_delay_secs
123 | if delay_time:
124 | logging.info('Waiting for %i second(s)', delay_time)
125 | time.sleep(delay_time)
126 |
127 |
128 | def parse_gin(restore_dir):
129 | """Parse gin config from --gin_file, --gin_param, and the model directory."""
130 | # Enable parsing gin files on Google Cloud.
131 | gin.config.register_file_reader(tf.io.gfile.GFile, tf.io.gfile.exists)
132 | # Add user folders to the gin search path.
133 | for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
134 | gin.add_config_file_search_path(gin_search_path)
135 |
136 | # Parse gin configs, later calls override earlier ones.
137 | with gin.unlock_config():
138 | # Optimization defaults.
139 | use_tpu = bool(FLAGS.tpu)
140 | opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
141 | gin.parse_config_file(os.path.join('optimization', opt_default))
142 | eval_default = 'eval/basic.gin'
143 | gin.parse_config_file(eval_default)
144 |
145 | # Load operative_config if it exists (model has already trained).
146 | try:
147 | operative_config = train_util.get_latest_operative_config(restore_dir)
148 | logging.info('Using operative config: %s', operative_config)
149 | operative_config = cloud.make_file_paths_local(operative_config, GIN_PATH)
150 | gin.parse_config_file(operative_config, skip_unknown=True)
151 | except FileNotFoundError:
152 | logging.info('Operative config not found in %s', restore_dir)
153 |
154 | # User gin config and user hyperparameters from flags.
155 | gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
156 | gin.parse_config_files_and_bindings(
157 | gin_file, FLAGS.gin_param, skip_unknown=True)
158 |
159 |
160 | def allow_memory_growth():
161 | """Sets the GPUs to grow the memory usage as is needed by the process."""
162 | gpus = tf.config.experimental.list_physical_devices('GPU')
163 | if gpus:
164 | try:
165 | # Currently, memory growth needs to be the same across GPUs.
166 | for gpu in gpus:
167 | tf.config.experimental.set_memory_growth(gpu, True)
168 | except RuntimeError as e:
169 | # Memory growth must be set before GPUs have been initialized.
170 | print(e)
171 |
172 |
173 | def main(unused_argv):
174 | """Parse gin config and run ddsp training, evaluation, or sampling."""
175 | restore_dir = os.path.expanduser(FLAGS.restore_dir)
176 | save_dir = os.path.expanduser(FLAGS.save_dir)
177 | # If no separate restore directory is given, use the save directory.
178 | restore_dir = save_dir if not restore_dir else restore_dir
179 | logging.info('Restore Dir: %s', restore_dir)
180 | logging.info('Save Dir: %s', save_dir)
181 |
182 | gfile.makedirs(restore_dir) # Only makes dirs if they don't exist.
183 | parse_gin(restore_dir)
184 | logging.info('Operative Gin Config:\n%s', gin.config.config_str())
185 |
186 | if FLAGS.allow_memory_growth:
187 | allow_memory_growth()
188 |
189 | # Training.
190 | if FLAGS.mode == 'train':
191 | strategy = train_util.get_strategy(tpu=FLAGS.tpu,
192 | cluster_config=FLAGS.cluster_config)
193 | with strategy.scope():
194 | model = models.get_model()
195 | trainer = trainers.get_trainer_class()(model, strategy)
196 |
197 | train_util.train(data_provider=gin.REQUIRED,
198 | trainer=trainer,
199 | save_dir=save_dir,
200 | restore_dir=restore_dir,
201 | early_stop_loss_value=FLAGS.early_stop_loss_value,
202 | report_loss_to_hypertune=FLAGS.hypertune)
203 |
204 | # Evaluation.
205 | elif FLAGS.mode == 'eval':
206 | model = models.get_model()
207 | delay_start()
208 | eval_util.evaluate(data_provider=gin.REQUIRED,
209 | model=model,
210 | save_dir=save_dir,
211 | restore_dir=restore_dir,
212 | run_once=FLAGS.run_once)
213 |
214 | # Sampling.
215 | elif FLAGS.mode == 'sample':
216 | model = models.get_model()
217 | delay_start()
218 | eval_util.sample(data_provider=gin.REQUIRED,
219 | model=model,
220 | save_dir=save_dir,
221 | restore_dir=restore_dir,
222 | run_once=FLAGS.run_once)
223 |
224 |
225 | def console_entry_point():
226 | """From pip installed script."""
227 | app.run(main)
228 |
229 |
230 | if __name__ == '__main__':
231 | console_entry_point()
232 |
--------------------------------------------------------------------------------
/ddsp/training/decoders_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.training.decoders."""
16 |
17 | import functools
18 |
19 | from absl.testing import parameterized
20 | from ddsp.training import decoders
21 | import numpy as np
22 | import tensorflow.compat.v2 as tf
23 |
24 |
25 | class DilatedConvDecoderTest(parameterized.TestCase, tf.test.TestCase):
26 |
27 | def setUp(self):
28 | """Create some common default values for decoder."""
29 | super().setUp()
30 | # For decoder.
31 | self.ch = 4
32 | self.layers_per_stack = 3
33 | self.stacks = 2
34 | self.output_splits = (('amps', 1), ('harmonic_distribution', 10),
35 | ('noise_magnitudes', 10))
36 |
37 | # For audio features and conditioning.
38 | self.frame_rate = 100
39 | self.length_in_sec = 0.20
40 | self.time_steps = int(self.frame_rate * self.length_in_sec)
41 |
42 | def _gen_dummy_conditioning(self):
43 | """Generate dummy scaled f0 and ld conditioning."""
44 | conditioning = {}
45 | # Generate dummy `f0_hz` with batch and channel dims.
46 | f0_hz_dummy = np.repeat(1.0,
47 | self.length_in_sec * self.frame_rate)[np.newaxis, :,
48 | np.newaxis]
49 | conditioning['f0_scaled'] = f0_hz_dummy # Testing correct shapes only.
50 | # Generate dummy `loudness_db` with batch and channel dims.
51 | loudness_db_dummy = np.repeat(1.0, self.length_in_sec *
52 | self.frame_rate)[np.newaxis, :, np.newaxis]
53 | conditioning[
54 | 'ld_scaled'] = loudness_db_dummy # Testing correct shapes only.
55 | return conditioning
56 |
57 | def test_correct_output_splits_and_shapes_dilated_conv_decoder(self):
58 | decoder = decoders.DilatedConvDecoder(
59 | ch=self.ch,
60 | layers_per_stack=self.layers_per_stack,
61 | stacks=self.stacks,
62 | conditioning_keys=None,
63 | output_splits=self.output_splits)
64 |
65 | conditioning = self._gen_dummy_conditioning()
66 | output = decoder(conditioning)
67 | for output_name, output_dim in self.output_splits:
68 | dummy_output = np.zeros((1, self.time_steps, output_dim))
69 | self.assertShapeEqual(dummy_output, output[output_name])
70 |
71 |
72 | class RnnFcDecoderTest(parameterized.TestCase, tf.test.TestCase):
73 |
74 | def setUp(self):
75 | """Create some common default values for decoder."""
76 | super().setUp()
77 | # For decoder.
78 | self.input_keys = ('pw_scaled', 'f0_scaled')
79 | self.output_splits = (('amps', 1), ('harmonic_distribution', 10))
80 | self.n_batch = 2
81 | self.n_t = 4
82 | self.inputs = {
83 | 'pw_scaled': tf.ones([self.n_batch, self.n_t, 1]),
84 | 'f0_scaled': tf.ones([self.n_batch, self.n_t, 1]),
85 | }
86 | self.rnn_ch = 3
87 | self.get_decoder = functools.partial(
88 | decoders.RnnFcDecoder,
89 | rnn_channels=self.rnn_ch,
90 | ch=2,
91 | layers_per_stack=1,
92 | input_keys=self.input_keys,
93 | output_splits=self.output_splits,
94 | rnn_type='gru',
95 | )
96 |
97 | @parameterized.named_parameters(
98 | ('stateful', False),
99 | ('stateless', True),
100 | )
101 | def test_correct_outputs(self, stateless=False):
102 | decoder = self.get_decoder(stateless=stateless)
103 |
104 | # Add state.
105 | inputs = self.inputs
106 | if stateless:
107 | inputs['state'] = tf.ones([self.n_batch, self.rnn_ch])
108 |
109 | # Run through the network
110 | outputs = decoder(inputs)
111 |
112 | # Check normal outputs.
113 | for name, dim in self.output_splits:
114 | dummy_output = np.zeros((self.n_batch, self.n_t, dim))
115 | self.assertShapeEqual(dummy_output, outputs[name])
116 |
117 | # Check the explicit state.
118 | if stateless:
119 | self.assertShapeEqual(inputs['state'], outputs['state'])
120 | self.assertNotAllEqual(inputs['state'], outputs['state'])
121 |
122 |
123 | if __name__ == '__main__':
124 | tf.test.main()
125 |
--------------------------------------------------------------------------------
/ddsp/training/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Uses one of AI Platform base images.
2 | # You can try using different images however only this one has been tested.
3 | FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-2
4 |
5 | # Installs sndfile library for reading and writing audio files.
6 | RUN apt-get update && \
7 | apt-get install --no-install-recommends -y libsndfile-dev
8 |
9 | # Upgrades Tensorflow and Tensorflow Probability
10 | # Newer version of Tensorflow is needed for multiple VMs training
11 | RUN pip install --upgrade pip && \
12 | pip install --upgrade tensorflow tensorflow-probability
13 |
14 | # Installs cloudml-hypertune package needed for hyperparameter tuning
15 | RUN pip install cloudml-hypertune
16 |
17 | WORKDIR /root
18 | # Installs Magenta DDSP from Github.
19 | RUN wget https://github.com//magenta/ddsp/archive/main.zip && \
20 | unzip main.zip && \
21 | cd ddsp-main && \
22 | python setup.py install
23 |
24 | # Copies running script.
25 | COPY task.py task.py
26 |
27 | ENTRYPOINT ["python", "task.py"]
28 |
--------------------------------------------------------------------------------
/ddsp/training/docker/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Docker code."""
16 |
17 | from ddsp.training.docker import task
18 |
--------------------------------------------------------------------------------
/ddsp/training/docker/config_hypertune.yaml:
--------------------------------------------------------------------------------
1 | trainingInput:
2 | scaleTier: CUSTOM
3 | masterType: n1-highmem-64
4 | masterConfig:
5 | acceleratorConfig:
6 | count: 4
7 | type: NVIDIA_TESLA_T4
8 | useChiefInTfConfig: True
9 | hyperparameters:
10 | goal: MINIMIZE
11 | hyperparameterMetricTag: "Loss"
12 | maxTrials: 4
13 | maxParallelTrials: 4
14 | enableTrialEarlyStopping: True
15 | params:
16 | - parameterName: learning_rate
17 | type: DISCRETE
18 | discreteValues:
19 | - 0.0001
20 | - 3e-4
21 | - 0.001
22 | - 0.01
23 |
--------------------------------------------------------------------------------
/ddsp/training/docker/config_multiple_vms.yaml:
--------------------------------------------------------------------------------
1 | # Recommended batch_size: 128
2 | # Recommended learning_rate: 0.001
3 | # Recommended number of steps: 15000
4 | #
5 | trainingInput:
6 | scaleTier: CUSTOM
7 | # Configures a chief worker with 2 T4 GPUs
8 | masterType: n1-highcpu-32
9 | masterConfig:
10 | acceleratorConfig:
11 | count: 2
12 | type: NVIDIA_TESLA_T4
13 | # Configures 3 workers, each with 2 T4 GPUs
14 | workerCount: 3
15 | workerType: n1-highcpu-32
16 | workerConfig:
17 | acceleratorConfig:
18 | count: 2
19 | type: NVIDIA_TESLA_T4
20 | # Makes AI Platform naming compatibile with Tensorflow naming
21 | useChiefInTfConfig: True
22 |
--------------------------------------------------------------------------------
/ddsp/training/docker/config_single_vm.yaml:
--------------------------------------------------------------------------------
1 | # Recommended batch_size: 16
2 | # Recommended learning_rate: 0.0001
3 | # Recommended number of steps: 40000
4 | trainingInput:
5 | scaleTier: CUSTOM
6 | # Configures a single worker with 2 NVIDIA T4 GPUs
7 | masterType: n1-highcpu-16
8 | masterConfig:
9 | acceleratorConfig:
10 | count: 1
11 | type: NVIDIA_TESLA_T4
12 | useChiefInTfConfig: True
13 |
--------------------------------------------------------------------------------
/ddsp/training/docker/task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Script for running a containerized training on Google Cloud AI Platform."""
16 |
17 | import json
18 | import os
19 | import subprocess
20 |
21 | from absl import app
22 | from absl import flags
23 |
24 | FLAGS = flags.FLAGS
25 |
26 | flags.DEFINE_string('save_dir', None,
27 | 'Path where checkpoints and summary events will be saved '
28 | 'during training and evaluation.')
29 | flags.DEFINE_string('restore_dir', '',
30 | 'Path from which checkpoints will be restored before '
31 | 'training. Can be different than the save_dir.')
32 | flags.DEFINE_string('file_pattern', None, 'Data file pattern')
33 |
34 | flags.DEFINE_integer('batch_size', 32, 'Batch size')
35 | flags.DEFINE_float('learning_rate', 0.003, 'Learning rate')
36 | flags.DEFINE_integer('num_steps', 30000, 'Number of training steps')
37 | flags.DEFINE_float('early_stop_loss_value', 0.0,
38 | 'Early stopping. When the total_loss reaches below this '
39 | 'value training stops.')
40 |
41 | flags.DEFINE_integer('steps_per_summary', 300, 'Steps per summary')
42 | flags.DEFINE_integer('steps_per_save', 300, 'Steps per save')
43 | flags.DEFINE_boolean('hypertune', False,
44 | 'Enable metric reporting for hyperparameter tuning.')
45 |
46 |
47 | flags.DEFINE_multi_string('gin_search_path', [],
48 | 'Additional gin file search paths. '
49 | 'Must be paths inside Docker container and '
50 | 'necessary gin configs should be added at the '
51 | 'Docker image building stage.')
52 | flags.DEFINE_multi_string('gin_file', [],
53 | 'List of paths to the config files. If file '
54 | 'in gstorage bucket specify whole gstorage path: '
55 | 'gs://bucket-name/dir/in/bucket/file.gin. If path '
56 | 'should be local remember about copying the file '
57 | 'inside the Docker container at building stage. ')
58 | flags.DEFINE_multi_string('gin_param', [],
59 | 'Newline separated list of Gin parameter bindings.')
60 |
61 |
62 | def get_worker_behavior_info(save_dir):
63 | """Infers worker behavior from the environment.
64 |
65 | Checks if TF_CONFIG environment variable is set
66 | and inferes cluster configuration and save_dir
67 | from it.
68 |
69 | Args:
70 | save_dir: Save directory given by the user.
71 |
72 | Returns:
73 | cluster_config: Inferred cluster configuration.
74 | save_dir: Inferred save directory.
75 | """
76 | if 'TF_CONFIG' in os.environ:
77 | cluster_config = os.environ['TF_CONFIG']
78 | cluster_config_dict = json.loads(cluster_config)
79 | if ('cluster' not in cluster_config_dict.keys() or
80 | 'task' not in cluster_config_dict or
81 | len(cluster_config_dict['cluster']) <= 1):
82 | cluster_config = ''
83 | elif cluster_config_dict['task']['type'] != 'chief':
84 | save_dir = ''
85 | else:
86 | cluster_config = ''
87 |
88 | return cluster_config, save_dir
89 |
90 |
91 | def parse_list_params(list_of_params, param_name):
92 | return [f'--{param_name}={param}' for param in list_of_params]
93 |
94 |
95 | def main(unused_argv):
96 | restore_dir = FLAGS.save_dir if not FLAGS.restore_dir else FLAGS.restore_dir
97 |
98 | cluster_config, save_dir = get_worker_behavior_info(FLAGS.save_dir)
99 | gin_search_path = parse_list_params(FLAGS.gin_search_path, 'gin_search_path')
100 | gin_file = parse_list_params(FLAGS.gin_file, 'gin_file')
101 | gin_param = parse_list_params(FLAGS.gin_param, 'gin_param')
102 |
103 | ddsp_run_command = (
104 | ['ddsp_run',
105 | '--mode=train',
106 | '--alsologtostderr',
107 | '--gin_file=models/solo_instrument.gin',
108 | '-gin_file=datasets/tfrecord.gin',
109 | f'--cluster_config={cluster_config}',
110 | f'--save_dir={save_dir}',
111 | f'--restore_dir={restore_dir}',
112 | f'--hypertune={FLAGS.hypertune}',
113 | f'--early_stop_loss_value={FLAGS.early_stop_loss_value}',
114 | f'--gin_param=batch_size={FLAGS.batch_size}',
115 | f'--gin_param=learning_rate={FLAGS.learning_rate}',
116 | f'--gin_param=TFRecordProvider.file_pattern=\'{FLAGS.file_pattern}\'',
117 | f'--gin_param=train_util.train.num_steps={FLAGS.num_steps}',
118 | f'--gin_param=train_util.train.steps_per_save={FLAGS.steps_per_save}',
119 | ('--gin_param=train_util.train.steps_per_summary='
120 | f'{FLAGS.steps_per_summary}')]
121 | + gin_search_path + gin_file + gin_param)
122 |
123 | subprocess.run(args=ddsp_run_command, check=True)
124 |
125 | if __name__ == '__main__':
126 | flags.mark_flag_as_required('file_pattern')
127 | flags.mark_flag_as_required('save_dir')
128 | app.run(main)
129 |
--------------------------------------------------------------------------------
/ddsp/training/docker/task_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for task.py."""
16 |
17 | import os
18 | from unittest import mock
19 |
20 | from ddsp.training.docker import task
21 | import tensorflow.compat.v2 as tf
22 |
23 |
24 | class GetWorkerBehaviorInfoTest(tf.test.TestCase):
25 |
26 | def test_no_tf_config(self):
27 | """Tests behavior when there is no TF_CONFIG set."""
28 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/')
29 | self.assertEqual(save_dir, 'some/dir/')
30 | self.assertEqual(cluster_config, '')
31 |
32 | def test_incomplete_tf_config(self):
33 | """Test behavior when set TF_CONFIG is incomplete."""
34 | with mock.patch.dict(os.environ, {'TF_CONFIG': '{"cluster": {}}'}):
35 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/')
36 | self.assertEqual(save_dir, 'some/dir/')
37 | self.assertEqual(cluster_config, '')
38 |
39 | with mock.patch.dict(os.environ, {'TF_CONFIG': '{"task": {}}'}):
40 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/')
41 | self.assertEqual(save_dir, 'some/dir/')
42 | self.assertEqual(cluster_config, '')
43 |
44 | @mock.patch.dict(
45 | os.environ,
46 | {'TF_CONFIG': ('{"cluster": {"worker": ["worker0.example.com:2221"]},'
47 | '"task": {"type": "worker", "index": 0}}')})
48 | def test_single_worker(self):
49 | """Tests behavior when cluster has only one worker."""
50 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/')
51 | self.assertEqual(save_dir, 'some/dir/')
52 | self.assertEqual(cluster_config, '')
53 |
54 | @mock.patch.dict(
55 | os.environ,
56 | {'TF_CONFIG': ('{"cluster": {"worker": ["worker0.example.com:2221"],'
57 | '"chief": ["chief.example.com:2222"]},'
58 | '"task": {"type": "chief", "index": 0}}')})
59 | def test_multi_worker_as_chief(self):
60 | """Tests multi-worker behavior when task type chief is set in TF_CONFIG."""
61 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/')
62 | self.assertEqual(save_dir, 'some/dir/')
63 | self.assertEqual(
64 | cluster_config,
65 | ('{"cluster": {"worker": ["worker0.example.com:2221"],'
66 | '"chief": ["chief.example.com:2222"]},'
67 | '"task": {"type": "chief", "index": 0}}'))
68 |
69 | @mock.patch.dict(
70 | os.environ,
71 | {'TF_CONFIG': ('{"cluster": {"worker": ["worker0.example.com:2221"],'
72 | '"chief": ["chief.example.com:2222"]},'
73 | '"task": {"type": "worker", "index": 0}}')})
74 | def test_multi_worker_as_worker(self):
75 | """Tests multi-worker behavior when task type worker is set in TF_CONFIG."""
76 | cluster_config, save_dir = task.get_worker_behavior_info('some/dir/')
77 | self.assertEqual(save_dir, '')
78 | self.assertEqual(
79 | cluster_config,
80 | ('{"cluster": {"worker": ["worker0.example.com:2221"],'
81 | '"chief": ["chief.example.com:2222"]},'
82 | '"task": {"type": "worker", "index": 0}}'))
83 |
84 |
85 | if __name__ == '__main__':
86 | tf.test.main()
87 |
--------------------------------------------------------------------------------
/ddsp/training/evaluators.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of evaluator implementations for use in eval_util."""
16 | import ddsp
17 | from ddsp.training import heuristics
18 | from ddsp.training import metrics
19 | from ddsp.training import summaries
20 | import gin
21 | import numpy as np
22 | import tensorflow.compat.v2 as tf
23 |
24 |
25 | class BaseEvaluator(object):
26 | """Base class for evaluators."""
27 |
28 | def __init__(self, sample_rate, frame_rate):
29 | self._sample_rate = sample_rate
30 | self._frame_rate = frame_rate
31 |
32 | def set_rates(self, sample_rate, frame_rate):
33 | """Sets sample and frame rates, not known in gin initialization."""
34 | self._sample_rate = sample_rate
35 | self._frame_rate = frame_rate
36 |
37 | def evaluate(self, batch, output, losses):
38 | """Computes metrics."""
39 | raise NotImplementedError()
40 |
41 | def sample(self, batch, outputs, step):
42 | """Computes and logs samples."""
43 | raise NotImplementedError()
44 |
45 | def flush(self, step):
46 | """Logs metrics."""
47 | raise NotImplementedError()
48 |
49 |
50 |
51 | @gin.register
52 | class BasicEvaluator(BaseEvaluator):
53 | """Computes audio samples and losses."""
54 |
55 | def __init__(self, sample_rate, frame_rate):
56 | super().__init__(sample_rate, frame_rate)
57 | self._avg_losses = {}
58 |
59 | def evaluate(self, batch, outputs, losses):
60 | del outputs # Unused.
61 | if not self._avg_losses:
62 | self._avg_losses = {
63 | name: tf.keras.metrics.Mean(name=name, dtype=tf.float32)
64 | for name in list(losses.keys())
65 | }
66 | # Loss.
67 | for k, v in losses.items():
68 | self._avg_losses[k].update_state(v)
69 |
70 | def sample(self, batch, outputs, step):
71 | audio = batch['audio']
72 | audio_gen = outputs['audio_gen']
73 |
74 | audio_gen = np.array(audio_gen)
75 |
76 | # Add audio.
77 | summaries.audio_summary(
78 | audio_gen, step, self._sample_rate, name='audio_generated')
79 | summaries.audio_summary(
80 | audio, step, self._sample_rate, name='audio_original')
81 |
82 | # Add plots.
83 | summaries.waveform_summary(audio, audio_gen, step)
84 | summaries.spectrogram_summary(audio, audio_gen, step)
85 |
86 | def flush(self, step):
87 | latest_losses = {}
88 | for k, metric in self._avg_losses.items():
89 | latest_losses[k] = metric.result()
90 | tf.summary.scalar('losses/{}'.format(k), metric.result(), step=step)
91 | metric.reset_states()
92 |
93 |
94 | @gin.register
95 | class F0LdEvaluator(BaseEvaluator):
96 | """Computes F0 and loudness metrics."""
97 |
98 | def __init__(self, sample_rate, frame_rate, run_f0_crepe=True):
99 | super().__init__(sample_rate, frame_rate)
100 | self._loudness_metrics = metrics.LoudnessMetrics(
101 | sample_rate=sample_rate, frame_rate=frame_rate)
102 | self._f0_metrics = metrics.F0Metrics(
103 | sample_rate=sample_rate, frame_rate=frame_rate)
104 | self._run_f0_crepe = run_f0_crepe
105 | if self._run_f0_crepe:
106 | self._f0_crepe_metrics = metrics.F0CrepeMetrics(
107 | sample_rate=sample_rate, frame_rate=frame_rate)
108 |
109 | def evaluate(self, batch, outputs, losses):
110 | del losses # Unused.
111 | audio_gen = outputs['audio_gen']
112 | self._loudness_metrics.update_state(batch, audio_gen)
113 |
114 | if 'f0_hz' in outputs and 'f0_hz' in batch:
115 | self._f0_metrics.update_state(batch, outputs['f0_hz'])
116 | elif self._run_f0_crepe:
117 | self._f0_crepe_metrics.update_state(batch, audio_gen)
118 |
119 | def sample(self, batch, outputs, step):
120 | if 'f0_hz' in outputs and 'f0_hz' in batch:
121 | summaries.f0_summary(batch['f0_hz'], outputs['f0_hz'], step,
122 | name='f0_harmonic')
123 |
124 | def flush(self, step):
125 | self._loudness_metrics.flush(step)
126 | self._f0_metrics.flush(step)
127 | if self._run_f0_crepe:
128 | self._f0_crepe_metrics.flush(step)
129 |
130 |
131 | @gin.register
132 | class TWMEvaluator(BaseEvaluator):
133 | """Evaluates F0s created with TWM heuristic."""
134 |
135 | def __init__(self,
136 | sample_rate,
137 | frame_rate,
138 | processor_name='sinusoidal',
139 | noisy=False):
140 | super().__init__(sample_rate, frame_rate)
141 | self._noisy = noisy
142 | self._processor_name = processor_name
143 | self._f0_twm_metrics = metrics.F0Metrics(
144 | sample_rate=sample_rate, frame_rate=frame_rate, name='f0_twm')
145 |
146 | def _compute_twm_f0(self, outputs):
147 | """Computes F0 from sinusoids using TWM heuristic."""
148 | processor_controls = outputs[self._processor_name]['controls']
149 | freqs = processor_controls['frequencies']
150 | amps = processor_controls['amplitudes']
151 | if self._noisy:
152 | noise_ratios = processor_controls['noise_ratios']
153 | amps = amps * (1.0 - noise_ratios)
154 | twm = ddsp.losses.TWMLoss()
155 | # Treat all freqs as candidate f0s.
156 | return twm.predict_f0(freqs, freqs, amps)
157 |
158 | def evaluate(self, batch, outputs, losses):
159 | del losses # Unused.
160 | twm_f0 = self._compute_twm_f0(outputs)
161 | self._f0_twm_metrics.update_state(batch, twm_f0)
162 |
163 | def sample(self, batch, outputs, step):
164 | twm_f0 = self._compute_twm_f0(outputs)
165 | summaries.f0_summary(batch['f0_hz'], twm_f0, step, name='f0_twm')
166 |
167 | def flush(self, step):
168 | self._f0_twm_metrics.flush(step)
169 |
170 |
171 | @gin.register
172 | class MidiAutoencoderEvaluator(BaseEvaluator):
173 | """Metrics for MIDI Autoencoder."""
174 |
175 | def __init__(self, sample_rate, frame_rate, db_key='loudness_db',
176 | f0_key='f0_hz'):
177 | super().__init__(sample_rate, frame_rate)
178 | self._midi_metrics = metrics.MidiMetrics(
179 | frames_per_second=frame_rate, tag='learned')
180 | self._db_key = db_key
181 | self._f0_key = f0_key
182 |
183 | def evaluate(self, batch, outputs, losses):
184 | del losses # Unused.
185 | self._midi_metrics.update_state(outputs, outputs['pianoroll'])
186 |
187 | def sample(self, batch, outputs, step):
188 | audio = batch['audio']
189 | summaries.audio_summary(
190 | audio, step, self._sample_rate, name='audio_original')
191 |
192 | audio_keys = ['midi_audio', 'synth_audio', 'midi_audio2', 'synth_audio2']
193 | for k in audio_keys:
194 | if k in outputs and outputs[k] is not None:
195 | summaries.audio_summary(outputs[k], step, self._sample_rate, name=k)
196 | summaries.spectrogram_summary(audio, outputs[k], step, tag=k)
197 | summaries.waveform_summary(audio, outputs[k], step, name=k)
198 |
199 | summaries.f0_summary(
200 | batch[self._f0_key], outputs[f'{self._f0_key}_pred'],
201 | step, name='f0_hz_rec')
202 |
203 | summaries.pianoroll_summary(outputs, step, 'pianoroll',
204 | self._frame_rate, 'pianoroll')
205 | summaries.midiae_f0_summary(batch[self._f0_key], outputs, step)
206 | ld_rec = f'{self._db_key}_rec'
207 | if ld_rec in outputs:
208 | summaries.midiae_ld_summary(batch[self._db_key], outputs, step,
209 | self._db_key)
210 |
211 | summaries.midiae_sp_summary(outputs, step)
212 |
213 | def flush(self, step):
214 | self._midi_metrics.flush(step)
215 |
216 |
217 | @gin.register
218 | class MidiHeuristicEvaluator(BaseEvaluator):
219 | """Metrics for MIDI heuristic."""
220 |
221 | def __init__(self, sample_rate, frame_rate):
222 | super().__init__(sample_rate, frame_rate)
223 | self._midi_metrics = metrics.MidiMetrics(
224 | tag='heuristic', frames_per_second=frame_rate)
225 |
226 | def _compute_heuristic_notes(self, outputs):
227 | return heuristics.segment_notes_batch(
228 | binarize_f=heuristics.midi_heuristic,
229 | pick_f0_f=heuristics.mean_f0,
230 | pick_amps_f=heuristics.median_amps,
231 | controls_batch=outputs)
232 |
233 | def evaluate(self, batch, outputs, losses):
234 | del losses # Unused.
235 | notes = self._compute_heuristic_notes(outputs)
236 | self._midi_metrics.update_state(outputs, notes)
237 |
238 | def sample(self, batch, outputs, step):
239 | notes = self._compute_heuristic_notes(outputs)
240 | outputs['heuristic_notes'] = notes
241 | summaries.midi_summary(outputs, step, 'heuristic', self._frame_rate,
242 | 'heuristic_notes')
243 | summaries.pianoroll_summary(outputs, step, 'heuristic',
244 | self._frame_rate, 'heuristic_notes')
245 |
246 | def flush(self, step):
247 | self._midi_metrics.flush(step)
248 |
249 |
250 |
--------------------------------------------------------------------------------
/ddsp/training/gin/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Gin Configs: Datasets
2 |
3 | This directory contains gin configs for datasets. Each file contains datasets to use for training, evaluation, and sampling.
4 | Each training run with `ddsp_run.py` must be supplied both a model file and a dataset file, each with the `--gin_file` flag.
5 | Evaluation and sampling jobs only require a dataset file.
6 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/base.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | import ddsp.training
3 |
4 | # Evaluate
5 | evaluate.batch_size = 32
6 | evaluate.num_batches = 5 # Depends on dataset size.
7 |
8 | # Sample
9 | sample.batch_size = 16
10 | sample.num_batches = 1
11 | sample.ckpt_delay_secs = 300 # 5 minutes
12 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/nsynth.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'datasets/base.gin'
3 |
4 | # Dataset (Different scopes to pass different args to different instances)
5 | train.data_provider = @train_data/data.NSynthTfds()
6 | train_data/NSynthTfds.split = 'train'
7 |
8 | evaluate.data_provider = @eval_data/data.NSynthTfds()
9 | eval_data/data.NSynthTfds.split = 'valid'
10 |
11 | sample.data_provider = @sample_data/data.NSynthTfds()
12 | sample_data/data.NSynthTfds.split = 'test'
13 |
14 | # Evaluate
15 | evaluate.num_batches = 50 # Full test set ~17000 samples
16 |
17 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/tfrecord.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'datasets/base.gin'
3 |
4 | # Make dataset with ddsp/training/data_preparation/ddsp_prepare_tfrecord.py
5 | # --gin_param="TFRecordProvider.file_pattern='/path/to/dataset*.tfrecord'"
6 |
7 | # Dataset
8 | train.data_provider = @data.TFRecordProvider()
9 | evaluate.data_provider = @data.TFRecordProvider()
10 | sample.data_provider = @data.TFRecordProvider()
11 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/urmp/README.md:
--------------------------------------------------------------------------------
1 | # URMP gin configs
2 |
3 | These gin configs are for the [URMP dataset](http://www2.ece.rochester.edu/projects/air/projects/URMP.html).
4 |
5 | To use, add the base path as a param with `--gin-param=base_path='path/to/urmp_f0_interop/'`.
6 |
7 | To select a specific instrument, set `--gin-param=instrument_key='sax'`.
8 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/urmp/all.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'datasets/urmp/base.gin'
3 |
4 | # All URMP instruments
5 | Urmp.instrument_key = 'all'
6 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/urmp/all_midi.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'datasets/urmp/midi_base.gin'
3 |
4 | # All URMP instruments
5 | UrmpMidi.instrument_key = 'all'
6 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/urmp/base.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'datasets/base.gin'
3 |
4 | train.data_provider = @train_data/data.Urmp()
5 | train_data/data.Urmp.split = 'train'
6 |
7 | evaluate.data_provider = @test_data/data.Urmp()
8 | sample.data_provider = @test_data/data.Urmp()
9 | test_data/data.Urmp.split = 'test'
10 |
--------------------------------------------------------------------------------
/ddsp/training/gin/datasets/urmp/midi_base.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'datasets/base.gin'
3 |
4 | train.data_provider = @train_data/data.UrmpMidi()
5 | train_data/data.UrmpMidi.split = 'train'
6 |
7 | evaluate.data_provider = @test_data/data.UrmpMidi()
8 | sample.data_provider = @test_data/data.UrmpMidi()
9 | eval_discrete.data_provider = @test_data/data.UrmpMidi()
10 | test_data/data.UrmpMidi.split = 'test'
11 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/basic.gin:
--------------------------------------------------------------------------------
1 | evaluators = [
2 | @BasicEvaluator,
3 | ]
4 |
5 | evaluate.evaluator_classes = %evaluators
6 | sample.evaluator_classes = %evaluators
7 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/basic_f0_ld.gin:
--------------------------------------------------------------------------------
1 | evaluators = [
2 | @BasicEvaluator,
3 | @F0LdEvaluator
4 | ]
5 |
6 | evaluate.evaluator_classes = %evaluators
7 | sample.evaluator_classes = %evaluators
8 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/basic_f0_ld_twm.gin:
--------------------------------------------------------------------------------
1 | evaluators = [
2 | @BasicEvaluator,
3 | @F0LdEvaluator,
4 | @TWMEvaluator
5 | ]
6 |
7 | evaluate.evaluator_classes = %evaluators
8 | sample.evaluator_classes = %evaluators
9 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/heuristic.gin:
--------------------------------------------------------------------------------
1 | segment_notes_batch.binarize_f = @heuristics.midi_heuristic
2 | segment_notes_batch.pick_f0_f = @heuristics.mean_f0
3 | segment_notes_batch.pick_amps_f = @heuristics.median_amps
4 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/heuristic_power.gin:
--------------------------------------------------------------------------------
1 | segment_notes_batch.binarize_f = @heuristics.midi_heuristic_power
2 | segment_notes_batch.pick_f0_f = @heuristics.mean_f0
3 | segment_notes_batch.pick_amps_f = @heuristics.median_amps
4 |
--------------------------------------------------------------------------------
/ddsp/training/gin/eval/midi_ae.gin:
--------------------------------------------------------------------------------
1 | evaluators = [
2 | @BasicEvaluator,
3 | @MidiAutoencoderEvaluator,
4 | ]
5 |
6 | evaluate.evaluator_classes = %evaluators
7 | sample.evaluator_classes = %evaluators
8 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/README.md:
--------------------------------------------------------------------------------
1 | # Gin Configs: Models
2 |
3 | This directory contains gin configs for model architectures (including ddsp modules).
4 | Each training run with `ddsp_run.py` must be supplied both a model file and a dataset file, each with the `--gin_file` flag.
5 | Evaluation and sampling jobs only require a dataset file.
6 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/ae.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # Autoencoder that decodes from (loudness, f0, z).
3 | # z = encoder(audio)
4 |
5 | import ddsp
6 | import ddsp.training
7 |
8 | # =====
9 | # Model
10 | # =====
11 | get_model.model = @models.Autoencoder()
12 |
13 | # Preprocessor
14 | Autoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor()
15 | F0LoudnessPreprocessor.time_steps = 1000
16 |
17 | # Encoder
18 | Autoencoder.encoder = @encoders.MfccTimeDistributedRnnEncoder()
19 | MfccTimeDistributedRnnEncoder.rnn_channels = 512
20 | MfccTimeDistributedRnnEncoder.rnn_type = 'gru'
21 | MfccTimeDistributedRnnEncoder.z_dims = 16
22 | MfccTimeDistributedRnnEncoder.z_time_steps = 125
23 |
24 | # Decoder
25 | Autoencoder.decoder = @decoders.RnnFcDecoder()
26 | RnnFcDecoder.rnn_channels = 512
27 | RnnFcDecoder.rnn_type = 'gru'
28 | RnnFcDecoder.ch = 512
29 | RnnFcDecoder.layers_per_stack = 3
30 | RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled', 'z')
31 | RnnFcDecoder.output_splits = (('amps', 1),
32 | ('harmonic_distribution', 100),
33 | ('noise_magnitudes', 65))
34 |
35 | # Losses
36 | Autoencoder.losses = [
37 | @losses.SpectralLoss(),
38 | ]
39 | SpectralLoss.loss_type = 'L1'
40 | SpectralLoss.mag_weight = 1.0
41 | SpectralLoss.logmag_weight = 1.0
42 |
43 | # ==============
44 | # ProcessorGroup
45 | # ==============
46 |
47 | Autoencoder.processor_group = @processors.ProcessorGroup()
48 |
49 | ProcessorGroup.dag = [
50 | (@synths.Harmonic(),
51 | ['amps', 'harmonic_distribution', 'f0_hz']),
52 | (@synths.FilteredNoise(),
53 | ['noise_magnitudes']),
54 | (@processors.Add(),
55 | ['filtered_noise/signal', 'harmonic/signal']),
56 | ]
57 |
58 | # Harmonic Synthesizer
59 | Harmonic.name = 'harmonic'
60 | Harmonic.n_samples = 64000
61 | Harmonic.sample_rate = 16000
62 | Harmonic.normalize_below_nyquist = True
63 | Harmonic.scale_fn = @core.exp_sigmoid
64 |
65 | # Filtered Noise Synthesizer
66 | FilteredNoise.name = 'filtered_noise'
67 | FilteredNoise.n_samples = 64000
68 | FilteredNoise.window_size = 0
69 | FilteredNoise.scale_fn = @core.exp_sigmoid
70 |
71 | # Add
72 | processors.Add.name = 'add'
73 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/README.md:
--------------------------------------------------------------------------------
1 | Configs for the MidiAutoencoder family of models (decoding from MIDI to audio).
2 |
3 | Configs that go:
4 |
5 | * Synth Params -> MIDI -> Synth Params
6 |
7 | Instead of:
8 |
9 | * LD/DB -> MIDI -> LD/DB
10 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/midiae.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # MidiAutoencoder (base class)
3 |
4 | import ddsp
5 | import ddsp.training
6 |
7 | include 'eval/midi_ae.gin'
8 | include 'models/midiae/mixins/recon_lossgroup.gin'
9 |
10 | # Training
11 | train.num_steps = 500000
12 |
13 |
14 | # ==============================================================================
15 | # Model
16 | # ==============================================================================
17 | get_model.model = @models.MidiAutoencoder()
18 |
19 |
20 | # Preprocessor
21 | MidiAutoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor()
22 | F0LoudnessPreprocessor.time_steps = 1000
23 |
24 |
25 | # Synthcoder
26 | MidiAutoencoder.synthcoder = @decoders.DilatedConvDecoder()
27 | DilatedConvDecoder.ch = 128
28 | DilatedConvDecoder.layers_per_stack = 9
29 | DilatedConvDecoder.norm_type = 'layer'
30 | DilatedConvDecoder.input_keys = ('ld_scaled', 'f0_scaled')
31 | DilatedConvDecoder.stacks = 2
32 | DilatedConvDecoder.conditioning_keys = None
33 | DilatedConvDecoder.precondition_stack = None
34 | DilatedConvDecoder.output_splits = (('amplitudes', 1),
35 | ('harmonic_distribution', 60),
36 | ('magnitudes', 65))
37 |
38 |
39 | # Stop Gradient
40 | MidiAutoencoder.sg_before_midiae = True
41 |
42 |
43 | # MIDI Encoder
44 | MidiAutoencoder.midi_encoder = None
45 |
46 |
47 | # MIDI Decoder
48 | MidiAutoencoder.midi_decoder = @decoders.MidiToHarmonicDecoder()
49 | MidiToHarmonicDecoder.f0_residual = True
50 | MidiToHarmonicDecoder.norm = True
51 | MidiToHarmonicDecoder.output_splits = (('f0_midi', 1),
52 | ('amplitudes', 1),
53 | ('harmonic_distribution', 60),
54 | ('magnitudes', 65))
55 | MidiToHarmonicDecoder.net = @dec/nn.DilatedConvStack()
56 | dec/DilatedConvStack.ch = 128
57 | dec/DilatedConvStack.layers_per_stack = 5
58 | dec/DilatedConvStack.stacks = 4
59 | dec/DilatedConvStack.norm_type = 'layer'
60 | dec/DilatedConvStack.conditional = False
61 |
62 |
63 | # ==============================================================================
64 | # Losses
65 | # ==============================================================================
66 | MidiAutoencoder.reconstruction_losses = @recon_lossgroup/losses.LossGroup()
67 |
68 |
69 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/mixins/README.md:
--------------------------------------------------------------------------------
1 | # Mixins
2 |
3 | These gin files have independent effects of eachother so multiple of them
4 | can be included to mix in various configurations.
5 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/mixins/_.gin:
--------------------------------------------------------------------------------
1 | # Empty config so that we can do hyper sweeps over mixin files, with none being
2 | # an option.
3 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/mixins/hmm_prior.gin:
--------------------------------------------------------------------------------
1 | # Stop Gradient
2 | MidiAutoencoder.hmm_prior = @losses.HmmPrior()
3 | HmmPrior.avg_length = 200
4 | HmmPrior.midi_std = 0.5
5 | HmmPrior.n_timesteps = 1000
6 | HmmPrior.n_pitches = 128
7 | HmmPrior.weight = 5.0
8 |
9 | MidiAutoencoder.hmm_quantize = False
10 |
11 | # # Latent Closeness Losses
12 | MidiAutoencoder.qpitch_f0rec_loss = @qpitch/losses.MarginLoss()
13 | qpitch/MarginLoss.weight = 50.0
14 | qpitch/MarginLoss.margin = 0.5
15 | qpitch/MarginLoss.name = 'q_pitch-f0_rec'
16 |
17 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/mixins/midi_encoder.gin:
--------------------------------------------------------------------------------
1 | # Turn on the MIDI Encoder network (it's off by default)
2 |
3 |
4 | # HACK(emanilow): initialize all the model types (gin will ignore unused)
5 | MidiAutoencoder.midi_encoder = @encoders.HarmonicToMidiEncoder()
6 | ZMidiAutoencoder.midi_encoder = @encoders.HarmonicToMidiEncoder()
7 | HmmMidiAutoencoder.midi_encoder = @encoders.HarmonicToMidiEncoder()
8 |
9 | # --- MIDI Encoder details
10 | HarmonicToMidiEncoder.f0_residual = True
11 | HarmonicToMidiEncoder.net = @enc/nn.DilatedConvStack()
12 | enc/DilatedConvStack.ch = 128
13 | enc/DilatedConvStack.layers_per_stack = 5
14 | enc/DilatedConvStack.stacks = 4
15 | enc/DilatedConvStack.norm_type = 'layer'
16 | enc/DilatedConvStack.conditional = False
17 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/mixins/recon_lossgroup.gin:
--------------------------------------------------------------------------------
1 | # Standard reconstruction losses for the MidiAutoencoder encapsulated in a
2 | # LossGroup
3 |
4 |
5 | recon_lossgroup/LossGroup.dag = [
6 | ['synth_spectral_loss', ['audio', 'synth_audio']],
7 | ['f0_loss', ['f0_midi', 'f0_midi_pred', 'f0_loss_weights']],
8 | ['amps_loss', ['amps', 'amps_pred']],
9 | ['hd_loss', ['hd', 'hd_pred']],
10 | ['noise_loss', ['noise', 'noise_pred']]
11 | ]
12 |
13 | recon_lossgroup/LossGroup.amps_loss = @amps_loss/losses.ParamLoss()
14 | amps_loss/ParamLoss.weight = 0.5
15 | amps_loss/ParamLoss.loss_type = 'L1'
16 | amps_loss/ParamLoss.name = 'amplitude_reconstruction'
17 |
18 | recon_lossgroup/LossGroup.f0_loss = @f0_loss/losses.ParamLoss()
19 | f0_loss/ParamLoss.weight = 50.0
20 | f0_loss/ParamLoss.loss_type = 'L2'
21 | f0_loss/ParamLoss.name = 'f0_reconstruction'
22 |
23 | recon_lossgroup/LossGroup.hd_loss = @hd_loss/losses.ParamLoss()
24 | hd_loss/ParamLoss.weight = 500.0
25 | hd_loss/ParamLoss.loss_type = 'L1'
26 | hd_loss/ParamLoss.name = 'harmonic_distribution_reconstruction'
27 |
28 | recon_lossgroup/LossGroup.noise_loss = @noise_loss/losses.ParamLoss()
29 | noise_loss/ParamLoss.weight = 0.5
30 | noise_loss/ParamLoss.loss_type = 'L1'
31 | noise_loss/ParamLoss.name = 'noise_reconstruction'
32 |
33 | recon_lossgroup/LossGroup.synth_spectral_loss = @losses.SpectralLoss()
34 | SpectralLoss.loss_type = 'L1'
35 | SpectralLoss.mag_weight = 1.0
36 | SpectralLoss.logmag_weight = 1.0
37 | SpectralLoss.name = 'spectral_loss_synth'
38 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/midiae/z_midiae.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # ZMidiAutoencoder
3 |
4 | import ddsp
5 | import ddsp.training
6 |
7 | include 'eval/midi_ae.gin'
8 | include 'models/midiae/mixins/recon_lossgroup.gin'
9 |
10 | # Training
11 | train.num_steps = 500000
12 |
13 |
14 | # ==============================================================================
15 | # Model
16 | # ==============================================================================
17 | get_model.model = @models.ZMidiAutoencoder()
18 |
19 |
20 | # Preprocessor
21 | ZMidiAutoencoder.preprocessor = @preprocessing.F0LoudnessPreprocessor()
22 | F0LoudnessPreprocessor.time_steps = 1000
23 |
24 |
25 | # Synth encoder
26 | ZMidiAutoencoder.z_synth_encoders = [
27 | @encoders.OneHotEncoder(),
28 | @encoders.MfccTimeDistributedRnnEncoder(),
29 | ]
30 | MfccTimeDistributedRnnEncoder.rnn_channels = 512
31 | MfccTimeDistributedRnnEncoder.rnn_type = 'gru'
32 | MfccTimeDistributedRnnEncoder.z_dims = 128
33 | MfccTimeDistributedRnnEncoder.z_time_steps = 1000
34 |
35 |
36 | # OneHotEncoder
37 | OneHotEncoder.one_hot_key = 'instrument_id'
38 | OneHotEncoder.vocab_size = 13 # num instruments
39 | OneHotEncoder.n_dims = 128
40 | OneHotEncoder.skip_expand = False
41 |
42 |
43 | # Synthcoder
44 | ZMidiAutoencoder.synthcoder = @decoders.DilatedConvDecoder()
45 | DilatedConvDecoder.ch = 128
46 | DilatedConvDecoder.layers_per_stack = 9
47 | DilatedConvDecoder.norm_type = 'layer'
48 | DilatedConvDecoder.input_keys = ('ld_scaled', 'f0_scaled')
49 | DilatedConvDecoder.stacks = 2
50 | DilatedConvDecoder.conditioning_keys = ('z',)
51 | DilatedConvDecoder.precondition_stack = None
52 | DilatedConvDecoder.output_splits = (('amplitudes', 1),
53 | ('harmonic_distribution', 60),
54 | ('magnitudes', 65))
55 |
56 |
57 | # Stop Gradient
58 | ZMidiAutoencoder.sg_before_midiae = True
59 |
60 |
61 | # MIDI Encoder
62 | ZMidiAutoencoder.midi_encoder = None
63 |
64 |
65 | # MIDI Global Latents
66 | ZMidiAutoencoder.z_global_encoders = [
67 | @encoders.OneHotEncoder(),
68 | @ee/encoders.ExpressionEncoder(),
69 | ]
70 |
71 |
72 | # Global Expression Encoder
73 | ee/ExpressionEncoder.input_keys = ('f0_scaled',
74 | 'amps_scaled',
75 | 'hd_scaled',
76 | 'noise_scaled',)
77 | ee/ExpressionEncoder.z_dims = 128
78 | ee/ExpressionEncoder.pool_time = True
79 | ee/ExpressionEncoder.net = @z/nn.DilatedConvStack()
80 |
81 |
82 | # Z Note Encoder
83 | ZMidiAutoencoder.z_note_encoder = @zn/encoders.ExpressionEncoder()
84 | zn/ExpressionEncoder.input_keys = ('f0_scaled',
85 | 'amps_scaled',
86 | 'hd_scaled',
87 | 'noise_scaled',)
88 | zn/ExpressionEncoder.z_dims = 128
89 | zn/ExpressionEncoder.pool_time = False
90 | zn/ExpressionEncoder.net = @z/nn.DilatedConvStack()
91 |
92 |
93 | # Same DilatedConvStack for Note encoder and Global Expression Encoder.
94 | z/DilatedConvStack.ch = 128
95 | z/DilatedConvStack.layers_per_stack = 5
96 | z/DilatedConvStack.stacks = 4
97 | z/DilatedConvStack.norm_type = 'layer'
98 | z/DilatedConvStack.conditional = False
99 |
100 |
101 | # Z Preconditioning Stack
102 | ZMidiAutoencoder.z_preconditioning_stack = @nn.FcStackOut()
103 | nn.FcStackOut.ch = 512
104 | nn.FcStackOut.n_out = 256
105 | nn.FcStackOut.layers = 5
106 |
107 |
108 | # MIDI Decoder
109 | ZMidiAutoencoder.midi_decoder = @decoders.MidiToHarmonicDecoder()
110 | MidiToHarmonicDecoder.f0_residual = True
111 | MidiToHarmonicDecoder.norm = True
112 | MidiToHarmonicDecoder.output_splits = (('f0_midi', 1),
113 | ('amplitudes', 1),
114 | ('harmonic_distribution', 60),
115 | ('magnitudes', 65))
116 | MidiToHarmonicDecoder.net = @dec/nn.DilatedConvStack()
117 | dec/DilatedConvStack.ch = 128
118 | dec/DilatedConvStack.layers_per_stack = 5
119 | dec/DilatedConvStack.stacks = 4
120 | dec/DilatedConvStack.norm_type = 'layer'
121 | dec/DilatedConvStack.conditional = True
122 |
123 |
124 | # ==============================================================================
125 | # Losses
126 | # ==============================================================================
127 | # Reconstruction Loss
128 | ZMidiAutoencoder.reconstruction_losses = @recon_lossgroup/losses.LossGroup()
129 |
130 |
131 | # ZMidiAutoencoder.z_global_prior = @zg/losses.GaussianPrior()
132 | # zg/GaussianPrior.weight = 1.0
133 | # zg/GaussianPrior.name = 'global_kl'
134 |
135 |
136 | # ZMidiAutoencoder.z_note_prior = @zn/losses.GaussianPrior()
137 | # zn/GaussianPrior.weight = 1.0
138 | # zn/GaussianPrior.name = 'note_kl'
139 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/solo_instrument.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # Decodes from (loudness, f0). Has a trainable reverb component as well.
3 | # Since it uses a trainable reverb, training data should all be from the same
4 | # acoustic environment.
5 |
6 | include 'models/ae.gin'
7 |
8 | # Encoder
9 | Autoencoder.encoder = None
10 |
11 | # Decoder
12 | Autoencoder.decoder = @decoders.RnnFcDecoder()
13 | RnnFcDecoder.rnn_channels = 512
14 | RnnFcDecoder.rnn_type = 'gru'
15 | RnnFcDecoder.ch = 512
16 | RnnFcDecoder.layers_per_stack = 3
17 | RnnFcDecoder.input_keys = ('ld_scaled', 'f0_scaled')
18 | RnnFcDecoder.output_splits = (('amps', 1),
19 | ('harmonic_distribution', 60),
20 | ('noise_magnitudes', 65))
21 |
22 | # ==============
23 | # ProcessorGroup
24 | # ==============
25 |
26 | ProcessorGroup.dag = [
27 | (@synths.Harmonic(),
28 | ['amps', 'harmonic_distribution', 'f0_hz']),
29 | (@synths.FilteredNoise(),
30 | ['noise_magnitudes']),
31 | (@processors.Add(),
32 | ['filtered_noise/signal', 'harmonic/signal']),
33 | (@effects.Reverb(),
34 | ['add/signal']),
35 | ]
36 |
37 | # Reverb
38 | Reverb.name = 'reverb'
39 | Reverb.reverb_length = 48000
40 | Reverb.trainable = True
41 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/vst/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/vst/vst.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # Autoencoder that decodes from (power, f0).
3 |
4 | import ddsp
5 | import ddsp.training
6 |
7 | # =====
8 | # Model
9 | # =====
10 | get_model.model = @models.Autoencoder()
11 |
12 | # Globals
13 | frame_rate = 50
14 | frame_size = 1024
15 | sample_rate = 16000
16 | n_samples = 64320 # Extra frame for center padding.
17 |
18 |
19 | # Preprocessor
20 | # Use same preprocessor for creating dataset and for training / inference.
21 | Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor()
22 | OnlineF0PowerPreprocessor:
23 | frame_rate = %frame_rate
24 | frame_size = %frame_size
25 | padding = 'center'
26 | compute_power = True
27 | compute_f0 = False
28 | crepe_saved_model_path = None
29 |
30 |
31 | # Encoder
32 | Autoencoder.encoder = None
33 |
34 |
35 | # Decoder
36 | Autoencoder.decoder = @decoders.RnnFcDecoder()
37 | RnnFcDecoder:
38 | rnn_channels = 512
39 | rnn_type = 'gru'
40 | ch = 256
41 | layers_per_stack = 1
42 | input_keys = ('pw_scaled', 'f0_scaled')
43 | output_splits = (('amps', 1),
44 | ('harmonic_distribution', 60),
45 | ('noise_magnitudes', 65))
46 |
47 | # Losses
48 | Autoencoder.losses = [
49 | @losses.SpectralLoss(),
50 | ]
51 | SpectralLoss:
52 | loss_type = 'L1'
53 | mag_weight = 1.0
54 | logmag_weight = 1.0
55 |
56 | # ==============
57 | # ProcessorGroup
58 | # ==============
59 |
60 | Autoencoder.processor_group = @processors.ProcessorGroup()
61 |
62 | # ==============
63 | # ProcessorGroup
64 | # ==============
65 |
66 | # Has a "Crop" processor to remove the padding from centered frames.
67 |
68 | ProcessorGroup.dag = [
69 | (@synths.Harmonic(),
70 | ['amps', 'harmonic_distribution', 'f0_hz']),
71 | (@synths.FilteredNoise(),
72 | ['noise_magnitudes']),
73 | (@processors.Add(),
74 | ['filtered_noise/signal', 'harmonic/signal']),
75 | (@effects.FilteredNoiseReverb(),
76 | ['add/signal']),
77 | (@processors.Crop(),
78 | ['reverb/signal'])
79 | ]
80 |
81 | # Reverb
82 | FilteredNoiseReverb:
83 | name = 'reverb'
84 | reverb_length = 24000
85 | n_frames = 500
86 | n_filter_banks = 32
87 | trainable = True
88 |
89 | # Harmonic Synthesizer
90 | Harmonic:
91 | name = 'harmonic'
92 | n_samples = %n_samples
93 | sample_rate = %sample_rate
94 | normalize_below_nyquist = True
95 | scale_fn = @core.exp_sigmoid
96 | amp_resample_method = 'linear'
97 |
98 | # Filtered Noise Synthesizer
99 | FilteredNoise:
100 | name = 'filtered_noise'
101 | n_samples = %n_samples
102 | window_size = 0
103 | scale_fn = @core.exp_sigmoid
104 |
105 | # Add
106 | processors.Add.name = 'add'
107 |
108 | # Remove the extra frame of synthesis from centering.
109 | # Since generation is forward.
110 | # Frame size is the frame of the "synthesis" which is just the hop size.
111 | Crop:
112 | frame_size = 320
113 | crop_location = 'back'
114 |
115 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/vst/vst_32k.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # Autoencoder that decodes from (power, f0).
3 | # All lengths from the 16kHz config * 2.
4 |
5 | import ddsp
6 | import ddsp.training
7 |
8 | # =====
9 | # Model
10 | # =====
11 | get_model.model = @models.Autoencoder()
12 |
13 | # Globals
14 | frame_rate = 50
15 | frame_size = 1024
16 | sample_rate = 32000
17 | n_samples = 128640 # Extra frame for center padding. 32000 * 4 + 640
18 |
19 |
20 | # Preprocessor
21 | # Use same preprocessor for creating dataset and for training / inference.
22 | Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor()
23 | OnlineF0PowerPreprocessor:
24 | frame_rate = %frame_rate
25 | frame_size = %frame_size
26 | padding = 'center'
27 | compute_power = True
28 | compute_f0 = False
29 | crepe_saved_model_path = None
30 |
31 |
32 | # Encoder
33 | Autoencoder.encoder = None
34 |
35 |
36 | # Decoder
37 | Autoencoder.decoder = @decoders.RnnFcDecoder()
38 | RnnFcDecoder:
39 | rnn_channels = 512
40 | rnn_type = 'gru'
41 | ch = 256
42 | layers_per_stack = 1
43 | input_keys = ('pw_scaled', 'f0_scaled')
44 | output_splits = (('amps', 1),
45 | ('harmonic_distribution', 100),
46 | ('noise_magnitudes', 98))
47 |
48 | # Losses
49 | Autoencoder.losses = [
50 | @losses.SpectralLoss(),
51 | ]
52 | SpectralLoss:
53 | loss_type = 'L1'
54 | mag_weight = 1.0
55 | logmag_weight = 1.0
56 | fft_sizes = [4096, 2048, 1024, 512, 256, 128]
57 |
58 | # 16kHz fft sizes: (2048, 1024, 512, 256, 128, 64)
59 |
60 | # ==============
61 | # ProcessorGroup
62 | # ==============
63 |
64 | Autoencoder.processor_group = @processors.ProcessorGroup()
65 |
66 | # ==============
67 | # ProcessorGroup
68 | # ==============
69 |
70 | # Has a "Crop" processor to remove the padding from centered frames.
71 |
72 | ProcessorGroup.dag = [
73 | (@synths.Harmonic(),
74 | ['amps', 'harmonic_distribution', 'f0_hz']),
75 | (@synths.FilteredNoise(),
76 | ['noise_magnitudes']),
77 | (@processors.Add(),
78 | ['filtered_noise/signal', 'harmonic/signal']),
79 | (@effects.FilteredNoiseReverb(),
80 | ['add/signal']),
81 | (@processors.Crop(),
82 | ['reverb/signal'])
83 | ]
84 |
85 | # Reverb
86 | FilteredNoiseReverb:
87 | name = 'reverb'
88 | reverb_length = 48000
89 | n_frames = 500
90 | n_filter_banks = 32
91 | initial_bias = -4.0
92 | trainable = True
93 |
94 | # Harmonic Synthesizer
95 | Harmonic:
96 | name = 'harmonic'
97 | n_samples = %n_samples
98 | sample_rate = %sample_rate
99 | normalize_below_nyquist = True
100 | scale_fn = @core.exp_sigmoid
101 | amp_resample_method = 'linear'
102 | use_angular_cumsum = True # Necessary at 48k as oscillators precess more.
103 |
104 | # Filtered Noise Synthesizer
105 | FilteredNoise:
106 | name = 'filtered_noise'
107 | n_samples = %n_samples
108 | window_size = 0
109 | scale_fn = @core.exp_sigmoid
110 |
111 | # Add
112 | processors.Add.name = 'add'
113 |
114 | # Remove the extra frame of synthesis from centering.
115 | # Since generation is forward.
116 | # Frame size is the frame of the "synthesis" which is just the hop size.
117 | Crop:
118 | frame_size = 640
119 | crop_location = 'back'
120 |
121 |
--------------------------------------------------------------------------------
/ddsp/training/gin/models/vst/vst_48k.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # Autoencoder that decodes from (power, f0).
3 | # All lengths from the 16kHz config * 3.
4 |
5 | import ddsp
6 | import ddsp.training
7 |
8 | # =====
9 | # Model
10 | # =====
11 | get_model.model = @models.Autoencoder()
12 |
13 | # Globals
14 | frame_rate = 50
15 | frame_size = 1024
16 | sample_rate = 48000
17 | n_samples = 192960 # Extra frame for center padding. 48000 * 4 + 960
18 |
19 |
20 | # Preprocessor
21 | # Use same preprocessor for creating dataset and for training / inference.
22 | Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor()
23 | OnlineF0PowerPreprocessor:
24 | frame_rate = %frame_rate
25 | frame_size = %frame_size
26 | padding = 'center'
27 | compute_power = True
28 | compute_f0 = False
29 | crepe_saved_model_path = None
30 |
31 |
32 | # Encoder
33 | Autoencoder.encoder = None
34 |
35 |
36 | # Decoder
37 | Autoencoder.decoder = @decoders.RnnFcDecoder()
38 | RnnFcDecoder:
39 | rnn_channels = 512
40 | rnn_type = 'gru'
41 | ch = 256
42 | layers_per_stack = 1
43 | input_keys = ('pw_scaled', 'f0_scaled')
44 | output_splits = (('amps', 1),
45 | ('harmonic_distribution', 100),
46 | ('noise_magnitudes', 98))
47 |
48 | # Losses
49 | Autoencoder.losses = [
50 | @losses.SpectralLoss(),
51 | ]
52 | SpectralLoss:
53 | loss_type = 'L1'
54 | mag_weight = 1.0
55 | logmag_weight = 1.0
56 | fft_sizes = [6144, 3072, 1536, 768, 384, 192]
57 |
58 | # 16kHz fft sizes: (2048, 1024, 512, 256, 128, 64)
59 |
60 | # ==============
61 | # ProcessorGroup
62 | # ==============
63 |
64 | Autoencoder.processor_group = @processors.ProcessorGroup()
65 |
66 | # ==============
67 | # ProcessorGroup
68 | # ==============
69 |
70 | # Has a "Crop" processor to remove the padding from centered frames.
71 |
72 | ProcessorGroup.dag = [
73 | (@synths.Harmonic(),
74 | ['amps', 'harmonic_distribution', 'f0_hz']),
75 | (@synths.FilteredNoise(),
76 | ['noise_magnitudes']),
77 | (@processors.Add(),
78 | ['filtered_noise/signal', 'harmonic/signal']),
79 | (@effects.FilteredNoiseReverb(),
80 | ['add/signal']),
81 | (@processors.Crop(),
82 | ['reverb/signal'])
83 | ]
84 |
85 | # Reverb
86 | FilteredNoiseReverb:
87 | name = 'reverb'
88 | reverb_length = 72000
89 | n_frames = 500
90 | n_filter_banks = 32
91 | initial_bias = -4.0
92 | trainable = True
93 |
94 | # Harmonic Synthesizer
95 | Harmonic:
96 | name = 'harmonic'
97 | n_samples = %n_samples
98 | sample_rate = %sample_rate
99 | normalize_below_nyquist = True
100 | scale_fn = @core.exp_sigmoid
101 | amp_resample_method = 'linear'
102 | use_angular_cumsum = True # Necessary at 48k as oscillators precess more.
103 |
104 | # Filtered Noise Synthesizer
105 | FilteredNoise:
106 | name = 'filtered_noise'
107 | n_samples = %n_samples
108 | window_size = 0
109 | scale_fn = @core.exp_sigmoid
110 |
111 | # Add
112 | processors.Add.name = 'add'
113 |
114 | # Remove the extra frame of synthesis from centering.
115 | # Since generation is forward.
116 | # Frame size is the frame of the "synthesis" which is just the hop size.
117 | Crop:
118 | frame_size = 960
119 | crop_location = 'back'
120 |
121 |
--------------------------------------------------------------------------------
/ddsp/training/gin/optimization/README.md:
--------------------------------------------------------------------------------
1 | # Gin Configs: Optimization
2 |
3 | This directory contains gin configs for training hyperparameters such as `batch_size`
4 | and `learning_rate`. Default values are loaded via `ddsp_run.py` and can be replaced with the `--gin_param` flag.
5 |
--------------------------------------------------------------------------------
/ddsp/training/gin/optimization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/optimization/base.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | import ddsp
3 | import ddsp.training
4 |
5 | # Globals for easier configuration with --gin_param
6 | learning_rate = 3e-4
7 | batch_size = 32
8 |
9 | train.batch_size = %batch_size
10 | train.num_steps = 1000000
11 | train.steps_per_summary = 300
12 | train.steps_per_save = 300
13 |
14 | Trainer.learning_rate = %learning_rate
15 | Trainer.lr_decay_steps = 10000
16 | Trainer.lr_decay_rate = 0.98
17 | Trainer.grad_clip_norm = 3.0
18 | Trainer.checkpoints_to_keep = 100
19 |
--------------------------------------------------------------------------------
/ddsp/training/gin/optimization/base_tpu.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'optimization/base.gin'
3 |
4 | # Larger batch size for TPU.
5 | batch_size = 64 # (4x2, 4 per a core)
6 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/README.md:
--------------------------------------------------------------------------------
1 | # Gin Configs: Papers
2 |
3 | This directory contains gin configs for replicating the experiments in published
4 | papers using ddsp.
5 |
6 | ## List of papers
7 |
8 | * [iclr2020](./iclr2020/): Experiments from the orginal DDSP ICLR 2020 paper ([paper](https://openreview.net/forum?id=B1x1ma4tDr), [blog](https://magenta.tensorflow.org/ddsp)).
9 |
10 | ```latex
11 | @inproceedings{
12 | engel2020ddsp,
13 | title={DDSP: Differentiable Digital Signal Processing},
14 | author={Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and Adam Roberts},
15 | booktitle={International Conference on Learning Representations},
16 | year={2020},
17 | url={https://openreview.net/forum?id=B1x1ma4tDr}
18 | }
19 | ```
20 |
21 | * [icml2020](./icml2020/): Experiments from the ICML 2020 SAS Workshop paper on Inverse Audio Synthesis ([paper](https://goo.gl/magenta/ddsp-inv)).
22 |
23 | ```latex
24 | @inproceedings{
25 | engel2020ddsp,
26 | title={Self-supervised Pitch Detection with Inverse Audio Synthesis},
27 | author={Jesse Engel and Rigel Swavely and Lamtharn (Hanoi) Hantrakul and Adam Roberts and Curtis Hawthorne},
28 | booktitle={International Conference on Machine Learning, Self-supervised Audio and Speech Workshop},
29 | year={2020},
30 | url={https://openreview.net/forum?id=RlVTYWhsky7}
31 | }
32 | ```
33 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/iclr2020/README.md:
--------------------------------------------------------------------------------
1 | # DDSP: Differentiable Digital Signal Processing
2 | _ICLR 2020, ([paper](g.co/magenta/ddsp))_
3 |
4 | Gin configs for reproducing the results of the paper.
5 |
6 | Each config file specifies both the dataset and the model, so only a single --gin_file flag is required.
7 |
8 | ## Configs
9 |
10 | * `nsynth_ae.gin`: Autoencoder (with latent z, and f0 given by CREPE) on NSynth dataset.
11 | * `nsynth_ae_abs.gin`: Deprecated. Please install v0.0.6 to use this config. Autoencoder (with latent z, and f0 inferred by model) on NSynth dataset. For improved version, see [icml2020](./../icml2020/) paper.
12 | * `solo_instrument.gin`: Decoder (with no z, and f0 given by CREPE) on your own dataset of a monophonic instrument. Make dataset with `ddsp/training/data_preparation/ddsp_prepare_tfrecord.py`.
13 | * `tiny_instrument.gin`: Same as `solo_instrument.gin`, but with much smaller model.
14 |
15 |
16 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/iclr2020/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/iclr2020/nsynth_ae.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'models/ae.gin'
3 | include 'datasets/nsynth.gin'
4 | include 'eval/basic_f0_ld.gin'
5 |
6 | # To recreate original experiment optimization params, uncomment lines below.
7 | # learning_rate = 1e-5
8 | # batch_size = 128
9 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/iclr2020/solo_instrument.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # Make dataset with ddsp/training/data_preparation/ddsp_prepare_tfrecord.py
3 | # --gin_param="TFRecordProvider.file_pattern='/path/to/dataset*.tfrecord'"
4 |
5 | include 'models/solo_instrument.gin'
6 | include 'datasets/tfrecord.gin'
7 | include 'eval/basic_f0_ld.gin'
8 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/iclr2020/tiny_instrument.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | # The 'Tiny Model' config shown in
3 | # https://storage.googleapis.com/ddsp/index.html#tiny
4 | include 'papers/iclr2020/solo_instrument.gin'
5 |
6 | RnnFcDecoder.layers_per_stack = 0
7 | RnnFcDecoder.rnn_channels = 256
8 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/icml2020/README.md:
--------------------------------------------------------------------------------
1 | # Self-supervised Pitch Detection by Inverse Audio Synthesis
2 | _ICML SAS Workshop 2020_ ([paper](https://openreview.net/forum?id=RlVTYWhsky7))
3 |
4 | Instructions for reproducing the results of the paper.
5 |
6 | ## Generate synthetic dataset
7 |
8 | After pip installing ddsp, `pip install -U ddsp[data_preparation]`, you can create a synthetic dataset for the transcribing autoencoder using the installed script.
9 | The command below creates a very small dataset as a test.
10 |
11 | ```bash
12 | ddsp_generate_synthetic_dataset \
13 | --output_tfrecord_path=/tmp/synthetic_data.tfrecord \
14 | --gin_param="generate_examples.generate_fn=@data_preparation.synthetic_data.generate_notes_v2" \
15 | --gin_param="generate_notes_v2.n_timesteps=125" \
16 | --gin_param="generate_notes_v2.n_harmonics=100" \
17 | --gin_param="generate_notes_v2.n_mags=65" \
18 | --num_examples=1000 \
19 | --num_shards=2 \
20 | --alsologtostderr
21 | ```
22 |
23 | The synthetic dataset in the paper has 10,000,000 examples split into 1000 shards.
24 | Because that's a large dataset, we have provided sharded TFRecord files for the synthetic on GCP at `gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*`.
25 | As the name indicates, the dataset was made with 125 timesteps, 100 harmonics, and 65 noise bins.
26 | If training on GCP it is fast to directly read from these buckets, but if training locally you will probably want to download the files locally (~1.7 TB) using the `gsutil` command line utility from the [gcloud sdk](https://cloud.google.com/sdk/docs/downloads-interactive).
27 |
28 |
29 | If you wish to make your own large dataset, it may require runinng in a distributed setup such as [Dataflow](https://cloud.google.com/dataflow) on GCP.
30 | If running on DataFlow, you'll need to set the `--pipeline_options` flag using the execution parameters described at https://cloud.google.com/dataflow/docs/guides/specifying-exec-params
31 | E.g., `--pipeline_options=--runner=DataFlowRunner,--project=,--temp_location=gs://,--region=us-central1,--job_name=`...
32 |
33 | ## Pretrain transcribing autoencoder
34 |
35 | ### Train
36 | Point the dataset file_pattern to the synthetic dataset created above.
37 |
38 | ```bash
39 | ddsp_run \
40 | --mode=train \
41 | --save_dir=/tmp/$USER-tae-pretrain-0 \
42 | --gin_file=papers/icml2020/pretrain_model.gin \
43 | --gin_file=papers/icml2020/pretrain_dataset.gin \
44 | --gin_param="SyntheticNotes.file_pattern='gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*'" \
45 | --gin_param="batch_size=64" \
46 | --alsologtostderr
47 | ```
48 |
49 | This command points to datasets on GCP, and the gin_params for file_pattern are redudant with the default values in the gin files, but provided here to show how you would modify them for local dataset paths.
50 |
51 | In the paper we train for ~1.2M steps with a batch size of 64. A single v100 can fit a max batch size of 32, so you will need to use multiple gpus to exactly reproduce the experiment. Given the large amount of pretraining, a pretrained checkpoint [is available here](https://storage.googleapis.com/ddsp-inv/ckpts/synthetic_pretrained_ckpt.zip)
52 | or on GCP at `gs://ddsp-inv/ckpts/synthetic_pretrained_ckpt`.
53 |
54 | ### Eval and Sample
55 |
56 | This will add evaluation metrics as scalar summaries for tensorboard.
57 |
58 | ```bash
59 | ddsp_run \
60 | --mode=eval \
61 | --run_once \
62 | --restore_dir=/path/to/trained_ckpt \
63 | --save_dir=/path/to/trained_ckpt \
64 | --gin_file=papers/icml2020/pretrain_dataset.gin \
65 | --alsologtostderr
66 | ```
67 |
68 | This will add image and audio scalar summaries for tensorboard.
69 |
70 | ```bash
71 | ddsp_run \
72 | --mode=sample \
73 | --run_once \
74 | --restore_dir=/path/to/trained_ckpt \
75 | --save_dir=/path/to/trained_ckpt \
76 | --gin_file=papers/icml2020/pretrain_dataset.gin \
77 | --alsologtostderr
78 | ```
79 |
80 |
81 | ## Finetuning transcribing autoencoder
82 |
83 | ### Train
84 | Now we finetune the model from above on a specific dataset. Use the `--restore_dir` flag to point to your pretrained checkpoint.
85 |
86 | A pretrained model on 1.2M steps (batch size=64) of synthetic data [is available here](https://storage.googleapis.com/ddsp-inv/ckpts/synthetic_pretrained_ckpt.zip)
87 | or on GCP.
88 |
89 | ```bash
90 | gsutil cp -r gs://ddsp-inv/ckpts/synthetic_pretrained_ckpt /path/to/synthetic_pretrained_ckpt
91 | ```
92 |
93 | ```bash
94 | ddsp_run \
95 | --mode=train \
96 | --restore_dir=/path/to/synthetic_pretrained_ckpt \
97 | --save_dir=/tmp/$USER-tae-finetune-0 \
98 | --gin_file=papers/icml2020/finetune_model.gin \
99 | --gin_file=papers/icml2020/finetune_dataset.gin \
100 | --gin_param="SyntheticNotes.file_pattern='gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*'" \
101 | --gin_param="train_data/TFRecordProvider.file_pattern='gs://ddsp-inv/datasets/all_instruments_train.tfrecord*'" \
102 | --gin_param="test_data/TFRecordProvider.file_pattern='gs://ddsp-inv/datasets/all_instruments_test.tfrecord*'" \
103 | --gin_param="batch_size=64" \
104 | --alsologtostderr
105 | ```
106 |
107 | This command points to datasets on GCP, and the gin_params for file_pattern are redudant with the default values in the gin files, but provided here to show how you would modify them for local dataset paths.
108 | We have provided sharded TFRecord files for the [URMP dataset](http://www2.ece.rochester.edu/projects/air/projects/URMP/annotations_5P.html) on GCP at `gs://ddsp-inv/datasets/all_instruments_train.tfrecord*` and `gs://ddsp-inv/datasets/all_instruments_test.tfrecord*`.
109 | If training on GCP it is fast to directly read from these buckets, but if training locally you will probably want to download the files locally (~16 GB) using the `gsutil` command line utility from the [gcloud sdk](https://cloud.google.com/sdk/docs/downloads-interactive).
110 |
111 |
112 | In the paper, this model was trained with a batch size of 64 on 8 accelerators (8 per an accelerator), and typically converges after 200-400k iterations. A single v100 can fit a max batch size of 12, so you will need to use multiple GPUs or TPUs to exactly reproduce the experiment. To use a TPU, start up an instance from the web interface and pass the internal ip address to the tpu flag `--tpu=grpc://`.
113 |
114 |
115 | Finetuned models for +400k steps (batch size=64) are available on GCP the
116 | [URMP](http://www2.ece.rochester.edu/projects/air/projects/URMP/annotations_5P.html) ([ckpt](https://storage.googleapis.com/ddsp-inv/ckpts/urmp_ckpt.zip)),
117 | [MDB-stem-synth](https://zenodo.org/record/1481172#.Xzouy5NKhTY) ([ckpt](https://storage.googleapis.com/ddsp-inv/ckpts/mdb_stem_synth_ckpt.zip)),
118 | and [MIR1k](https://sites.google.com/site/unvoicedsoundseparation/mir-1k) ([ckpt](https://storage.googleapis.com/ddsp-inv/ckpts/mir1k_ckpt.zip)) datasets.
119 |
120 | ### Eval and Sample
121 |
122 | This will add evaluation metrics as scalar summaries for tensorboard.
123 |
124 | ```bash
125 | ddsp_run \
126 | --mode=eval \
127 | --run_once \
128 | --restore_dir=/path/to/trained_ckpt \
129 | --save_dir=/path/to/trained_ckpt \
130 | --gin_file=papers/icml2020/finetune_dataset.gin \
131 | --alsologtostderr
132 | ```
133 |
134 | This will add image and audio scalar summaries for tensorboard.
135 |
136 | ```bash
137 | ddsp_run \
138 | --mode=sample \
139 | --run_once \
140 | --restore_dir=/path/to/trained_ckpt \
141 | --save_dir=/path/to/trained_ckpt \
142 | --gin_file=papers/icml2020/finetune_dataset.gin \
143 | --alsologtostderr
144 | ```
145 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/icml2020/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/icml2020/finetune_dataset.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'papers/icml2020/pretrain_dataset.gin'
3 |
4 | # Datasets
5 | train.data_provider = @train_data/data.ZippedProvider()
6 | evaluate.data_provider = @test_data/data.TFRecordProvider()
7 | sample.data_provider = @test_data/data.TFRecordProvider()
8 |
9 |
10 |
11 | # Training
12 | train_data/ZippedProvider.data_providers = [
13 | @data.SyntheticNotes(),
14 | @train_data/data.TFRecordProvider(),
15 | ]
16 |
17 |
18 | # Make dataset with ddsp/training/data_preparation/ddsp_prepare_tfrecord.py
19 | # --gin_param="TFRecordProvider.file_pattern='/path/to/dataset*.tfrecord'"
20 |
21 | train_data/TFRecordProvider.file_pattern = 'gs://ddsp-inv/datasets/all_instruments_train.tfrecord*'
22 | test_data/TFRecordProvider.file_pattern = 'gs://ddsp-inv/datasets/all_instruments_test.tfrecord*'
23 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/icml2020/finetune_model.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | include 'papers/icml2020/pretrain_model.gin'
3 |
4 | SpectralLoss.mag_weight = 1.0
5 | SpectralLoss.logmag_weight = 1.0
6 |
7 | KDEConsistencyLoss.weight_mean_amp = 0.1
8 | KDEConsistencyLoss.weight_a = 0.1
9 | KDEConsistencyLoss.weight_b = 0.1
10 | KDEConsistencyLoss.scale_a = 0.1
11 | KDEConsistencyLoss.scale_b = 0.1
12 |
13 | FilteredNoiseConsistencyLoss.weight = 100.0
14 |
15 | HarmonicConsistencyLoss.amp_weight = 10.0
16 | HarmonicConsistencyLoss.dist_weight = 100.0
17 | HarmonicConsistencyLoss.f0_weight = 1.0
18 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/icml2020/pretrain_dataset.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 | import ddsp.training
3 |
4 |
5 | # Evaluate
6 | evaluate.batch_size = 8
7 | evaluate.num_batches = 256
8 |
9 |
10 | # Sample
11 | sample.batch_size = 16
12 | sample.num_batches = 1
13 | sample.ckpt_delay_secs = 300 # 5 minutes
14 |
15 |
16 | # Dataset
17 | train.data_provider = @data.SyntheticNotes()
18 | evaluate.data_provider = @data.SyntheticNotes()
19 | sample.data_provider = @data.SyntheticNotes()
20 |
21 |
22 | ## Create a synthetic dataset with ddsp/training/data_preparation/ddsp_generate_synthetic_data.py
23 | # Synthetic data generator.
24 | data.SyntheticNotes.file_pattern = 'gs://ddsp-inv/datasets/notes_t125_h100_m65_v2.tfrecord*'
25 | data.SyntheticNotes.n_timesteps = 125
26 | data.SyntheticNotes.n_harmonics = 100
27 | data.SyntheticNotes.n_mags = 65
28 |
--------------------------------------------------------------------------------
/ddsp/training/gin/papers/icml2020/pretrain_model.gin:
--------------------------------------------------------------------------------
1 | # -*-Python-*-
2 |
3 | import ddsp
4 | import ddsp.training
5 |
6 | include 'eval/basic_f0_ld_twm.gin'
7 |
8 | # =====
9 | # Model
10 | # =====
11 | get_model.model = @models.InverseSynthesis()
12 |
13 | InverseSynthesis.reverb = False
14 |
15 | # Sinusoidal Encoder
16 | InverseSynthesis.sinusoidal_encoder = @encoders.ResnetSinusoidalEncoder()
17 | ResnetSinusoidalEncoder.size = 'small'
18 | ResnetSinusoidalEncoder.output_splits = (('frequencies', 6400),
19 | ('amplitudes', 100),
20 | ('noise_magnitudes', 65))
21 |
22 |
23 | # FFT input parameters from onsets and frames transcription experiments.
24 | ResnetSinusoidalEncoder.spectral_fn = @f0_spectral/spectral_ops.compute_logmel
25 | f0_spectral/compute_logmel.lo_hz = 0.0
26 | f0_spectral/compute_logmel.hi_hz = 8000.0
27 | f0_spectral/compute_logmel.bins = 229
28 | f0_spectral/compute_logmel.fft_size = 2048
29 | f0_spectral/compute_logmel.overlap = 0.75
30 | f0_spectral/compute_logmel.pad_end = True
31 |
32 | # Harmonic Encoder
33 | InverseSynthesis.harmonic_encoder = @encoders.SinusoidalToHarmonicEncoder()
34 | SinusoidalToHarmonicEncoder.net = @nn.RnnSandwich()
35 |
36 |
37 | # Audio Losses
38 | InverseSynthesis.losses = [
39 | @losses.SpectralLoss(),
40 | ]
41 | SpectralLoss.loss_type = 'L1'
42 | SpectralLoss.mag_weight = 0.0
43 | SpectralLoss.logmag_weight = 0.0
44 |
45 |
46 | # Sinusoidal Consistency loss:
47 | InverseSynthesis.sinusoidal_consistency_losses = @losses.KDEConsistencyLoss()
48 | KDEConsistencyLoss.weight_a = 1.0
49 | KDEConsistencyLoss.weight_b = 1.0
50 | KDEConsistencyLoss.scale_a = 0.1
51 | KDEConsistencyLoss.scale_b = 0.1
52 |
53 |
54 | # Harmonic Consistency Loss.
55 | InverseSynthesis.harmonic_consistency_losses = @losses.HarmonicConsistencyLoss()
56 | HarmonicConsistencyLoss.amp_weight = 1.0
57 | HarmonicConsistencyLoss.dist_weight = 1.0
58 | HarmonicConsistencyLoss.f0_weight = 1.0
59 |
60 |
61 | # Filtered Noise Consistency loss:
62 | InverseSynthesis.filtered_noise_consistency_loss = @losses.FilteredNoiseConsistencyLoss()
63 | FilteredNoiseConsistencyLoss.weight = 1.0
64 |
--------------------------------------------------------------------------------
/ddsp/training/heuristics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.training.heuristics."""
16 |
17 | from absl.testing import parameterized
18 | from ddsp.training import heuristics
19 | import numpy as np
20 | import tensorflow.compat.v2 as tf
21 |
22 |
23 | class HeuristicsTest(parameterized.TestCase):
24 |
25 | def test_pad_for_frame(self):
26 | signal = np.zeros((10,))
27 | signal[0] = -1
28 | signal[-1] = 1
29 | frame_width = 4
30 | front_padded_signal = heuristics.pad_for_frame(
31 | signal, mode='front', frame_width=frame_width)
32 | framed = tf.signal.frame(front_padded_signal, frame_width, 1)
33 | self.assertEqual(framed.shape[0], signal.shape[0])
34 | np.testing.assert_array_equal(framed[0], [-1, -1, -1, -1])
35 | np.testing.assert_array_equal(framed[-1], [0, 0, 0, 1])
36 |
37 | back_padded_signal = heuristics.pad_for_frame(
38 | signal, mode='end', frame_width=frame_width)
39 | framed = tf.signal.frame(back_padded_signal, frame_width, 1)
40 | self.assertEqual(framed.shape[0], signal.shape[0])
41 | np.testing.assert_array_equal(framed[0], [-1, 0, 0, 0])
42 | np.testing.assert_array_equal(framed[-1], [1, 1, 1, 1])
43 |
44 | center_padded_signal = heuristics.pad_for_frame(
45 | signal, mode='center', frame_width=frame_width)
46 | framed = tf.signal.frame(center_padded_signal, frame_width, 1)
47 | self.assertEqual(framed.shape[0], signal.shape[0])
48 | np.testing.assert_array_equal(framed[0], [-1, -1, -1, 0])
49 | np.testing.assert_array_equal(framed[-1], [0, 0, 1, 1])
50 |
51 | def test_pad_for_frame_odd(self):
52 | signal = np.zeros((10,))
53 | signal[0] = -1
54 | signal[-1] = 1
55 | frame_width = 3
56 |
57 | center_padded_signal = heuristics.pad_for_frame(
58 | signal, mode='center', frame_width=frame_width)
59 | framed = tf.signal.frame(center_padded_signal, frame_width, 1)
60 | self.assertEqual(framed.shape[0], signal.shape[0])
61 | np.testing.assert_array_equal(framed[0], [-1, -1, 0])
62 | np.testing.assert_array_equal(framed[-1], [0, 1, 1])
63 |
64 | @parameterized.parameters(
65 | (10.0, 2.0, 0.5, (11, 32000)),
66 | (10.0, 10.0, 0.5, (3, 160000)),
67 | (10.0, 1.0, 0.5, (21, 16000)),
68 | (10.0, 1.0, 0.75, (14, 16000)),
69 | )
70 | def test_window_array(self, dur, win_len, frame_step_ratio, expected_shape):
71 | sr = 16000
72 | sw = np.sin(np.linspace(start=0.0, stop=dur, num=int(sr*dur)))
73 |
74 | wa = heuristics.window_array(sw, sr, win_len, frame_step_ratio)
75 | self.assertSequenceEqual(wa.shape, expected_shape)
76 | self.assertEqual(wa.shape[-1], int(sr*win_len))
77 |
78 |
79 | if __name__ == '__main__':
80 | tf.test.main()
81 |
--------------------------------------------------------------------------------
/ddsp/training/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module with all the global configurable models for training."""
16 |
17 | from ddsp.training.models.autoencoder import Autoencoder
18 | from ddsp.training.models.inverse_synthesis import InverseSynthesis
19 | from ddsp.training.models.midi_autoencoder import MidiAutoencoder
20 | from ddsp.training.models.midi_autoencoder import ZMidiAutoencoder
21 | from ddsp.training.models.model import Model
22 | import gin
23 |
24 | _configurable = lambda cls: gin.configurable(cls, module=__name__)
25 |
26 | Autoencoder = _configurable(Autoencoder)
27 | InverseSynthesis = _configurable(InverseSynthesis)
28 | MidiAutoencoder = _configurable(MidiAutoencoder)
29 | ZMidiAutoencoder = _configurable(ZMidiAutoencoder)
30 |
31 |
32 |
33 | @gin.configurable
34 | def get_model(model=gin.REQUIRED):
35 | """Gin configurable function get a 'global' model for use in ddsp_run.py.
36 |
37 | Convenience for using the same model in train(), evaluate(), and sample().
38 | Args:
39 | model: An instantiated model, such as 'models.Autoencoder()'.
40 |
41 | Returns:
42 | The 'global' model specified in the gin config.
43 | """
44 | return model
45 |
--------------------------------------------------------------------------------
/ddsp/training/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Model that encodes audio features and decodes with a ddsp processor group."""
16 |
17 | import ddsp
18 | from ddsp.training.models.model import Model
19 |
20 |
21 | class Autoencoder(Model):
22 | """Wrap the model function for dependency injection with gin."""
23 |
24 | def __init__(self,
25 | preprocessor=None,
26 | encoder=None,
27 | decoder=None,
28 | processor_group=None,
29 | losses=None,
30 | **kwargs):
31 | super().__init__(**kwargs)
32 | self.preprocessor = preprocessor
33 | self.encoder = encoder
34 | self.decoder = decoder
35 | self.processor_group = processor_group
36 | self.loss_objs = ddsp.core.make_iterable(losses)
37 |
38 | def encode(self, features, training=True):
39 | """Get conditioning by preprocessing then encoding."""
40 | if self.preprocessor is not None:
41 | features.update(self.preprocessor(features, training=training))
42 | if self.encoder is not None:
43 | features.update(self.encoder(features))
44 | return features
45 |
46 | def decode(self, features, training=True):
47 | """Get generated audio by decoding than processing."""
48 | features.update(self.decoder(features, training=training))
49 | return self.processor_group(features)
50 |
51 | def get_audio_from_outputs(self, outputs):
52 | """Extract audio output tensor from outputs dict of call()."""
53 | return outputs['audio_synth']
54 |
55 | def call(self, features, training=True):
56 | """Run the core of the network, get predictions and loss."""
57 | features = self.encode(features, training=training)
58 | features.update(self.decoder(features, training=training))
59 |
60 | # Run through processor group.
61 | pg_out = self.processor_group(features, return_outputs_dict=True)
62 |
63 | # Parse outputs
64 | outputs = pg_out['controls']
65 | outputs['audio_synth'] = pg_out['signal']
66 |
67 | if training:
68 | self._update_losses_dict(
69 | self.loss_objs, features['audio'], outputs['audio_synth'])
70 |
71 | return outputs
72 |
73 |
--------------------------------------------------------------------------------
/ddsp/training/models/autoencoder_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.training.models.autoencoder."""
16 |
17 | from absl.testing import parameterized
18 | from ddsp.core import tf_float32
19 | from ddsp.training import models
20 | import gin
21 | import numpy as np
22 | import pkg_resources
23 | import tensorflow as tf
24 |
25 | GIN_PATH = pkg_resources.resource_filename(__name__, '../gin')
26 | gin.add_config_file_search_path(GIN_PATH)
27 |
28 |
29 | class AutoencoderTest(parameterized.TestCase, tf.test.TestCase):
30 |
31 | @parameterized.named_parameters(
32 | ('nsynth_ae', 16000, 250, False, 'papers/iclr2020/nsynth_ae.gin'),
33 | ('solo_instrument', 16000, 250, False,
34 | 'papers/iclr2020/solo_instrument.gin'),
35 | ('vst_16kHz', 16000, 50, True, 'models/vst/vst.gin'),
36 | ('vst_32kHz', 32000, 50, True, 'models/vst/vst_32k.gin'),
37 | ('vst_48kHz', 48000, 50, True, 'models/vst/vst_48k.gin'),
38 | )
39 | def test_build_model(self, sample_rate, frame_rate, centered, gin_file):
40 | """Tests if Model builds properly and produces audio of correct shape.
41 |
42 | Args:
43 | sample_rate: Sample rate of audio.
44 | frame_rate: Frame rate of features.
45 | centered: Add an additional frame.
46 | gin_file: Name of gin_file to use.
47 | """
48 | n_batch = 1
49 | n_secs = 4
50 | # frame_size = sample_rate // frame_rate
51 | n_frames = int(frame_rate * n_secs)
52 | n_samples = int(sample_rate * n_secs)
53 | n_samples_16k = int(16000 * n_secs)
54 | if centered:
55 | n_frames += 1
56 | # n_samples += frame_size
57 | # n_samples_16k += 320
58 |
59 | inputs = {
60 | 'loudness_db': np.zeros([n_batch, n_frames]),
61 | 'f0_hz': np.zeros([n_batch, n_frames]),
62 | 'f0_confidence': np.zeros([n_batch, n_frames]),
63 | 'audio': np.random.randn(n_batch, n_samples),
64 | 'audio_16k': np.random.randn(n_batch, n_samples_16k),
65 | }
66 | inputs = {k: tf_float32(v) for k, v in inputs.items()}
67 |
68 | with gin.unlock_config():
69 | gin.clear_config()
70 | gin.parse_config_file(gin_file)
71 |
72 | model = models.Autoencoder()
73 | outputs = model(inputs)
74 | self.assertIsInstance(outputs, dict)
75 | # Confirm that model generates correctly sized audio.
76 | audio_gen = model.get_audio_from_outputs(outputs)
77 | audio_gen_shape = audio_gen.shape.as_list()
78 | self.assertEqual(audio_gen_shape, list(inputs['audio'].shape))
79 |
80 |
81 | if __name__ == '__main__':
82 | tf.test.main()
83 |
--------------------------------------------------------------------------------
/ddsp/training/models/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Model base class."""
16 |
17 | import time
18 |
19 | from absl import logging
20 | import ddsp
21 | from ddsp.core import copy_if_tf_function
22 | from ddsp.training import train_util
23 | import tensorflow as tf
24 |
25 |
26 | class Model(tf.keras.Model):
27 | """Base class for all models."""
28 |
29 | def __init__(self, **kwargs):
30 | super().__init__(**kwargs)
31 | self._losses_dict = {}
32 |
33 | def __call__(self, *args, return_losses=False, **kwargs):
34 | """Reset the losses dict on each call.
35 |
36 | Args:
37 | *args: Arguments passed on to call().
38 | return_losses: Return a dictionary of losses in addition to the call()
39 | function returns.
40 | **kwargs: Keyword arguments passed on to call().
41 |
42 | Returns:
43 | outputs: A dictionary of model outputs generated in call().
44 | {output_name: output_tensor or dict}.
45 | losses: If return_losses=True, also returns a dictionary of losses,
46 | {loss_name: loss_value}.
47 | """
48 | # Copy mutable dicts if in graph mode to prevent side-effects (pure func).
49 | args = [copy_if_tf_function(a) if isinstance(a, dict) else a for a in args]
50 |
51 | # Run model.
52 | self._losses_dict = {}
53 | outputs = super().__call__(*args, **kwargs)
54 |
55 | # Get total loss.
56 | if not return_losses:
57 | return outputs
58 | else:
59 | self._losses_dict['total_loss'] = self.sum_losses(self._losses_dict)
60 | return outputs, self._losses_dict
61 |
62 | def sum_losses(self, losses_dict):
63 | """Sum all the scalar losses in a dictionary."""
64 | return tf.reduce_sum(list(losses_dict.values()))
65 |
66 | def _update_losses_dict(self, loss_objs, *args, **kwargs):
67 | """Helper function to run loss objects on args and add to model losses."""
68 | for loss_obj in ddsp.core.make_iterable(loss_objs):
69 | if hasattr(loss_obj, 'get_losses_dict'):
70 | losses_dict = loss_obj.get_losses_dict(*args, **kwargs)
71 | self._losses_dict.update(losses_dict)
72 |
73 | def restore(self, checkpoint_path, verbose=True, restore_keys=None):
74 | """Restore model and optimizer from a checkpoint.
75 |
76 | Args:
77 | checkpoint_path: Path to checkpoint file or directory.
78 | verbose: Warn about missing variables.
79 | restore_keys: Optional list of strings for submodules to restore.
80 |
81 | Raises:
82 | FileNotFoundError: If no checkpoint is found.
83 | """
84 | start_time = time.time()
85 | latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path)
86 |
87 | if restore_keys is None:
88 | # If no keys are passed, restore the whole model.
89 | checkpoint = tf.train.Checkpoint(model=self)
90 | logging.info('Model restoring all components.')
91 | if verbose:
92 | checkpoint.restore(latest_checkpoint)
93 | else:
94 | checkpoint.restore(latest_checkpoint).expect_partial()
95 |
96 | else:
97 | # Restore only sub-modules by building a new subgraph.
98 | # Following https://www.tensorflow.org/guide/checkpoint#loading_mechanics.
99 | logging.info('Trainer restoring model subcomponents:')
100 | for k in restore_keys:
101 | to_restore = {k: getattr(self, k)}
102 | log_str = 'Restoring {}'.format(to_restore)
103 | logging.info(log_str)
104 | fake_model = tf.train.Checkpoint(**to_restore)
105 | new_root = tf.train.Checkpoint(model=fake_model)
106 | status = new_root.restore(latest_checkpoint)
107 | status.assert_existing_objects_matched()
108 |
109 | logging.info('Loaded checkpoint %s', latest_checkpoint)
110 | logging.info('Loading model took %.1f seconds', time.time() - start_time)
111 |
112 | def get_audio_from_outputs(self, outputs):
113 | """Extract audio output tensor from outputs dict of call()."""
114 | raise NotImplementedError('Must implement `self.get_audio_from_outputs()`.')
115 |
116 | def call(self, *args, training=False, **kwargs):
117 | """Run the forward pass, add losses, and create a dictionary of outputs.
118 |
119 | This function must run the forward pass, add losses to self._losses_dict and
120 | return a dictionary of all the relevant output tensors.
121 |
122 | Args:
123 | *args: Args for forward pass.
124 | training: Required `training` kwarg passed in by keras.
125 | **kwargs: kwargs for forward pass.
126 |
127 | Returns:
128 | Dictionary of all relevant tensors.
129 | """
130 | raise NotImplementedError('Must implement a `self.call()` method.')
131 |
132 |
133 |
--------------------------------------------------------------------------------
/ddsp/training/plotting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Plotting utilities for the DDSP library. Useful in colab and elsewhere."""
16 |
17 | from ddsp import core
18 | from ddsp import spectral_ops
19 | from matplotlib import gridspec
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import tensorflow.compat.v2 as tf
23 |
24 | DEFAULT_SAMPLE_RATE = spectral_ops.CREPE_SAMPLE_RATE
25 |
26 |
27 | def specplot(audio,
28 | vmin=-5,
29 | vmax=1,
30 | rotate=True,
31 | size=512 + 256,
32 | **matshow_kwargs):
33 | """Plot the log magnitude spectrogram of audio."""
34 | # If batched, take first element.
35 | if len(audio.shape) == 2:
36 | audio = audio[0]
37 |
38 | logmag = spectral_ops.compute_logmag(core.tf_float32(audio), size=size)
39 | if rotate:
40 | logmag = np.rot90(logmag)
41 | # Plotting.
42 | plt.matshow(logmag,
43 | vmin=vmin,
44 | vmax=vmax,
45 | cmap=plt.cm.magma,
46 | aspect='auto',
47 | **matshow_kwargs)
48 | plt.xticks([])
49 | plt.yticks([])
50 | plt.xlabel('Time')
51 | plt.ylabel('Frequency')
52 |
53 |
54 | def transfer_function(ir, sample_rate=DEFAULT_SAMPLE_RATE):
55 | """Get true transfer function from an impulse_response."""
56 | n_fft = core.get_fft_size(0, ir.shape.as_list()[-1])
57 | frequencies = np.abs(
58 | np.fft.fftfreq(n_fft, 1 / sample_rate)[:int(n_fft / 2) + 1])
59 | magnitudes = tf.abs(tf.signal.rfft(ir, [n_fft]))
60 | return frequencies, magnitudes
61 |
62 |
63 | def plot_impulse_responses(impulse_response,
64 | desired_magnitudes,
65 | sample_rate=DEFAULT_SAMPLE_RATE):
66 | """Plot a target frequency response, and that of an impulse response."""
67 | n_fft = desired_magnitudes.shape[-1] * 2
68 | frequencies = np.fft.fftfreq(n_fft, 1 / sample_rate)[:n_fft // 2]
69 | true_frequencies, true_magnitudes = transfer_function(impulse_response)
70 |
71 | # Plot it.
72 | plt.figure(figsize=(12, 6))
73 | plt.subplot(121)
74 | # Desired transfer function.
75 | plt.semilogy(frequencies, desired_magnitudes, label='Desired')
76 | # True transfer function.
77 | plt.semilogy(true_frequencies, true_magnitudes[0, 0, :], label='True')
78 | plt.title('Transfer Function')
79 | plt.legend()
80 |
81 | plt.subplot(122)
82 | plt.plot(impulse_response[0, 0, :])
83 | plt.title('Impulse Response')
84 |
85 |
86 | def pianoroll_plot_setup(figsize=None, side_piano_ratio=0.025,
87 | faint_pr=True, xlim=None):
88 | """Makes a tiny piano left of the y-axis and a faint piano on the main figure.
89 |
90 | This function sets up the figure for pretty plotting a piano roll. It makes a
91 | small imshow plot to the left of the main plot that looks like a piano. This
92 | piano side plot is aligned along the y-axis of the main plot, such that y
93 | values align with MIDI values (y=0 is the lowest C-1, y=11 is C0, etc).
94 | Additionally, a main figure is set up that shares the y-axis of the piano side
95 | plot. Optionally, a set of faint horizontal lines are drawn on the main figure
96 | that correspond to the black keys on the piano (and a line separating B & C
97 | and E & F). This function returns the formatted figure, the side piano axis,
98 | and the main axis for plotting your data.
99 |
100 | By default, this will draw 11 octaves of piano keys along the y-axis; you will
101 | probably want reduce what is visible using `ax.set_ylim()` on either returned
102 | axis.
103 |
104 | Using with imshow piano roll data:
105 | A common use case is for using imshow() on the main axis to display a piano
106 | roll alongside the piano side plot AND the faint piano roll behind your
107 | data. In this case, if your data is a 2D array you have to use a masked
108 | numpy array to make certain values invisible on the plot, and therefore make
109 | the faint piano roll visible. Here's an example:
110 |
111 | midi = np.flipud([
112 | [0.0, 0.0, 1.0],
113 | [0.0, 1.0, 0.0],
114 | [1.0, 0.0, 0.0],
115 | ])
116 |
117 | midi_masked = np.ma.masked_values(midi, 0.0) # Mask out all 0.0's
118 | fig, ax, sp = plotting.pianoroll_plot_setup()
119 | ax.imshow(midi_masked, origin='lower', aspect='auto') # main subplot axis
120 | sp.set_ylabel('My favorite MIDI data') # side piano axis
121 | fig.show()
122 |
123 | The other option is to use imshow in RGBA mode, where your data is split
124 | into 4 channels. Every alpha value that is 0.0 will be transparent and show
125 | the faint piano roll below your data.
126 |
127 | Args:
128 | figsize: Size if the matplotlib figure. Will be passed to `plt.figure()`.
129 | Defaults to None.
130 | side_piano_ratio: Width of the y-axis piano in terms of raio of the whole
131 | figure. Defaults to 1/40th.
132 | faint_pr: Whether to draw faint black & white keys across the main plot.
133 | Defaults to True.
134 | xlim: Tuple containing the min and max of the x values for the main plot.
135 | Only used to determine the x limits for the faint piano roll in the main
136 | plot. Defaults to (0, 1000).
137 |
138 | Returns:
139 | (figure, main_axis, left_piano_axis)
140 | figure: A matplotlib figure object containing both subplots set up with an
141 | aligned piano roll.
142 | main_axis: A matplotlib axis object to be used for plotting. Optionally
143 | has a faint piano roll in the background.
144 | left_piano_axis: A matplotlib axis object that has a small, aligned piano
145 | along the left side y-axis of the main_axis subplot.
146 | """
147 | octaves = 11
148 |
149 | # Setup figure and gridspec.
150 | fig = plt.figure(figsize=figsize)
151 | gs_ratio = int(1 / side_piano_ratio)
152 | gs = gridspec.GridSpec(1, 2, width_ratios=[1, gs_ratio])
153 | left_piano_ax = fig.add_subplot(gs[0])
154 |
155 | # Make a piano on the left side of the y-axis with imshow().
156 | keys = np.array(
157 | [0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0] # notes in descending order; B -> C
158 | )
159 | keys = np.tile(keys, octaves)[:, None]
160 | left_piano_ax.imshow(keys, cmap='binary', aspect='auto',
161 | extent=[0, 0.625, -0.5, octaves*12-0.5])
162 |
163 | # Make the lines between keys.
164 | for i in range(octaves):
165 | left_piano_ax.hlines(i*12 - 0.5, -0.5, 1, colors='black', linewidth=0.5)
166 | left_piano_ax.hlines(i*12 + 1.0, -0.5, 1, colors='black', linewidth=0.5)
167 | left_piano_ax.hlines(i*12 + 3.0, -0.5, 1, colors='black', linewidth=0.5)
168 | left_piano_ax.hlines(i*12 + 4.5, -0.5, 1, colors='black', linewidth=0.5)
169 | left_piano_ax.hlines(i*12 + 6.0, -0.5, 1, colors='black', linewidth=0.5)
170 | left_piano_ax.hlines(i*12 + 8.0, -0.5, 1, colors='black', linewidth=0.5)
171 | left_piano_ax.hlines(i*12 + 10.0, -0.5, 1, colors='black', linewidth=0.5)
172 |
173 | # Set the limits of the side piano and remove ticks so it looks nice.
174 | left_piano_ax.set_xlim(0, 0.995)
175 | left_piano_ax.set_xticks([])
176 |
177 | # Create the aligned axis we'll return to the user.
178 | main_ax = fig.add_subplot(gs[1], sharey=left_piano_ax)
179 |
180 | # Draw a faint piano roll behind the main axes (if the user wants).
181 | if faint_pr:
182 | xlim = (0, 1000) if xlim is None else xlim
183 | x_min, x_max = xlim
184 | x_delta = x_max - x_min
185 | main_ax.imshow(np.tile(keys, x_delta), cmap='binary', aspect='auto',
186 | alpha=0.05, extent=[x_min, x_max, -0.5, octaves*12-0.5])
187 | for i in range(octaves):
188 | main_ax.hlines(i * 12 + 4.5, x_min, x_max, colors='black',
189 | linewidth=0.5, alpha=0.25)
190 | main_ax.hlines(i * 12 - 0.5, x_min, x_max, colors='black',
191 | linewidth=0.5, alpha=0.25)
192 |
193 | main_ax.set_xlim(*xlim)
194 |
195 | # Some final cosmetic tweaks before returning the axis obj's and figure.
196 | plt.setp(main_ax.get_yticklabels(), visible=False)
197 | gs.tight_layout(fig)
198 | return fig, main_ax, left_piano_ax
199 |
--------------------------------------------------------------------------------
/ddsp/training/preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of preprocess functions."""
16 |
17 | import ddsp
18 | from ddsp.training import nn
19 | import gin
20 | import tensorflow as tf
21 |
22 | F0_RANGE = ddsp.spectral_ops.F0_RANGE
23 | DB_RANGE = ddsp.spectral_ops.DB_RANGE
24 |
25 | tfkl = tf.keras.layers
26 |
27 |
28 | # ---------------------- Preprocess Helpers ------------------------------------
29 | def at_least_3d(x):
30 | """Optionally adds time, batch, then channel dimension."""
31 | x = x[tf.newaxis] if not x.shape else x
32 | x = x[tf.newaxis, :] if len(x.shape) == 1 else x
33 | x = x[:, :, tf.newaxis] if len(x.shape) == 2 else x
34 | return x
35 |
36 |
37 | def scale_db(db):
38 | """Scales [-DB_RANGE, 0] to [0, 1]."""
39 | return (db / DB_RANGE) + 1.0
40 |
41 |
42 | def inv_scale_db(db_scaled):
43 | """Scales [0, 1] to [-DB_RANGE, 0]."""
44 | return (db_scaled - 1.0) * DB_RANGE
45 |
46 |
47 | def scale_f0_hz(f0_hz):
48 | """Scales [0, Nyquist] Hz to [0, 1.0] MIDI-scaled."""
49 | return ddsp.core.hz_to_midi(f0_hz) / F0_RANGE
50 |
51 |
52 | def inv_scale_f0_hz(f0_scaled):
53 | """Scales [0, 1.0] MIDI-scaled to [0, Nyquist] Hz."""
54 | return ddsp.core.midi_to_hz(f0_scaled * F0_RANGE)
55 |
56 |
57 | # ---------------------- Preprocess objects ------------------------------------
58 | @gin.register
59 | class F0LoudnessPreprocessor(nn.DictLayer):
60 | """Resamples and scales 'f0_hz' and 'loudness_db' features."""
61 |
62 | def __init__(self,
63 | time_steps=1000,
64 | frame_rate=250,
65 | sample_rate=16000,
66 | compute_loudness=True,
67 | **kwargs):
68 | super().__init__(**kwargs)
69 | self.time_steps = time_steps
70 | self.frame_rate = frame_rate
71 | self.sample_rate = sample_rate
72 | self.compute_loudness = compute_loudness
73 |
74 | def call(self, loudness_db, f0_hz, audio=None) -> [
75 | 'f0_hz', 'loudness_db', 'f0_scaled', 'ld_scaled']:
76 | # Compute loudness fresh (it's fast).
77 | if self.compute_loudness:
78 | loudness_db = ddsp.spectral_ops.compute_loudness(
79 | audio,
80 | sample_rate=self.sample_rate,
81 | frame_rate=self.frame_rate)
82 |
83 | # Resample features to the frame_rate.
84 | f0_hz = self.resample(f0_hz)
85 | loudness_db = self.resample(loudness_db)
86 | # For NN training, scale frequency and loudness to the range [0, 1].
87 | # Log-scale f0 features. Loudness from [-1, 0] to [1, 0].
88 | f0_scaled = scale_f0_hz(f0_hz)
89 | ld_scaled = scale_db(loudness_db)
90 | return f0_hz, loudness_db, f0_scaled, ld_scaled
91 |
92 | @staticmethod
93 | def invert_scaling(f0_scaled, ld_scaled):
94 | """Takes in scaled f0 and loudness, and puts them back to hz & db scales."""
95 | f0_hz = inv_scale_f0_hz(f0_scaled)
96 | loudness_db = inv_scale_db(ld_scaled)
97 | return f0_hz, loudness_db
98 |
99 | def resample(self, x):
100 | x = at_least_3d(x)
101 | return ddsp.core.resample(x, self.time_steps)
102 |
103 |
104 | @gin.register
105 | class F0PowerPreprocessor(F0LoudnessPreprocessor):
106 | """Dynamically compute additional power_db feature."""
107 |
108 | def __init__(self,
109 | time_steps=1000,
110 | frame_rate=250,
111 | sample_rate=16000,
112 | frame_size=64,
113 | **kwargs):
114 | super().__init__(time_steps, **kwargs)
115 | self.frame_rate = frame_rate
116 | self.sample_rate = sample_rate
117 | self.frame_size = frame_size
118 |
119 | def call(self, f0_hz, power_db=None, audio=None) -> [
120 | 'f0_hz', 'pw_db', 'f0_scaled', 'pw_scaled']:
121 | """Compute power on the fly if it's not in the inputs."""
122 | # For NN training, scale frequency and loudness to the range [0, 1].
123 | f0_hz = self.resample(f0_hz)
124 | f0_scaled = scale_f0_hz(f0_hz)
125 |
126 | if power_db is not None:
127 | # Use dataset values if present.
128 | pw_db = power_db
129 | elif audio is not None:
130 | # Otherwise, compute power on the fly.
131 | pw_db = ddsp.spectral_ops.compute_power(audio,
132 | sample_rate=self.sample_rate,
133 | frame_rate=self.frame_rate,
134 | frame_size=self.frame_size)
135 | else:
136 | raise ValueError('Power preprocessing requires either '
137 | '"power_db" or "audio" keys to be provided '
138 | 'in the dataset.')
139 | # Resample.
140 | pw_db = self.resample(pw_db)
141 | # Scale power.
142 | pw_scaled = scale_db(pw_db)
143 |
144 | return f0_hz, pw_db, f0_scaled, pw_scaled
145 |
146 | @staticmethod
147 | def invert_scaling(f0_scaled, pw_scaled):
148 | """Puts scaled f0, loudness, and power back to hz & db scales."""
149 | f0_hz = inv_scale_f0_hz(f0_scaled)
150 | power_db = inv_scale_db(pw_scaled)
151 | return f0_hz, power_db
152 |
153 |
154 | @gin.register
155 | class OnlineF0PowerPreprocessor(nn.DictLayer):
156 | """Dynamically compute power_db and f0_hz with optional centered frames."""
157 |
158 | def __init__(self,
159 | frame_rate=250,
160 | frame_size=1024,
161 | padding='center',
162 | compute_power=True,
163 | compute_f0=True,
164 | crepe_saved_model_path='full',
165 | viterbi=False,
166 | **kwargs):
167 | super().__init__(**kwargs)
168 | # Preprocessing must happen at 16kHz because CREPE trained at 16kHz.
169 | self.sample_rate = ddsp.spectral_ops.CREPE_SAMPLE_RATE
170 | self.frame_rate = frame_rate
171 | self.frame_size = frame_size
172 | self.hop_size = self.sample_rate // frame_rate
173 |
174 | self.compute_f0 = compute_f0
175 | self.compute_power = compute_power
176 |
177 | self.padding = padding
178 |
179 | # Crepe model, must either be a model size or path to SavedModel.
180 | if crepe_saved_model_path:
181 | self.crepe_model = ddsp.spectral_ops.PretrainedCREPE(
182 | model_size_or_path=crepe_saved_model_path, hop_size=self.hop_size
183 | )
184 |
185 | # Use viterbi decoding.
186 | self.viterbi = viterbi
187 |
188 | def call(
189 | self, audio, f0_hz=None, f0_confidence=None, audio_16k=None, pw_db=None
190 | ) -> ['f0_hz', 'pw_db', 'f0_scaled', 'pw_scaled', 'f0_confidence']:
191 | """Compute power on the fly if it's not in the inputs."""
192 | # Compute features at 16kHz (needed for CREPE).
193 | if audio_16k is not None:
194 | audio = audio_16k
195 |
196 | # Compute power and f0 on the fly.
197 | if self.compute_power:
198 | pw_db = ddsp.spectral_ops.compute_power(audio,
199 | sample_rate=self.sample_rate,
200 | frame_rate=self.frame_rate,
201 | frame_size=self.frame_size,
202 | padding=self.padding)
203 |
204 | if self.compute_f0:
205 | f0_hz, f0_confidence = self.crepe_model.predict_f0_and_confidence(
206 | audio, viterbi=self.viterbi, padding=self.padding)
207 | # Stop gradients from flowing to CREPE.
208 | f0_hz = tf.stop_gradient(f0_hz)
209 | f0_confidence = tf.stop_gradient(f0_confidence)
210 | elif f0_hz is None or f0_confidence is None:
211 | raise ValueError('Preprocessor must either have `compute_f0=True`, or'
212 | '__call__ must be supplied 3 arguments, '
213 | '[audio, f0_hz, and f0_confidence].')
214 |
215 | # For NN training, scale frequency and loudness to the range [0, 1].
216 | pw_db = at_least_3d(pw_db)
217 | f0_hz = at_least_3d(f0_hz)
218 |
219 | pw_scaled = scale_db(pw_db)
220 | f0_scaled = scale_f0_hz(f0_hz)
221 |
222 | # For sanity checking.
223 | # You need to define time_steps correctly just to make sure you know what
224 | # you're doing, and what the model output shape is (no interpolation).
225 | n_t = audio.shape[1]
226 | time_steps, _ = ddsp.spectral_ops.get_framed_lengths(
227 | n_t, self.frame_size, self.hop_size, self.padding)
228 | for k, output in {
229 | 'f0_hz': f0_hz,
230 | 'pw_db': pw_db,
231 | 'f0_scaled': f0_scaled,
232 | 'pw_scaled': pw_scaled,
233 | 'f0_confidence': f0_confidence}.items():
234 | if output.shape[1] != time_steps:
235 | raise ValueError(
236 | f'OnlineF0PowerPreprocessor output: ({k}) does not have '
237 | f'{time_steps} timesteps. Output shape: {output.shape}. '
238 | f'\nInputs: seconds ({n_t/self.sample_rate}), '
239 | f'frame_rate ({self.frame_rate}), '
240 | f'padding ("{self.padding}").')
241 |
242 | return f0_hz, pw_db, f0_scaled, pw_scaled, f0_confidence
243 |
244 |
245 |
--------------------------------------------------------------------------------
/ddsp/training/preprocessing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for ddsp.training.preprocessing."""
16 |
17 | from absl.testing import parameterized
18 | from ddsp.core import resample
19 | from ddsp.spectral_ops import compute_power
20 | from ddsp.training import preprocessing
21 | import tensorflow as tf
22 |
23 | tfkl = tf.keras.layers
24 |
25 |
26 | class F0PowerPreprocessorTest(parameterized.TestCase, tf.test.TestCase):
27 |
28 | def setUp(self):
29 | """Create input dictionary and preprocessor."""
30 | super().setUp()
31 | sr = 16000
32 | frame_rate = 250
33 | frame_size = 256
34 | n_samples = 16000
35 | n_t = 250
36 | # Replicate preprocessor computations.
37 | audio = 0.5 * tf.sin(tf.range(0, n_samples, dtype=tf.float32))[None, :]
38 | power_db = compute_power(audio,
39 | sample_rate=sr,
40 | frame_rate=frame_rate,
41 | frame_size=frame_size)
42 | power_db = preprocessing.at_least_3d(power_db)
43 | power_db = resample(power_db, n_t)
44 | self.input_dict = {
45 | 'f0_hz': tf.ones([1, n_t]),
46 | 'audio': audio,
47 | 'power_db': power_db,
48 | }
49 | self.preprocessor = preprocessing.F0PowerPreprocessor(
50 | time_steps=n_t,
51 | frame_rate=frame_rate,
52 | sample_rate=sr)
53 |
54 | @parameterized.named_parameters(
55 | ('audio_only', ['audio']),
56 | ('power_only', ['power_db']),
57 | ('audio_and_power', ['audio', 'power_db']),
58 | )
59 | def test_audio_only(self, input_keys):
60 | input_keys += ['f0_hz']
61 | inputs = {k: v for k, v in self.input_dict.items() if k in input_keys}
62 | outputs = self.preprocessor(inputs)
63 | self.assertAllClose(self.input_dict['power_db'],
64 | outputs['pw_db'],
65 | rtol=0.5,
66 | atol=30)
67 |
68 | if __name__ == '__main__':
69 | tf.test.main()
70 |
--------------------------------------------------------------------------------
/ddsp/training/trainers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library of Trainer objects that define traning step and wrap optimizer."""
16 |
17 | import time
18 |
19 | from absl import logging
20 | from ddsp.training import train_util
21 | import gin
22 | import tensorflow.compat.v2 as tf
23 |
24 |
25 | @gin.configurable
26 | class Trainer(object):
27 | """Class to bind an optimizer, model, strategy, and training step function."""
28 |
29 | def __init__(self,
30 | model,
31 | strategy,
32 | checkpoints_to_keep=100,
33 | learning_rate=0.001,
34 | lr_decay_steps=10000,
35 | lr_decay_rate=0.98,
36 | grad_clip_norm=3.0,
37 | restore_keys=None):
38 | """Constructor.
39 |
40 | Args:
41 | model: Model to train.
42 | strategy: A distribution strategy.
43 | checkpoints_to_keep: Max number of checkpoints before deleting oldest.
44 | learning_rate: Scalar initial learning rate.
45 | lr_decay_steps: Exponential decay timescale.
46 | lr_decay_rate: Exponential decay magnitude.
47 | grad_clip_norm: Norm level by which to clip gradients.
48 | restore_keys: List of names of model properties to restore. If no keys are
49 | passed, restore the whole model.
50 | """
51 | self.model = model
52 | self.strategy = strategy
53 | self.checkpoints_to_keep = checkpoints_to_keep
54 | self.grad_clip_norm = grad_clip_norm
55 | self.restore_keys = restore_keys
56 |
57 | # Create an optimizer.
58 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
59 | initial_learning_rate=learning_rate,
60 | decay_steps=lr_decay_steps,
61 | decay_rate=lr_decay_rate)
62 |
63 | with self.strategy.scope():
64 | self.optimizer = tf.keras.optimizers.Adam(lr_schedule)
65 |
66 | def get_checkpoint(self, model=None):
67 | """Model arg can also be a tf.train.Checkpoint(**dict(submodules))."""
68 | model = model or self.model # Default to full model.
69 | return tf.train.Checkpoint(model=model, optimizer=self.optimizer)
70 |
71 | def save(self, save_dir):
72 | """Saves model and optimizer to a checkpoint."""
73 | # Saving weights in checkpoint format because saved_model requires
74 | # handling variable batch size, which some synths and effects can't.
75 | start_time = time.time()
76 | checkpoint = self.get_checkpoint()
77 | manager = tf.train.CheckpointManager(
78 | checkpoint, directory=save_dir, max_to_keep=self.checkpoints_to_keep)
79 | step = self.step.numpy()
80 | manager.save(checkpoint_number=step)
81 | logging.info('Saved checkpoint to %s at step %s', save_dir, step)
82 | logging.info('Saving model took %.1f seconds', time.time() - start_time)
83 |
84 | def restore(self, checkpoint_path, restore_keys=None):
85 | """Restore model and optimizer from a checkpoint if it exists.
86 |
87 | Args:
88 | checkpoint_path: Path to checkpoint file or directory.
89 | restore_keys: Optional list of strings for submodules to restore.
90 |
91 | Raises:
92 | FileNotFoundError: If no checkpoint is found.
93 | """
94 | logging.info('Restoring from checkpoint...')
95 | start_time = time.time()
96 |
97 | # Prefer function args over object properties.
98 | restore_keys = restore_keys or self.restore_keys
99 | if restore_keys is None:
100 | # If no keys are passed, restore the whole model.
101 | model = self.model
102 | logging.info('Trainer restoring the full model')
103 | else:
104 | # Restore only sub-modules by building a new subgraph.
105 | restore_dict = {k: getattr(self.model, k) for k in restore_keys}
106 | model = tf.train.Checkpoint(**restore_dict)
107 |
108 | logging.info('Trainer restoring model subcomponents:')
109 | for k, v in restore_dict.items():
110 | log_str = 'Restoring {}: {}'.format(k, v)
111 | logging.info(log_str)
112 |
113 | # Restore from latest checkpoint.
114 | checkpoint = self.get_checkpoint(model)
115 | latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path)
116 | # checkpoint.restore must be within a strategy.scope() so that optimizer
117 | # slot variables are mirrored.
118 | with self.strategy.scope():
119 | if restore_keys is None:
120 | checkpoint.restore(latest_checkpoint)
121 | else:
122 | checkpoint.restore(latest_checkpoint).expect_partial()
123 | logging.info('Loaded checkpoint %s', latest_checkpoint)
124 | logging.info('Loading model took %.1f seconds', time.time() - start_time)
125 |
126 | @property
127 | def step(self):
128 | """The number of training steps completed."""
129 | return self.optimizer.iterations
130 |
131 | def psum(self, x, axis=None):
132 | """Sum across processors."""
133 | return self.strategy.reduce(tf.distribute.ReduceOp.SUM, x, axis=axis)
134 |
135 | def run(self, fn, *args, **kwargs):
136 | """Distribute and run function on processors."""
137 | return self.strategy.run(fn, args=args, kwargs=kwargs)
138 |
139 | def build(self, batch):
140 | """Build the model by running a distributed batch through it."""
141 | logging.info('Building the model...')
142 | _ = self.run(tf.function(self.model.__call__), batch)
143 | self.model.summary()
144 |
145 | def distribute_dataset(self, dataset):
146 | """Create a distributed dataset."""
147 | if isinstance(dataset, tf.data.Dataset):
148 | return self.strategy.experimental_distribute_dataset(dataset)
149 | else:
150 | return dataset
151 |
152 | @tf.function
153 | def train_step(self, inputs):
154 | """Distributed training step."""
155 | # Wrap iterator in tf.function, slight speedup passing in iter vs batch.
156 | batch = next(inputs) if hasattr(inputs, '__next__') else inputs
157 | losses = self.run(self.step_fn, batch)
158 | # Add up the scalar losses across replicas.
159 | n_replicas = self.strategy.num_replicas_in_sync
160 | return {k: self.psum(v, axis=None) / n_replicas for k, v in losses.items()}
161 |
162 | @tf.function
163 | def step_fn(self, batch):
164 | """Per-Replica training step."""
165 | with tf.GradientTape() as tape:
166 | _, losses = self.model(batch, return_losses=True, training=True)
167 | # Clip and apply gradients.
168 | grads = tape.gradient(losses['total_loss'], self.model.trainable_variables)
169 | grads, _ = tf.clip_by_global_norm(grads, self.grad_clip_norm)
170 | self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
171 | return losses
172 |
173 |
174 | @gin.configurable
175 | def get_trainer_class(trainer_class=Trainer):
176 | """Gin configurable function get a 'global' trainer for use in ddsp_run.py.
177 |
178 | Args:
179 | trainer_class: A trainer class such as `Trainer`.
180 |
181 | Returns:
182 | The 'global' trainer class specifieed in the gin config.
183 | """
184 | return trainer_class
185 |
--------------------------------------------------------------------------------
/ddsp/version.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | r"""Separate file for storing the current version of DDSP.
16 |
17 | Stored in a separate file so that setup.py can reference the version without
18 | pulling in all the dependencies in __init__.py.
19 | """
20 |
21 | __version__ = '3.7.0'
22 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Install ddsp."""
16 |
17 | import os
18 | import sys
19 | import setuptools
20 |
21 | # To enable importing version.py directly, we add its path to sys.path.
22 | version_path = os.path.join(os.path.dirname(__file__), 'ddsp')
23 | sys.path.append(version_path)
24 | from version import __version__ # pylint: disable=g-import-not-at-top
25 |
26 | setuptools.setup(
27 | name='ddsp',
28 | version=__version__,
29 | description='Differentiable Digital Signal Processing ',
30 | author='Google Inc.',
31 | author_email='no-reply@google.com',
32 | url='http://github.com/magenta/ddsp',
33 | license='Apache 2.0',
34 | packages=setuptools.find_packages(),
35 | package_data={
36 | '': ['*.gin'],
37 | },
38 | scripts=[],
39 | install_requires=[
40 | 'absl-py',
41 | 'apache-beam',
42 | 'cloudml-hypertune<=0.1.0.dev6',
43 | 'crepe<=0.0.12',
44 | 'dill<=0.3.4',
45 | 'future',
46 | 'gin-config>=0.3.0',
47 | 'google-cloud-storage',
48 | 'hmmlearn<=0.2.7',
49 | 'librosa<=0.10',
50 | 'pydub<=0.25.1',
51 | 'protobuf<=3.20', # temporary fix for proto dependency bug
52 | 'mir_eval<=0.7',
53 | 'note_seq<0.0.4',
54 | 'numpy<1.24',
55 | 'scipy<=1.10.1',
56 | 'six',
57 | 'tensorflow<=2.11',
58 | 'tensorflowjs<3.19',
59 | 'tensorflow-probability<=0.19',
60 | 'tensorflow-datasets<=4.9',
61 | 'tflite_support<=0.1'
62 | ],
63 | extras_require={
64 | 'gcp': [
65 | 'gevent', 'google-api-python-client', 'google-compute-engine',
66 | 'oauth2client'
67 | ],
68 | 'data_preparation': [
69 | 'apache_beam',
70 | # TODO(jesseengel): Remove versioning when beam import is fixed.
71 | 'pyparsing<=2.4.7'
72 | ],
73 | 'test': ['pytest', 'pylint!=2.5.0'],
74 | },
75 | # pylint: disable=line-too-long
76 | entry_points={
77 | 'console_scripts': [
78 | 'ddsp_export = ddsp.training.ddsp_export:console_entry_point',
79 | 'ddsp_run = ddsp.training.ddsp_run:console_entry_point',
80 | 'ddsp_prepare_tfrecord = ddsp.training.data_preparation.ddsp_prepare_tfrecord:console_entry_point',
81 | 'ddsp_generate_synthetic_dataset = ddsp.training.data_preparation.ddsp_generate_synthetic_dataset:console_entry_point',
82 | 'ddsp_ai_platform = ddsp.training.docker.ddsp_ai_platform:console_entry_point',
83 | ],
84 | },
85 | # pylint: enable=line-too-long
86 | classifiers=[
87 | 'Development Status :: 4 - Beta',
88 | 'Intended Audience :: Developers',
89 | 'Intended Audience :: Science/Research',
90 | 'License :: OSI Approved :: Apache Software License',
91 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
92 | ],
93 | tests_require=['pytest'],
94 | setup_requires=['pytest-runner'],
95 | keywords='audio dsp signalprocessing machinelearning music',
96 | )
97 |
--------------------------------------------------------------------------------
/update_gin_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The DDSP Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Script to update old '.gin' config files to work with current codebase.
16 |
17 | When there are changes in a function signature (adding or removing kwargs),
18 | those changes need to be reflected in gin config files. There is no way to
19 | automatically do this from within gin.
20 |
21 | This script maintains some backwards compatability, by keeping track of changes
22 | to function signatures and automatically updating gin config files to
23 | incorporate those changes.
24 |
25 | In particular, this is useful for operative_config-*.gin files that are created
26 | during training, to allow loading of old models.
27 |
28 | This script will update any target gin config files, and leave copies of the old
29 | files for safe-keeping.
30 |
31 | Usage:
32 | ======
33 | python update_gin_config.py path/to/config/operative_config-*.gin
34 | """
35 |
36 | import os
37 | import re
38 |
39 | from absl import app
40 | from absl import flags
41 | from tensorflow import io
42 |
43 | flags.DEFINE_string('path', '/tmp/operative_config-*.gin', 'Path to gin '
44 | 'config files. May be a path to a file, a or a glob '
45 | 'expression.')
46 | flags.DEFINE_boolean('overwrite', False, 'Overwrite orginal files.')
47 |
48 |
49 | FLAGS = flags.FLAGS
50 | gfile = io.gfile
51 |
52 |
53 | # ==============================================================================
54 | # Updates to perform
55 | # ==============================================================================
56 | # Remove lines with any of these strings in them.
57 | REMOVE = [
58 | 'SpectralLoss.delta_delta_freq_weight',
59 | 'SpectralLoss.delta_delta_time_weight',
60 | 'DilatedConvEncoder.resample',
61 | 'DilatedConvDecoder.resample',
62 | ]
63 |
64 |
65 | # Perform these line-by-line substitutions, (Regex, Replacement String).
66 | SUBSTITUTE = [
67 | ('ZRnnFcDecoder', 'RnnFcDecoder'),
68 | ]
69 |
70 |
71 | # Add the following kwargs, (GinConfigurable, kwarg, value).
72 | # -> 'GinConfigurable.kwarg = value'
73 | ADD = [
74 | ('RnnFcDecoder', 'input_keys', '("f0_scaled", "ld_scaled")'),
75 | ]
76 |
77 |
78 | # ==============================================================================
79 | # Main Program
80 | # ==============================================================================
81 | def add_kwarg(lines, gin_configurable, kwarg, value):
82 | """Return the line where a GinConfigurable first appears."""
83 | gin_kwarg = gin_configurable + '.' + kwarg
84 | new_line = gin_kwarg + ' = ' + value + '\n'
85 | configurable_present = any([gin_configurable in line for line in lines])
86 | kwarg_present = any([gin_kwarg in line for line in lines])
87 | if configurable_present and not kwarg_present:
88 | # Add to the bottom of the config.
89 | lines.append('\n' + new_line)
90 | print(f'Added: {new_line.rstrip()}')
91 | elif configurable_present and kwarg_present:
92 | print(f'Skipped Add: {new_line.rstrip()}, {gin_kwarg} already present.')
93 | else:
94 | print(f'Skipped Add: {new_line.rstrip()}, {gin_configurable} not present.')
95 |
96 |
97 | def main(argv):
98 | # Parse input args.
99 | if len(argv) > 2:
100 | raise app.UsageError('Too many command-line arguments.')
101 | elif len(argv) == 2:
102 | path = argv[1]
103 | else:
104 | path = FLAGS.path
105 |
106 | # Get a list of files that match the pattern.
107 | files = gfile.glob(path)
108 |
109 | for fpath in files:
110 | # Create a new file path.
111 | dirname, filename = os.path.split(fpath)
112 | new_filename = filename if FLAGS.overwrite else 'updated_' + filename
113 | new_fpath = os.path.join(dirname, new_filename)
114 | print(f'\nUpdating: \n{fpath} -> \n{new_fpath}')
115 | print('================')
116 |
117 | # Read old config.
118 | with gfile.GFile(fpath, 'r') as f:
119 | lines = f.readlines()
120 |
121 | # Make new config.
122 | new_lines = []
123 | for line in lines:
124 | # Remove lines with old arguments.
125 | if any([tag in line for tag in REMOVE]):
126 | print(f'Removed: {line.rstrip()}')
127 | continue
128 |
129 | # Substitute.
130 | for regex, sub in SUBSTITUTE:
131 | old_line = line
132 | line, n = re.subn(regex, sub, line)
133 | if n:
134 | print(f'Swapped: {old_line.rstrip()} -> {line.rstrip()}')
135 |
136 | # Append the new line.
137 | new_lines.append(line)
138 |
139 | # Add new lines after substitutions.
140 | for gin_configurable, kwarg, value in ADD:
141 | add_kwarg(new_lines, gin_configurable, kwarg, value)
142 |
143 | # Delete target file if it exists.
144 | if gfile.exists(new_fpath):
145 | gfile.remove(new_fpath)
146 |
147 | # Write to a new file.
148 | with gfile.GFile(new_fpath, 'w') as f:
149 | _ = f.write(''.join(new_lines))
150 |
151 |
152 | if __name__ == '__main__':
153 | app.run(main)
154 |
--------------------------------------------------------------------------------
/update_pip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Usage:
4 | # sh third_party/py/ddsp/update_pip.sh
5 | set -ex
6 |
7 | orig_dir=$(pwd)
8 | tmp_dir=$(mktemp -d -t ddsp-XXXX)
9 | git clone https://github.com/magenta/ddsp.git $tmp_dir
10 | cd $tmp_dir
11 |
12 | python setup.py sdist
13 | python setup.py bdist_wheel --universal
14 | twine upload dist/*
15 |
16 | cd $orig_dir
17 | rm -rf $tmp_dir
18 |
--------------------------------------------------------------------------------